Phonsiri commited on
Commit
ff7b2fd
·
verified ·
1 Parent(s): 19dc641

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -24
app.py CHANGED
@@ -2,31 +2,26 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
  import torch
 
5
 
6
  # --- Configuration ---
7
- BASE_MODEL_ID = "google/gemma-2-2b-it" # โมเดลหลัก
8
- ADAPTER_ID = "Phonsiri/gemma-2-2b-it-grpo-v6-checkpoints" # โมเดล Adapter (LoRA) ที่ต้องการโหลด
9
 
10
  # --- Load Tokenizer & Model ---
11
  print(f"Loading base model: {BASE_MODEL_ID}...")
12
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
13
 
14
- # 1. โหลด Base Model ก่อน
15
  base_model = AutoModelForCausalLM.from_pretrained(
16
  BASE_MODEL_ID,
17
  device_map="auto",
18
  torch_dtype=torch.float16
19
  )
20
 
21
- # 2. โหลด Adapter (LoRA) มาประกบ
22
  print(f"Loading adapter: {ADAPTER_ID}...")
23
  model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
24
 
25
- # (Optional) ถ้าต้องการให้ Inference เร็วขึ้นนิดหน่อย สามารถ Merge ได้เลย (กิน RAM ตอนโหลดเพิ่มชั่วคราว)
26
- # model = model.merge_and_unload()
27
-
28
  def generate(prompt):
29
- # สร้าง Chat Template
30
  messages = [{"role": "user", "content": prompt}]
31
 
32
  formatted_prompt = tokenizer.apply_chat_template(
@@ -37,39 +32,35 @@ def generate(prompt):
37
 
38
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
39
 
40
- # Generate
41
  with torch.no_grad():
42
  outputs = model.generate(
43
  **inputs,
44
- max_new_tokens=2048, # เพิ่มความยาวเผื่อการคิดแบบ Chain-of-Thought
45
- temperature=0.6, # ลดลงนิดหน่อยเพื่อให้ Reasoning นิ่งขึ้น
46
  top_p=0.9,
47
  do_sample=True,
48
  pad_token_id=tokenizer.eos_token_id,
49
  eos_token_id=tokenizer.eos_token_id
50
  )
51
 
52
- # Decode
53
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
 
55
- # --- Response Cleaning Logic ---
56
- # ตัดส่วน Prompt ออกเพื่อให้เหลือแค่คำตอบของโมเดล
57
  if "model\n" in full_response:
58
- # ตัดที่ token model ตัวสุดท้าย (Gemma chat format)
59
  response = full_response.split("model\n")[-1].strip()
60
  elif "<start_of_turn>model" in full_response:
61
  response = full_response.split("<start_of_turn>model")[-1].strip()
62
  else:
63
- # Fallback: ตัดตามความยาว prompt
64
- # (วิธีนี้อาจไม่แม่นยำ 100% ถ้า prompt ถูก format ใหม่ แต่ใช้กันเหนียว)
65
- response = full_response[len(formatted_prompt):].strip() # ตัดจาก formatted prompt ดีกว่า
66
- if len(response) == 0: # ถ้าตัดแล้วหายหมด ให้ใช้ raw decode
67
  response = full_response
68
 
69
- # ลบ tags ที่อาจหลงเหลือ
70
  response = response.replace("<end_of_turn>", "").strip()
71
 
72
- return response
 
 
 
73
 
74
  # --- Gradio UI ---
75
  examples = [
@@ -78,6 +69,7 @@ examples = [
78
  ["Solve for x: 2x + 5 = 15"]
79
  ]
80
 
 
81
  demo = gr.Interface(
82
  fn=generate,
83
  inputs=gr.Textbox(
@@ -85,9 +77,8 @@ demo = gr.Interface(
85
  lines=3,
86
  placeholder="Ask a math or reasoning question..."
87
  ),
88
- outputs=gr.Textbox(
89
- label="Reasoning & Answer",
90
- lines=15 # เพิ่มบรรทัดเพราะ GRPO มักตอบยาว
91
  ),
92
  title="Gemma-2-2B GRPO (Adapter Version)",
93
  description=f"Running Adapter: {ADAPTER_ID}\nBase Model: {BASE_MODEL_ID}",
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
  import torch
5
+ import html # เพิ่ม html library
6
 
7
  # --- Configuration ---
8
+ BASE_MODEL_ID = "google/gemma-2-2b-it"
9
+ ADAPTER_ID = "Phonsiri/gemma-2-2b-it-grpo-v6-checkpoints"
10
 
11
  # --- Load Tokenizer & Model ---
12
  print(f"Loading base model: {BASE_MODEL_ID}...")
13
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
14
 
 
15
  base_model = AutoModelForCausalLM.from_pretrained(
16
  BASE_MODEL_ID,
17
  device_map="auto",
18
  torch_dtype=torch.float16
19
  )
20
 
 
21
  print(f"Loading adapter: {ADAPTER_ID}...")
22
  model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
23
 
 
 
 
24
  def generate(prompt):
 
25
  messages = [{"role": "user", "content": prompt}]
26
 
27
  formatted_prompt = tokenizer.apply_chat_template(
 
32
 
33
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
34
 
 
35
  with torch.no_grad():
36
  outputs = model.generate(
37
  **inputs,
38
+ max_new_tokens=2048,
39
+ temperature=0.6,
40
  top_p=0.9,
41
  do_sample=True,
42
  pad_token_id=tokenizer.eos_token_id,
43
  eos_token_id=tokenizer.eos_token_id
44
  )
45
 
 
46
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
 
48
+ # Cleaning Logic
 
49
  if "model\n" in full_response:
 
50
  response = full_response.split("model\n")[-1].strip()
51
  elif "<start_of_turn>model" in full_response:
52
  response = full_response.split("<start_of_turn>model")[-1].strip()
53
  else:
54
+ response = full_response[len(formatted_prompt):].strip()
55
+ if len(response) == 0:
 
 
56
  response = full_response
57
 
 
58
  response = response.replace("<end_of_turn>", "").strip()
59
 
60
+ # --- สำคัญ: แก้ไขการแสดงผล Tag ---
61
+ # แปลง < เป็น &lt; เพื่อให้ Gradio ไม่มองว่าเป็น HTML tag ที่ต้องซ่อน
62
+ # หรือใช้วิธีใส่ Code Block ครอบ
63
+ return f"```xml\n{response}\n```"
64
 
65
  # --- Gradio UI ---
66
  examples = [
 
69
  ["Solve for x: 2x + 5 = 15"]
70
  ]
71
 
72
+ # เปลี่ยน Output เป็น Markdown เพื่อให้ render code block สวยๆ
73
  demo = gr.Interface(
74
  fn=generate,
75
  inputs=gr.Textbox(
 
77
  lines=3,
78
  placeholder="Ask a math or reasoning question..."
79
  ),
80
+ outputs=gr.Markdown( # เปลี่ยนจาก Textbox เป็น Markdown
81
+ label="Reasoning & Answer"
 
82
  ),
83
  title="Gemma-2-2B GRPO (Adapter Version)",
84
  description=f"Running Adapter: {ADAPTER_ID}\nBase Model: {BASE_MODEL_ID}",