pcalhoun commited on
Commit
2c5d3c9
Β·
verified Β·
1 Parent(s): 180920d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +355 -0
app.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ BitsAndBytesConfig,
8
+ )
9
+ from peft import PeftModel
10
+
11
+ # CONFIGURATION
12
+ CHECKPOINT_PATH = "pcalhoun/ILR-Assistant"
13
+ MODEL_NAME = "Qwen/Qwen3-4B"
14
+ LOAD_IN_4BIT = True
15
+ MAX_NEW_TOKENS = 1024
16
+
17
+ ILR_LEVELS = ['1', '1+', '2', '2+', '3', '3+']
18
+
19
+ INITIAL_USER_MESSAGE_TEMPLATE = """ILR Level 1 (Elementary):
20
+ Reads very simple texts (e.g., tourist materials) with high-frequency vocabulary. Misunderstandings common; grasps basic ideas in familiar contexts.
21
+ ILR Level 1+ (Elementary+):
22
+ Handles simple announcements, headlines, or narratives. Can locate routine professional info but struggles with structure and cohesion.
23
+ ILR Level 2 (Limited Working):
24
+ Reads straightforward factual texts on familiar topics (e.g., news, basic reports). Understands main ideas but slowly; inferences are limited.
25
+ ILR Level 2+ (Limited Working+):
26
+ Comprehends most non-technical prose and concrete professional discussions. Separates main ideas from details but misses nuance.
27
+ ILR Level 3 (General Professional):
28
+ Reads diverse authentic texts (e.g., news, reports) with near-complete comprehension. Interprets implicit meaning but struggles with complex idioms.
29
+ ILR Level 3+ (General Professional+):
30
+ Handles varied professional styles with minimal errors. Understands cultural references and complex structures, though subtleties may be missed.
31
+ Initial ILR level for this conversation: {ilr_level}
32
+ Test my comprehension of Modern Standard Arabic."""
33
+
34
+ INITIAL_ASSISTANT_SCORER = "I am administering an ILR level assessment."
35
+
36
+ IM_START = "<|im_start|>"
37
+ IM_END = "<|im_end|>"
38
+
39
+ # Global variables
40
+ model = None
41
+ tokenizer = None
42
+
43
+ def load_model_and_tokenizer():
44
+ """Load the base model with LoRA adapter."""
45
+ global model, tokenizer
46
+
47
+ if model is not None and tokenizer is not None:
48
+ return model, tokenizer
49
+
50
+ print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
51
+
52
+ # Load tokenizer
53
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
54
+ tokenizer.pad_token = tokenizer.eos_token
55
+
56
+ # Load base model with quantization
57
+ if LOAD_IN_4BIT and torch.cuda.is_available():
58
+ bnb_config = BitsAndBytesConfig(
59
+ load_in_4bit=True,
60
+ bnb_4bit_use_double_quant=True,
61
+ bnb_4bit_quant_type="nf4",
62
+ bnb_4bit_compute_dtype=torch.bfloat16,
63
+ )
64
+ base_model = AutoModelForCausalLM.from_pretrained(
65
+ MODEL_NAME,
66
+ quantization_config=bnb_config,
67
+ device_map="auto",
68
+ trust_remote_code=True,
69
+ )
70
+ else:
71
+ base_model = AutoModelForCausalLM.from_pretrained(
72
+ MODEL_NAME,
73
+ torch_dtype=torch.bfloat16,
74
+ device_map="auto",
75
+ trust_remote_code=True,
76
+ )
77
+
78
+ # Load LoRA adapter if checkpoint exists
79
+ if os.path.exists(CHECKPOINT_PATH):
80
+ model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH)
81
+ else:
82
+ print("Warning: Checkpoint path not found, using base model only")
83
+ model = base_model
84
+
85
+ model.eval()
86
+ print("βœ“ Model and LoRA adapter loaded successfully")
87
+ return model, tokenizer
88
+
89
+ def text_completion(prompt):
90
+ """
91
+ Generate text completion for the given prompt.
92
+
93
+ Args:
94
+ prompt (str): The input prompt text
95
+
96
+ Returns:
97
+ str: The generated completion text
98
+ """
99
+ try:
100
+ model, tokenizer = load_model_and_tokenizer()
101
+
102
+ # Print the full prompt to CLI
103
+ print("=" * 80)
104
+ print("FULL PROMPT:")
105
+ print("=" * 80)
106
+ print(prompt)
107
+ print("=" * 80)
108
+
109
+ # Tokenize
110
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
111
+
112
+ # Generate with stricter stopping conditions
113
+ with torch.no_grad():
114
+ output = model.generate(
115
+ **inputs,
116
+ max_new_tokens=MAX_NEW_TOKENS,
117
+ temperature=0.6,
118
+ top_p=0.95,
119
+ top_k=20,
120
+ do_sample=True,
121
+ pad_token_id=tokenizer.eos_token_id,
122
+ eos_token_id=tokenizer.eos_token_id,
123
+ stopping_criteria=None,
124
+ )
125
+
126
+ # Decode response
127
+ completion = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)
128
+
129
+ # Print the raw response to CLI
130
+ print("RAW MODEL OUTPUT:")
131
+ print("=" * 80)
132
+ print(completion)
133
+ print("=" * 80)
134
+
135
+ # Clean up the response - stop at first IM_END token
136
+ if IM_END in completion:
137
+ completion = completion.split(IM_END)[0]
138
+
139
+ return completion.strip()
140
+
141
+ except Exception as e:
142
+ error_msg = f"Error generating completion: {str(e)}"
143
+ print(error_msg)
144
+ return error_msg
145
+
146
+ def format_message_for_display(content, role):
147
+ """Format a message for display in the Gradio interface (remove chat tokens but keep scorer content)."""
148
+ if role == "user":
149
+ return content
150
+ elif role == "assistant":
151
+ # Keep the <scorer> content visible but remove chat tokens
152
+ return content
153
+ return content
154
+
155
+ def build_chat_prompt(messages):
156
+ """Build the full chat prompt with proper tokens for model generation."""
157
+ prompt = ""
158
+ for msg in messages:
159
+ role = msg["role"]
160
+ content = msg["content"]
161
+
162
+ if role == "user":
163
+ prompt += f"{IM_START}user\n{content}{IM_END}\n"
164
+ elif role == "assistant":
165
+ if msg.get("complete", False):
166
+ # Complete message with IM_END
167
+ prompt += f"{IM_START}assistant\n{content}{IM_END}\n"
168
+ else:
169
+ # Incomplete message for generation
170
+ prompt += f"{IM_START}assistant\n{content}"
171
+
172
+ print("BUILT CHAT PROMPT:")
173
+ print("=" * 60)
174
+ print(prompt)
175
+ print("=" * 60)
176
+
177
+ return prompt
178
+
179
+ def initialize_conversation(ilr_level):
180
+ """Initialize a new conversation with the given ILR level."""
181
+ print(f"πŸ”„ Initializing conversation at ILR level: {ilr_level}")
182
+
183
+ # Create initial messages
184
+ initial_user_content = INITIAL_USER_MESSAGE_TEMPLATE.format(ilr_level=ilr_level)
185
+ initial_assistant_content = f"<scorer>\n{INITIAL_ASSISTANT_SCORER}\n</scorer>\n"
186
+
187
+ messages = [
188
+ {"role": "user", "content": initial_user_content, "complete": True},
189
+ {"role": "assistant", "content": initial_assistant_content, "complete": False}
190
+ ]
191
+
192
+ # Generate the initial assistant response
193
+ prompt = build_chat_prompt(messages)
194
+ response = text_completion(prompt)
195
+
196
+ # Update the assistant message with the complete response
197
+ messages[-1]["content"] = initial_assistant_content + response
198
+ messages[-1]["complete"] = True
199
+
200
+ # Convert to display format for Gradio
201
+ display_history = []
202
+ display_history.append([
203
+ format_message_for_display(initial_user_content, "user"),
204
+ format_message_for_display(messages[-1]["content"], "assistant")
205
+ ])
206
+
207
+ # Format raw output for display
208
+ raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}"
209
+
210
+ return display_history, messages, raw_output
211
+
212
+ def send_message(user_input, chat_history, messages, ilr_level):
213
+ """Handle sending a user message and generating assistant response."""
214
+ if not user_input.strip():
215
+ return chat_history, "", messages, ""
216
+
217
+ print("πŸ“ SENDING MESSAGE:")
218
+ print("=" * 60)
219
+ print(f"User Input: {repr(user_input)}")
220
+ print(f"Current Messages: {len(messages)}")
221
+ print("=" * 60)
222
+
223
+ # Add user message
224
+ messages.append({"role": "user", "content": user_input, "complete": True})
225
+
226
+ # Start assistant response with scorer tag
227
+ assistant_start = "<scorer>\n"
228
+ messages.append({"role": "assistant", "content": assistant_start, "complete": False})
229
+
230
+ # Generate assistant response
231
+ prompt = build_chat_prompt(messages)
232
+ response = text_completion(prompt)
233
+
234
+ # Complete the assistant message
235
+ full_assistant_content = assistant_start + response
236
+ messages[-1]["content"] = full_assistant_content
237
+ messages[-1]["complete"] = True
238
+
239
+ # Update chat history for display
240
+ chat_history.append([
241
+ format_message_for_display(user_input, "user"),
242
+ format_message_for_display(full_assistant_content, "assistant")
243
+ ])
244
+
245
+ # Format raw output for display
246
+ raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}"
247
+
248
+ return chat_history, "", messages, raw_output
249
+
250
+ def reset_conversation(ilr_level):
251
+ """Reset the conversation with a new ILR level."""
252
+ chat_history, messages, raw_output = initialize_conversation(ilr_level)
253
+ return chat_history, messages, raw_output
254
+
255
+ def create_interface():
256
+ """Create the Gradio interface."""
257
+ with gr.Blocks(title="ILR Arabic Assistant", theme=gr.themes.Soft()) as demo:
258
+ gr.Markdown("# πŸ‡ΈπŸ‡¦ ILR Arabic Assistant")
259
+
260
+ # State to store messages
261
+ messages_state = gr.State([])
262
+
263
+ with gr.Row():
264
+ with gr.Column(scale=1):
265
+ ilr_level = gr.Dropdown(
266
+ choices=ILR_LEVELS,
267
+ value="2+",
268
+ label="ILR Level",
269
+ info="Select your proficiency level"
270
+ )
271
+
272
+ reset_btn = gr.Button(
273
+ "πŸ”„ Reset Conversation",
274
+ variant="primary"
275
+ )
276
+
277
+ gr.Markdown("""
278
+ ### ILR Levels:
279
+ - **1**: Elementary
280
+ - **1+**: Elementary+
281
+ - **2**: Limited Working
282
+ - **2+**: Limited Working+
283
+ - **3**: General Professional
284
+ - **3+**: General Professional+
285
+ """)
286
+
287
+ with gr.Column(scale=3):
288
+ chatbot = gr.Chatbot(
289
+ label="Conversation",
290
+ height=500,
291
+ show_copy_button=True,
292
+ avatar_images=("πŸ‘€", "πŸ€–"),
293
+ )
294
+
295
+ with gr.Row():
296
+ msg = gr.Textbox(
297
+ label="Your message",
298
+ placeholder="Type your response in English...",
299
+ scale=4
300
+ )
301
+ send_btn = gr.Button("πŸ“€ Send", scale=1, variant="primary")
302
+
303
+ # Raw output display
304
+ raw_output_display = gr.Textbox(
305
+ label="Raw Model Output",
306
+ lines=10,
307
+ max_lines=20,
308
+ interactive=False,
309
+ show_copy_button=True,
310
+ autoscroll=True,
311
+ placeholder="Raw model output will appear here...",
312
+ )
313
+
314
+ # Event handlers
315
+ def handle_reset(level):
316
+ return reset_conversation(level)
317
+
318
+ def handle_send(user_input, chat_history, messages, level):
319
+ return send_message(user_input, chat_history, messages, level)
320
+
321
+ reset_btn.click(
322
+ handle_reset,
323
+ inputs=[ilr_level],
324
+ outputs=[chatbot, messages_state, raw_output_display]
325
+ )
326
+
327
+ send_btn.click(
328
+ handle_send,
329
+ inputs=[msg, chatbot, messages_state, ilr_level],
330
+ outputs=[chatbot, msg, messages_state, raw_output_display]
331
+ )
332
+
333
+ msg.submit(
334
+ handle_send,
335
+ inputs=[msg, chatbot, messages_state, ilr_level],
336
+ outputs=[chatbot, msg, messages_state, raw_output_display]
337
+ )
338
+
339
+ # Initialize conversation on load
340
+ def on_load(level):
341
+ chat_history, messages, raw_output = initialize_conversation(level)
342
+ return chat_history, messages, raw_output
343
+
344
+ demo.load(
345
+ on_load,
346
+ inputs=[ilr_level],
347
+ outputs=[chatbot, messages_state, raw_output_display]
348
+ )
349
+
350
+ return demo
351
+
352
+ if __name__ == "__main__":
353
+ demo = create_interface()
354
+ load_model_and_tokenizer()
355
+ demo.launch()