Jommarn / app.py
Phonsiri's picture
Update app.py
ff7b2fd verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import html # เพิ่ม html library
# --- Configuration ---
BASE_MODEL_ID = "google/gemma-2-2b-it"
ADAPTER_ID = "Phonsiri/gemma-2-2b-it-grpo-v6-checkpoints"
# --- Load Tokenizer & Model ---
print(f"Loading base model: {BASE_MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
device_map="auto",
torch_dtype=torch.float16
)
print(f"Loading adapter: {ADAPTER_ID}...")
model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
def generate(prompt):
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.6,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Cleaning Logic
if "model\n" in full_response:
response = full_response.split("model\n")[-1].strip()
elif "<start_of_turn>model" in full_response:
response = full_response.split("<start_of_turn>model")[-1].strip()
else:
response = full_response[len(formatted_prompt):].strip()
if len(response) == 0:
response = full_response
response = response.replace("<end_of_turn>", "").strip()
# --- สำคัญ: แก้ไขการแสดงผล Tag ---
# แปลง < เป็น &lt; เพื่อให้ Gradio ไม่มองว่าเป็น HTML tag ที่ต้องซ่อน
# หรือใช้วิธีใส่ Code Block ครอบ
return f"```xml\n{response}\n```"
# --- Gradio UI ---
examples = [
["A store sells notebooks for $3 each and pens for $1.50 each. Sarah buys 4 notebooks and 6 pens. How much does she pay in total?"],
["If John is taller than Mike, and Mike is taller than Sarah, who is the tallest?"],
["Solve for x: 2x + 5 = 15"]
]
# เปลี่ยน Output เป็น Markdown เพื่อให้ render code block สวยๆ
demo = gr.Interface(
fn=generate,
inputs=gr.Textbox(
label="Question",
lines=3,
placeholder="Ask a math or reasoning question..."
),
outputs=gr.Markdown( # เปลี่ยนจาก Textbox เป็น Markdown
label="Reasoning & Answer"
),
title="Gemma-2-2B GRPO (Adapter Version)",
description=f"Running Adapter: {ADAPTER_ID}\nBase Model: {BASE_MODEL_ID}",
examples=examples,
theme=gr.themes.Soft()
)
if __name__ == "__main__":
demo.launch(share=True)