File size: 16,350 Bytes
2c5d3c9
 
 
 
 
 
 
c6632f8
 
2c5d3c9
 
 
 
61ed7bb
2c5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6632f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4d042
2c5d3c9
 
 
c6632f8
2c5d3c9
 
c6632f8
 
 
 
 
 
 
 
 
 
 
 
 
 
2c5d3c9
c6632f8
2c5d3c9
 
 
 
 
 
 
 
 
 
c6632f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c5d3c9
 
 
 
 
 
 
 
 
 
 
c6632f8
2c5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6632f8
 
 
2c5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cb77ff
 
2c5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
import os
import torch
import gradio as gr
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    LogitsProcessorList,
    LogitsProcessor,
)
from peft import PeftModel

# CONFIGURATION
CHECKPOINT_PATH = "pcalhoun/ILR-Assistant-LoRA"
MODEL_NAME = "Qwen/Qwen3-4B"
LOAD_IN_4BIT = True
MAX_NEW_TOKENS = 1024

ILR_LEVELS = ['1', '1+', '2', '2+', '3', '3+']

INITIAL_USER_MESSAGE_TEMPLATE = """ILR Level 1 (Elementary):
Reads very simple texts (e.g., tourist materials) with high-frequency vocabulary. Misunderstandings common; grasps basic ideas in familiar contexts.
ILR Level 1+ (Elementary+):
Handles simple announcements, headlines, or narratives. Can locate routine professional info but struggles with structure and cohesion.
ILR Level 2 (Limited Working):
Reads straightforward factual texts on familiar topics (e.g., news, basic reports). Understands main ideas but slowly; inferences are limited.
ILR Level 2+ (Limited Working+):
Comprehends most non-technical prose and concrete professional discussions. Separates main ideas from details but misses nuance.
ILR Level 3 (General Professional):
Reads diverse authentic texts (e.g., news, reports) with near-complete comprehension. Interprets implicit meaning but struggles with complex idioms.
ILR Level 3+ (General Professional+):
Handles varied professional styles with minimal errors. Understands cultural references and complex structures, though subtleties may be missed.
Initial ILR level for this conversation: {ilr_level}
Test my comprehension of Modern Standard Arabic."""

INITIAL_ASSISTANT_SCORER = "I am administering an ILR level assessment."

IM_START = "<|im_start|>"
IM_END = "<|im_end|>"

# Global variables
model = None
tokenizer = None

class BanTokensLogitsProcessor(LogitsProcessor):
    """Custom LogitsProcessor to completely ban specific tokens with proper device handling."""
    
    def __init__(self, tokenizer, banned_words, device):
        self.banned_token_ids = set()
        self.device = device
        
        # Get all possible token IDs for banned words
        for word in banned_words:
            variants = [word, f" {word}", f"{word} ", f" {word} "]
            for variant in variants:
                try:
                    token_ids = tokenizer.encode(variant, add_special_tokens=False)
                    self.banned_token_ids.update(token_ids)
                except Exception as e:
                    print(f"Warning: Could not encode variant '{variant}': {e}")
        
        print(f"Banned token IDs: {self.banned_token_ids}")
        print(f"LogitsProcessor device: {self.device}")
    
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # Set logits of banned tokens to negative infinity
        for token_id in self.banned_token_ids:
            if token_id < scores.shape[-1]:  # Safety check
                scores[:, token_id] = float('-inf')
        return scores

def get_banned_token_ids(tokenizer, bad_words):
    """Get token IDs for words that should be banned using bad_words_ids format."""
    bad_words_ids = []
    for word in bad_words:
        # Try different variations to handle tokenization edge cases
        variants = [
            word,                    # exact word
            f" {word}",             # with leading space
            f"{word} ",             # with trailing space
            f" {word} "             # with both spaces
        ]
        
        for variant in variants:
            try:
                token_ids = tokenizer.encode(variant, add_special_tokens=False)
                if token_ids:  # Only add if tokenization succeeded
                    bad_words_ids.append(token_ids)
            except Exception as e:
                print(f"Warning: Could not encode variant '{variant}': {e}")
    
    return bad_words_ids



def load_model_and_tokenizer():
    """Load the base model with LoRA adapter."""
    global model, tokenizer
    
    if model is not None and tokenizer is not None:
        return model, tokenizer
    
    print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Load base model with quantization
    if LOAD_IN_4BIT and torch.cuda.is_available():
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
        base_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
        )
    else:
        base_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )
    
    # Load LoRA adapter if checkpoint exists
    model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH)
    
    model.eval()
    print("βœ“ Model and LoRA adapter loaded successfully")
    print(f"βœ“ Model device: {next(model.parameters()).device}")
    return model, tokenizer

def debug_tokenization(tokenizer, words):
    """Debug tokenization of specific words."""
    print("=== TOKENIZATION DEBUG ===")
    for word in words:
        variants = [word, f" {word}", f"{word} ", f" {word} "]
        for variant in variants:
            try:
                token_ids = tokenizer.encode(variant, add_special_tokens=False)
                tokens = tokenizer.tokenize(variant)
                print(f"'{variant}' -> IDs: {token_ids}, Tokens: {tokens}")
            except Exception as e:
                print(f"Error tokenizing '{variant}': {e}")
    print("=========================")

def text_completion(prompt):
    """Enhanced text completion with comprehensive token banning."""
    try:
        model, tokenizer = load_model_and_tokenizer()
        
        # Print the full prompt to CLI
        print("=" * 80)
        print("FULL PROMPT:")
        print("=" * 80)
        print(prompt)
        print("=" * 80)
        
        # Get model device
        model_device = next(model.parameters()).device
        print(f"Model device: {model_device}")
        
        # Method 1: bad_words_ids
        banned_words = ["<think>", "</think>"]
        bad_words_ids = get_banned_token_ids(tokenizer, banned_words)
        print(f"Bad words IDs: {bad_words_ids}")
        
        # Method 2: Custom LogitsProcessor with proper device handling
        ban_processor = BanTokensLogitsProcessor(tokenizer, banned_words, model_device)
        logits_processor = LogitsProcessorList([ban_processor])
        
        # Debug tokenization (run once to see how tokens are encoded)
        # debug_tokenization(tokenizer, banned_words)
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model_device)
        print(f"Input device: {inputs['input_ids'].device}")
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.6,
                top_p=0.95,
                top_k=20,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                bad_words_ids=bad_words_ids,        # Filter out <think> tokens
            )
        
        # Decode response
        completion = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)
        
        # Print the raw response to CLI
        print("RAW MODEL OUTPUT:")
        print("=" * 80)
        print(completion)
        print("=" * 80)
        
        # Clean up the response - stop at first IM_END token
        if IM_END in completion:
            completion = completion.split(IM_END)[0]
        
        return completion.strip()
        
    except Exception as e:
        error_msg = f"Error generating completion: {str(e)}"
        print(error_msg)
        print(f"Exception type: {type(e)}")
        import traceback
        traceback.print_exc()
        return error_msg

def format_message_for_display(content, role):
    """Format a message for display in the Gradio interface (remove chat tokens but keep scorer content)."""
    if role == "user":
        return content
    elif role == "assistant":
        # Keep the <scorer> content visible but remove chat tokens
        return content
    return content

def build_chat_prompt(messages):
    """Build the full chat prompt with proper tokens for model generation."""
    prompt = ""
    for msg in messages:
        role = msg["role"]
        content = msg["content"]
        
        if role == "user":
            prompt += f"{IM_START}user\n{content}{IM_END}\n"
        elif role == "assistant":
            if msg.get("complete", False):
                # Complete message with IM_END
                prompt += f"{IM_START}assistant\n{content}{IM_END}\n"
            else:
                # Incomplete message for generation
                prompt += f"{IM_START}assistant\n{content}"
    
    print("BUILT CHAT PROMPT:")
    print("=" * 60)
    print(prompt)
    print("=" * 60)
    
    return prompt

def initialize_conversation(ilr_level):
    """Initialize a new conversation with the given ILR level."""
    print(f"πŸ”„ Initializing conversation at ILR level: {ilr_level}")
    
    # Create initial messages
    initial_user_content = INITIAL_USER_MESSAGE_TEMPLATE.format(ilr_level=ilr_level)
    initial_assistant_content = f"<scorer>\n{INITIAL_ASSISTANT_SCORER}\n</scorer>\n"
    
    messages = [
        {"role": "user", "content": initial_user_content, "complete": True},
        {"role": "assistant", "content": initial_assistant_content, "complete": False}
    ]
    
    # Generate the initial assistant response
    prompt = build_chat_prompt(messages)
    response = text_completion(prompt)
    
    # Update the assistant message with the complete response
    messages[-1]["content"] = initial_assistant_content + response
    messages[-1]["complete"] = True
    
    # Convert to display format for Gradio
    display_history = []
    display_history.append([
        format_message_for_display(initial_user_content, "user"),
        format_message_for_display(messages[-1]["content"], "assistant")
    ])
    
    # Format raw output for display
    raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}"
    
    return display_history, messages, raw_output

def send_message(user_input, chat_history, messages, ilr_level):
    """Handle sending a user message and generating assistant response."""
    if not user_input.strip():
        return chat_history, "", messages, ""
    
    print("πŸ“ SENDING MESSAGE:")
    print("=" * 60)
    print(f"User Input: {repr(user_input)}")
    print(f"Current Messages: {len(messages)}")
    print("=" * 60)
    
    # Add user message
    messages.append({"role": "user", "content": user_input, "complete": True})
    
    # Start assistant response with scorer tag
    assistant_start = "<scorer>\n"
    messages.append({"role": "assistant", "content": assistant_start, "complete": False})
    
    # Generate assistant response
    prompt = build_chat_prompt(messages)
    response = text_completion(prompt)
    
    # Complete the assistant message
    full_assistant_content = assistant_start + response
    messages[-1]["content"] = full_assistant_content
    messages[-1]["complete"] = True
    
    # Update chat history for display
    chat_history.append([
        format_message_for_display(user_input, "user"),
        format_message_for_display(full_assistant_content, "assistant")
    ])
    
    # Format raw output for display
    raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}"
    
    return chat_history, "", messages, raw_output

def reset_conversation(ilr_level):
    """Reset the conversation with a new ILR level."""
    chat_history, messages, raw_output = initialize_conversation(ilr_level)
    return chat_history, messages, raw_output

def create_interface():
    """Create the Gradio interface."""
    with gr.Blocks(title="ILR Arabic Assistant", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# πŸ‡ΈπŸ‡¦ ILR Arabic Assistant")
        
        # State to store messages
        messages_state = gr.State([])
        
        with gr.Row():
            with gr.Column(scale=1):
                ilr_level = gr.Dropdown(
                    choices=ILR_LEVELS,
                    value="2+",
                    label="ILR Level",
                    info="Select your proficiency level"
                )
                
                reset_btn = gr.Button(
                    "πŸ”„ Reset Conversation",
                    variant="primary"
                )
                
                gr.Markdown("""The ILR Assistant generates Arabic reading comprehension assessments that adapt to your performance level. It presents Arabic passages with questions and automatically adjusts difficulty based on your responses - moving to easier content when you struggle or maintaining challenge when you succeed. The system was trained on authentic Arabic learning materials from the Defense Language Institute using the official ILR (Interagency Language Roundtable) proficiency scale. Try it out to see how AI can create personalized language assessments that respond to your Arabic reading comprehension skills.

                ### ILR Levels:
                - **1**: Elementary
                - **1+**: Elementary+  
                - **2**: Limited Working
                - **2+**: Limited Working+
                - **3**: General Professional
                - **3+**: General Professional+
                """)
            
            with gr.Column(scale=3):
                chatbot = gr.Chatbot(
                    label="Conversation",
                    height=500,
                    show_copy_button=True,
                    avatar_images=("πŸ‘€", "πŸ€–"),
                )
                
                with gr.Row():
                    msg = gr.Textbox(
                        label="Your message",
                        placeholder="Type your response in English...",
                        scale=4
                    )
                    send_btn = gr.Button("πŸ“€ Send", scale=1, variant="primary")
                
                # Raw output display
                raw_output_display = gr.Textbox(
                    label="Raw Model Output",
                    lines=10,
                    max_lines=20,
                    interactive=False,
                    show_copy_button=True,
                    autoscroll=True,
                    placeholder="Raw model output will appear here...",
                )
        
        # Event handlers
        def handle_reset(level):
            return reset_conversation(level)
        
        def handle_send(user_input, chat_history, messages, level):
            return send_message(user_input, chat_history, messages, level)
        
        reset_btn.click(
            handle_reset,
            inputs=[ilr_level],
            outputs=[chatbot, messages_state, raw_output_display]
        )
        
        send_btn.click(
            handle_send,
            inputs=[msg, chatbot, messages_state, ilr_level],
            outputs=[chatbot, msg, messages_state, raw_output_display]
        )
        
        msg.submit(
            handle_send,
            inputs=[msg, chatbot, messages_state, ilr_level],
            outputs=[chatbot, msg, messages_state, raw_output_display]
        )

        # Initialize conversation on load
        def on_load(level):
            chat_history, messages, raw_output = initialize_conversation(level)
            return chat_history, messages, raw_output
        
        demo.load(
            on_load,
            inputs=[ilr_level],
            outputs=[chatbot, messages_state, raw_output_display]
        )
    
    return demo

if __name__ == "__main__":
    demo = create_interface()
    load_model_and_tokenizer()
    demo.launch()