import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import torch import html # เพิ่ม html library # --- Configuration --- BASE_MODEL_ID = "google/gemma-2-2b-it" ADAPTER_ID = "Phonsiri/gemma-2-2b-it-grpo-v6-checkpoints" # --- Load Tokenizer & Model --- print(f"Loading base model: {BASE_MODEL_ID}...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, device_map="auto", torch_dtype=torch.float16 ) print(f"Loading adapter: {ADAPTER_ID}...") model = PeftModel.from_pretrained(base_model, ADAPTER_ID) def generate(prompt): messages = [{"role": "user", "content": prompt}] formatted_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=2048, temperature=0.6, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Cleaning Logic if "model\n" in full_response: response = full_response.split("model\n")[-1].strip() elif "model" in full_response: response = full_response.split("model")[-1].strip() else: response = full_response[len(formatted_prompt):].strip() if len(response) == 0: response = full_response response = response.replace("", "").strip() # --- สำคัญ: แก้ไขการแสดงผล Tag --- # แปลง < เป็น < เพื่อให้ Gradio ไม่มองว่าเป็น HTML tag ที่ต้องซ่อน # หรือใช้วิธีใส่ Code Block ครอบ return f"```xml\n{response}\n```" # --- Gradio UI --- examples = [ ["A store sells notebooks for $3 each and pens for $1.50 each. Sarah buys 4 notebooks and 6 pens. How much does she pay in total?"], ["If John is taller than Mike, and Mike is taller than Sarah, who is the tallest?"], ["Solve for x: 2x + 5 = 15"] ] # เปลี่ยน Output เป็น Markdown เพื่อให้ render code block สวยๆ demo = gr.Interface( fn=generate, inputs=gr.Textbox( label="Question", lines=3, placeholder="Ask a math or reasoning question..." ), outputs=gr.Markdown( # เปลี่ยนจาก Textbox เป็น Markdown label="Reasoning & Answer" ), title="Gemma-2-2B GRPO (Adapter Version)", description=f"Running Adapter: {ADAPTER_ID}\nBase Model: {BASE_MODEL_ID}", examples=examples, theme=gr.themes.Soft() ) if __name__ == "__main__": demo.launch(share=True)