jiminaa commited on
Commit
f6b9677
·
1 Parent(s): 7592cae

comparing base with finetuned

Browse files
Files changed (1) hide show
  1. main.py +128 -42
main.py CHANGED
@@ -22,8 +22,8 @@ MODEL = "meta-llama/Llama-3.2-1B-Instruct"
22
 
23
  app = FastAPI()
24
 
25
- # base model and tokenizer
26
- base_model = AutoModelForCausalLM.from_pretrained(
27
  MODEL,
28
  token=HF_TOKEN,
29
  dtype=torch.bfloat16, # faster than float32, matches GPU training
@@ -31,8 +31,19 @@ base_model = AutoModelForCausalLM.from_pretrained(
31
  low_cpu_mem_usage=True,
32
  attn_implementation="sdpa", # PyTorch optimized attention
33
  )
 
34
 
 
 
 
 
 
 
 
 
 
35
  base_model.config.use_cache = True
 
36
 
37
  tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN)
38
 
@@ -50,7 +61,7 @@ languages = list(adapter_paths.keys())
50
 
51
  # Create PeftModel with first adapter
52
  peft_model = PeftModel.from_pretrained(
53
- base_model,
54
  adapter_paths[languages[0]],
55
  adapter_name=languages[0],
56
  is_trainable=False
@@ -63,6 +74,44 @@ for lang in languages[1:]:
63
  peft_model.eval()
64
  print("All adapters ready.")
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # the input will be a list of messages that include system, user, and assistant prompts
67
  def generate_text_stream(messages, language, max_length=256, temperature=0.7):
68
 
@@ -249,14 +298,32 @@ async def chat_completions(request: ChatCompletionRequest):
249
  }
250
  )
251
 
252
- def chat_gradio(message, history, language, system_prompt, max_length, temperature):
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  messages = []
255
 
256
  if system_prompt:
257
  messages.append({"role": "system", "content": system_prompt})
258
 
259
- # only uses the last 10 messages to keep within context limit
260
  messages.extend(history[-10:])
261
 
262
  user_msg = {"role": "user", "content": message}
@@ -273,48 +340,57 @@ def chat_gradio(message, history, language, system_prompt, max_length, temperatu
273
  yield history + [user_msg, assistant_msg]
274
 
275
  with gr.Blocks(
276
- title="Language Learning Chatbot",
277
  theme=gr.themes.Soft()
278
  ) as demo:
279
 
280
  with gr.Row():
281
- with gr.Column(scale=2):
282
- chatbot = gr.Chatbot(
283
- label="Conversation",
284
- height=500,
 
285
  type="messages"
286
  )
287
 
288
- # User input
289
- with gr.Row():
290
- msg = gr.Textbox(
291
- label="Your message",
292
- placeholder="Type your message here and press Enter...",
293
- lines=2,
294
- scale=4
295
- )
296
 
297
- with gr.Row():
298
- submit_btn = gr.Button("Send", variant="primary", scale=1)
299
- clear_btn = gr.Button("Clear Chat", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- with gr.Column(scale=1):
302
- gr.Markdown("### ⚙️ Settings")
303
-
304
  language_dropdown = gr.Dropdown(
305
  choices=list(adapter_paths.keys()),
306
- label="Language",
307
  value=list(adapter_paths.keys())[0],
308
- info="Select the language model to use"
309
  )
310
 
311
  system_prompt_input = gr.Textbox(
312
  label="System Prompt (Optional)",
313
  placeholder="e.g., You are a helpful assistant...",
314
  lines=3,
315
- info="Set the assistant's behavior"
316
  )
317
-
318
  max_length_slider = gr.Slider(
319
  minimum=50,
320
  maximum=512,
@@ -333,31 +409,41 @@ with gr.Blocks(
333
  info="Higher = more creative"
334
  )
335
 
336
- # handling enter key in textbox
 
 
 
 
 
337
  msg.submit(
338
- fn=chat_gradio,
339
- inputs=[msg, chatbot, language_dropdown, system_prompt_input, max_length_slider, temperature_slider],
340
- outputs=[chatbot], # Update chatbot with streaming response
341
  ).then(
342
- fn=lambda: gr.update(value=""), # Clear input after sending
343
  outputs=[msg]
344
  )
345
 
346
- # Handle button click
347
  submit_btn.click(
348
- fn=chat_gradio,
349
- inputs=[msg, chatbot, language_dropdown, system_prompt_input, max_length_slider, temperature_slider],
350
- outputs=[chatbot],
 
 
 
 
 
351
  ).then(
352
  fn=lambda: gr.update(value=""),
353
  outputs=[msg]
354
  )
355
-
356
- # Clear chat button
357
  clear_btn.click(
358
- fn=lambda: None, # Return None to clear chatbot
359
- outputs=[chatbot],
360
- queue=False # Don't queue this action
361
  )
362
 
363
  demo.queue(False)
 
22
 
23
  app = FastAPI()
24
 
25
+ # base model for finetuned (LoRA) inference
26
+ finetuned_base = AutoModelForCausalLM.from_pretrained(
27
  MODEL,
28
  token=HF_TOKEN,
29
  dtype=torch.bfloat16, # faster than float32, matches GPU training
 
31
  low_cpu_mem_usage=True,
32
  attn_implementation="sdpa", # PyTorch optimized attention
33
  )
34
+ finetuned_base.config.use_cache = True
35
 
36
+ # separate base model for comparison (no LoRA)
37
+ base_model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL,
39
+ token=HF_TOKEN,
40
+ dtype=torch.bfloat16,
41
+ device_map="cpu",
42
+ low_cpu_mem_usage=True,
43
+ attn_implementation="sdpa",
44
+ )
45
  base_model.config.use_cache = True
46
+ base_model.eval()
47
 
48
  tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN)
49
 
 
61
 
62
  # Create PeftModel with first adapter
63
  peft_model = PeftModel.from_pretrained(
64
+ finetuned_base,
65
  adapter_paths[languages[0]],
66
  adapter_name=languages[0],
67
  is_trainable=False
 
74
  peft_model.eval()
75
  print("All adapters ready.")
76
 
77
+ # base model generation (no LoRA)
78
+ def generate_base_model_stream(messages, max_length=256, temperature=0.7):
79
+
80
+ print(f"Base model (no LoRA)")
81
+ print(f"Messages: {messages}")
82
+
83
+ inputs = tokenizer.apply_chat_template(
84
+ messages,
85
+ tokenize=True,
86
+ add_generation_prompt=True,
87
+ return_tensors="pt",
88
+ return_dict=True
89
+ ).to(base_model.device)
90
+
91
+ streamer = TextIteratorStreamer(
92
+ tokenizer,
93
+ skip_prompt=True,
94
+ skip_special_tokens=True)
95
+
96
+ generation_kwargs = {
97
+ **inputs,
98
+ "max_new_tokens": max_length,
99
+ "temperature": temperature,
100
+ "do_sample": True,
101
+ "pad_token_id": tokenizer.eos_token_id,
102
+ "streamer": streamer,
103
+ "num_beams": 1,
104
+ "use_cache": True,
105
+ }
106
+
107
+ thread = Thread(target=base_model.generate, kwargs=generation_kwargs)
108
+ thread.start()
109
+
110
+ for text in streamer:
111
+ yield text
112
+
113
+ thread.join()
114
+
115
  # the input will be a list of messages that include system, user, and assistant prompts
116
  def generate_text_stream(messages, language, max_length=256, temperature=0.7):
117
 
 
298
  }
299
  )
300
 
301
+ def chat_base_model(message, history, system_prompt, max_length, temperature):
302
+ messages = []
303
 
304
+ if system_prompt:
305
+ messages.append({"role": "system", "content": system_prompt})
306
+
307
+ messages.extend(history[-10:])
308
+
309
+ user_msg = {"role": "user", "content": message}
310
+ messages.append(user_msg)
311
+
312
+ assistant_msg = {"role": "assistant", "content": ""}
313
+ for token in generate_base_model_stream(
314
+ messages,
315
+ max_length,
316
+ temperature
317
+ ):
318
+ assistant_msg["content"] += token
319
+ yield history + [user_msg, assistant_msg]
320
+
321
+ def chat_finetuned(message, history, language, system_prompt, max_length, temperature):
322
  messages = []
323
 
324
  if system_prompt:
325
  messages.append({"role": "system", "content": system_prompt})
326
 
 
327
  messages.extend(history[-10:])
328
 
329
  user_msg = {"role": "user", "content": message}
 
340
  yield history + [user_msg, assistant_msg]
341
 
342
  with gr.Blocks(
343
+ title="Language Learning Chatbot",
344
  theme=gr.themes.Soft()
345
  ) as demo:
346
 
347
  with gr.Row():
348
+ with gr.Column(scale=1):
349
+ gr.Markdown("### Base Model (No LoRA)")
350
+ chatbot_base = gr.Chatbot(
351
+ label="Base Model",
352
+ height=400,
353
  type="messages"
354
  )
355
 
356
+ with gr.Column(scale=1):
357
+ gr.Markdown("### Finetuned Model (LoRA)")
358
+ chatbot_finetuned = gr.Chatbot(
359
+ label="Finetuned Model",
360
+ height=400,
361
+ type="messages"
362
+ )
 
363
 
364
+ with gr.Row():
365
+ msg = gr.Textbox(
366
+ label="Your message",
367
+ placeholder="Type your message here and press Enter...",
368
+ lines=2,
369
+ scale=4
370
+ )
371
+
372
+ with gr.Row():
373
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
374
+ clear_btn = gr.Button("Clear Both Chats", scale=1)
375
+
376
+ with gr.Row():
377
+ with gr.Column():
378
+ gr.Markdown("### Settings")
379
 
 
 
 
380
  language_dropdown = gr.Dropdown(
381
  choices=list(adapter_paths.keys()),
382
+ label="Language (for Finetuned Model)",
383
  value=list(adapter_paths.keys())[0],
384
+ info="Select the language adapter to use"
385
  )
386
 
387
  system_prompt_input = gr.Textbox(
388
  label="System Prompt (Optional)",
389
  placeholder="e.g., You are a helpful assistant...",
390
  lines=3,
391
+ info="Shared between both models"
392
  )
393
+
394
  max_length_slider = gr.Slider(
395
  minimum=50,
396
  maximum=512,
 
409
  info="Higher = more creative"
410
  )
411
 
412
+ # handling enter key in textbox - send to both models
413
+ msg.submit(
414
+ fn=chat_base_model,
415
+ inputs=[msg, chatbot_base, system_prompt_input, max_length_slider, temperature_slider],
416
+ outputs=[chatbot_base],
417
+ )
418
  msg.submit(
419
+ fn=chat_finetuned,
420
+ inputs=[msg, chatbot_finetuned, language_dropdown, system_prompt_input, max_length_slider, temperature_slider],
421
+ outputs=[chatbot_finetuned],
422
  ).then(
423
+ fn=lambda: gr.update(value=""),
424
  outputs=[msg]
425
  )
426
 
427
+ # Handle button click - send to both models
428
  submit_btn.click(
429
+ fn=chat_base_model,
430
+ inputs=[msg, chatbot_base, system_prompt_input, max_length_slider, temperature_slider],
431
+ outputs=[chatbot_base],
432
+ )
433
+ submit_btn.click(
434
+ fn=chat_finetuned,
435
+ inputs=[msg, chatbot_finetuned, language_dropdown, system_prompt_input, max_length_slider, temperature_slider],
436
+ outputs=[chatbot_finetuned],
437
  ).then(
438
  fn=lambda: gr.update(value=""),
439
  outputs=[msg]
440
  )
441
+
442
+ # Clear both chats
443
  clear_btn.click(
444
+ fn=lambda: (None, None),
445
+ outputs=[chatbot_base, chatbot_finetuned],
446
+ queue=False
447
  )
448
 
449
  demo.queue(False)