Spaces:
Sleeping
Sleeping
File size: 16,350 Bytes
2c5d3c9 c6632f8 2c5d3c9 61ed7bb 2c5d3c9 c6632f8 2c5d3c9 8d4d042 2c5d3c9 c6632f8 2c5d3c9 c6632f8 2c5d3c9 c6632f8 2c5d3c9 c6632f8 2c5d3c9 c6632f8 2c5d3c9 c6632f8 2c5d3c9 8cb77ff 2c5d3c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 |
import os
import torch
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
LogitsProcessorList,
LogitsProcessor,
)
from peft import PeftModel
# CONFIGURATION
CHECKPOINT_PATH = "pcalhoun/ILR-Assistant-LoRA"
MODEL_NAME = "Qwen/Qwen3-4B"
LOAD_IN_4BIT = True
MAX_NEW_TOKENS = 1024
ILR_LEVELS = ['1', '1+', '2', '2+', '3', '3+']
INITIAL_USER_MESSAGE_TEMPLATE = """ILR Level 1 (Elementary):
Reads very simple texts (e.g., tourist materials) with high-frequency vocabulary. Misunderstandings common; grasps basic ideas in familiar contexts.
ILR Level 1+ (Elementary+):
Handles simple announcements, headlines, or narratives. Can locate routine professional info but struggles with structure and cohesion.
ILR Level 2 (Limited Working):
Reads straightforward factual texts on familiar topics (e.g., news, basic reports). Understands main ideas but slowly; inferences are limited.
ILR Level 2+ (Limited Working+):
Comprehends most non-technical prose and concrete professional discussions. Separates main ideas from details but misses nuance.
ILR Level 3 (General Professional):
Reads diverse authentic texts (e.g., news, reports) with near-complete comprehension. Interprets implicit meaning but struggles with complex idioms.
ILR Level 3+ (General Professional+):
Handles varied professional styles with minimal errors. Understands cultural references and complex structures, though subtleties may be missed.
Initial ILR level for this conversation: {ilr_level}
Test my comprehension of Modern Standard Arabic."""
INITIAL_ASSISTANT_SCORER = "I am administering an ILR level assessment."
IM_START = "<|im_start|>"
IM_END = "<|im_end|>"
# Global variables
model = None
tokenizer = None
class BanTokensLogitsProcessor(LogitsProcessor):
"""Custom LogitsProcessor to completely ban specific tokens with proper device handling."""
def __init__(self, tokenizer, banned_words, device):
self.banned_token_ids = set()
self.device = device
# Get all possible token IDs for banned words
for word in banned_words:
variants = [word, f" {word}", f"{word} ", f" {word} "]
for variant in variants:
try:
token_ids = tokenizer.encode(variant, add_special_tokens=False)
self.banned_token_ids.update(token_ids)
except Exception as e:
print(f"Warning: Could not encode variant '{variant}': {e}")
print(f"Banned token IDs: {self.banned_token_ids}")
print(f"LogitsProcessor device: {self.device}")
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Set logits of banned tokens to negative infinity
for token_id in self.banned_token_ids:
if token_id < scores.shape[-1]: # Safety check
scores[:, token_id] = float('-inf')
return scores
def get_banned_token_ids(tokenizer, bad_words):
"""Get token IDs for words that should be banned using bad_words_ids format."""
bad_words_ids = []
for word in bad_words:
# Try different variations to handle tokenization edge cases
variants = [
word, # exact word
f" {word}", # with leading space
f"{word} ", # with trailing space
f" {word} " # with both spaces
]
for variant in variants:
try:
token_ids = tokenizer.encode(variant, add_special_tokens=False)
if token_ids: # Only add if tokenization succeeded
bad_words_ids.append(token_ids)
except Exception as e:
print(f"Warning: Could not encode variant '{variant}': {e}")
return bad_words_ids
def load_model_and_tokenizer():
"""Load the base model with LoRA adapter."""
global model, tokenizer
if model is not None and tokenizer is not None:
return model, tokenizer
print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Load base model with quantization
if LOAD_IN_4BIT and torch.cuda.is_available():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# Load LoRA adapter if checkpoint exists
model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH)
model.eval()
print("β Model and LoRA adapter loaded successfully")
print(f"β Model device: {next(model.parameters()).device}")
return model, tokenizer
def debug_tokenization(tokenizer, words):
"""Debug tokenization of specific words."""
print("=== TOKENIZATION DEBUG ===")
for word in words:
variants = [word, f" {word}", f"{word} ", f" {word} "]
for variant in variants:
try:
token_ids = tokenizer.encode(variant, add_special_tokens=False)
tokens = tokenizer.tokenize(variant)
print(f"'{variant}' -> IDs: {token_ids}, Tokens: {tokens}")
except Exception as e:
print(f"Error tokenizing '{variant}': {e}")
print("=========================")
def text_completion(prompt):
"""Enhanced text completion with comprehensive token banning."""
try:
model, tokenizer = load_model_and_tokenizer()
# Print the full prompt to CLI
print("=" * 80)
print("FULL PROMPT:")
print("=" * 80)
print(prompt)
print("=" * 80)
# Get model device
model_device = next(model.parameters()).device
print(f"Model device: {model_device}")
# Method 1: bad_words_ids
banned_words = ["<think>", "</think>"]
bad_words_ids = get_banned_token_ids(tokenizer, banned_words)
print(f"Bad words IDs: {bad_words_ids}")
# Method 2: Custom LogitsProcessor with proper device handling
ban_processor = BanTokensLogitsProcessor(tokenizer, banned_words, model_device)
logits_processor = LogitsProcessorList([ban_processor])
# Debug tokenization (run once to see how tokens are encoded)
# debug_tokenization(tokenizer, banned_words)
inputs = tokenizer(prompt, return_tensors="pt").to(model_device)
print(f"Input device: {inputs['input_ids'].device}")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=0.6,
top_p=0.95,
top_k=20,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
bad_words_ids=bad_words_ids, # Filter out <think> tokens
)
# Decode response
completion = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)
# Print the raw response to CLI
print("RAW MODEL OUTPUT:")
print("=" * 80)
print(completion)
print("=" * 80)
# Clean up the response - stop at first IM_END token
if IM_END in completion:
completion = completion.split(IM_END)[0]
return completion.strip()
except Exception as e:
error_msg = f"Error generating completion: {str(e)}"
print(error_msg)
print(f"Exception type: {type(e)}")
import traceback
traceback.print_exc()
return error_msg
def format_message_for_display(content, role):
"""Format a message for display in the Gradio interface (remove chat tokens but keep scorer content)."""
if role == "user":
return content
elif role == "assistant":
# Keep the <scorer> content visible but remove chat tokens
return content
return content
def build_chat_prompt(messages):
"""Build the full chat prompt with proper tokens for model generation."""
prompt = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "user":
prompt += f"{IM_START}user\n{content}{IM_END}\n"
elif role == "assistant":
if msg.get("complete", False):
# Complete message with IM_END
prompt += f"{IM_START}assistant\n{content}{IM_END}\n"
else:
# Incomplete message for generation
prompt += f"{IM_START}assistant\n{content}"
print("BUILT CHAT PROMPT:")
print("=" * 60)
print(prompt)
print("=" * 60)
return prompt
def initialize_conversation(ilr_level):
"""Initialize a new conversation with the given ILR level."""
print(f"π Initializing conversation at ILR level: {ilr_level}")
# Create initial messages
initial_user_content = INITIAL_USER_MESSAGE_TEMPLATE.format(ilr_level=ilr_level)
initial_assistant_content = f"<scorer>\n{INITIAL_ASSISTANT_SCORER}\n</scorer>\n"
messages = [
{"role": "user", "content": initial_user_content, "complete": True},
{"role": "assistant", "content": initial_assistant_content, "complete": False}
]
# Generate the initial assistant response
prompt = build_chat_prompt(messages)
response = text_completion(prompt)
# Update the assistant message with the complete response
messages[-1]["content"] = initial_assistant_content + response
messages[-1]["complete"] = True
# Convert to display format for Gradio
display_history = []
display_history.append([
format_message_for_display(initial_user_content, "user"),
format_message_for_display(messages[-1]["content"], "assistant")
])
# Format raw output for display
raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}"
return display_history, messages, raw_output
def send_message(user_input, chat_history, messages, ilr_level):
"""Handle sending a user message and generating assistant response."""
if not user_input.strip():
return chat_history, "", messages, ""
print("π SENDING MESSAGE:")
print("=" * 60)
print(f"User Input: {repr(user_input)}")
print(f"Current Messages: {len(messages)}")
print("=" * 60)
# Add user message
messages.append({"role": "user", "content": user_input, "complete": True})
# Start assistant response with scorer tag
assistant_start = "<scorer>\n"
messages.append({"role": "assistant", "content": assistant_start, "complete": False})
# Generate assistant response
prompt = build_chat_prompt(messages)
response = text_completion(prompt)
# Complete the assistant message
full_assistant_content = assistant_start + response
messages[-1]["content"] = full_assistant_content
messages[-1]["complete"] = True
# Update chat history for display
chat_history.append([
format_message_for_display(user_input, "user"),
format_message_for_display(full_assistant_content, "assistant")
])
# Format raw output for display
raw_output = f"RAW MODEL OUTPUT:\n{'=' * 80}\n{response}\n{'=' * 80}"
return chat_history, "", messages, raw_output
def reset_conversation(ilr_level):
"""Reset the conversation with a new ILR level."""
chat_history, messages, raw_output = initialize_conversation(ilr_level)
return chat_history, messages, raw_output
def create_interface():
"""Create the Gradio interface."""
with gr.Blocks(title="ILR Arabic Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πΈπ¦ ILR Arabic Assistant")
# State to store messages
messages_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
ilr_level = gr.Dropdown(
choices=ILR_LEVELS,
value="2+",
label="ILR Level",
info="Select your proficiency level"
)
reset_btn = gr.Button(
"π Reset Conversation",
variant="primary"
)
gr.Markdown("""The ILR Assistant generates Arabic reading comprehension assessments that adapt to your performance level. It presents Arabic passages with questions and automatically adjusts difficulty based on your responses - moving to easier content when you struggle or maintaining challenge when you succeed. The system was trained on authentic Arabic learning materials from the Defense Language Institute using the official ILR (Interagency Language Roundtable) proficiency scale. Try it out to see how AI can create personalized language assessments that respond to your Arabic reading comprehension skills.
### ILR Levels:
- **1**: Elementary
- **1+**: Elementary+
- **2**: Limited Working
- **2+**: Limited Working+
- **3**: General Professional
- **3+**: General Professional+
""")
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
height=500,
show_copy_button=True,
avatar_images=("π€", "π€"),
)
with gr.Row():
msg = gr.Textbox(
label="Your message",
placeholder="Type your response in English...",
scale=4
)
send_btn = gr.Button("π€ Send", scale=1, variant="primary")
# Raw output display
raw_output_display = gr.Textbox(
label="Raw Model Output",
lines=10,
max_lines=20,
interactive=False,
show_copy_button=True,
autoscroll=True,
placeholder="Raw model output will appear here...",
)
# Event handlers
def handle_reset(level):
return reset_conversation(level)
def handle_send(user_input, chat_history, messages, level):
return send_message(user_input, chat_history, messages, level)
reset_btn.click(
handle_reset,
inputs=[ilr_level],
outputs=[chatbot, messages_state, raw_output_display]
)
send_btn.click(
handle_send,
inputs=[msg, chatbot, messages_state, ilr_level],
outputs=[chatbot, msg, messages_state, raw_output_display]
)
msg.submit(
handle_send,
inputs=[msg, chatbot, messages_state, ilr_level],
outputs=[chatbot, msg, messages_state, raw_output_display]
)
# Initialize conversation on load
def on_load(level):
chat_history, messages, raw_output = initialize_conversation(level)
return chat_history, messages, raw_output
demo.load(
on_load,
inputs=[ilr_level],
outputs=[chatbot, messages_state, raw_output_display]
)
return demo
if __name__ == "__main__":
demo = create_interface()
load_model_and_tokenizer()
demo.launch() |