| | import gradio as gr |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from peft import PeftModel |
| | import torch |
| | import html |
| |
|
| | |
| | BASE_MODEL_ID = "google/gemma-2-2b-it" |
| | ADAPTER_ID = "Phonsiri/gemma-2-2b-it-grpo-v6-checkpoints" |
| |
|
| | |
| | print(f"Loading base model: {BASE_MODEL_ID}...") |
| | tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
| |
|
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | BASE_MODEL_ID, |
| | device_map="auto", |
| | torch_dtype=torch.float16 |
| | ) |
| |
|
| | print(f"Loading adapter: {ADAPTER_ID}...") |
| | model = PeftModel.from_pretrained(base_model, ADAPTER_ID) |
| |
|
| | def generate(prompt): |
| | messages = [{"role": "user", "content": prompt}] |
| | |
| | formatted_prompt = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | |
| | inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=2048, |
| | temperature=0.6, |
| | top_p=0.9, |
| | do_sample=True, |
| | pad_token_id=tokenizer.eos_token_id, |
| | eos_token_id=tokenizer.eos_token_id |
| | ) |
| | |
| | full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | |
| | if "model\n" in full_response: |
| | response = full_response.split("model\n")[-1].strip() |
| | elif "<start_of_turn>model" in full_response: |
| | response = full_response.split("<start_of_turn>model")[-1].strip() |
| | else: |
| | response = full_response[len(formatted_prompt):].strip() |
| | if len(response) == 0: |
| | response = full_response |
| |
|
| | response = response.replace("<end_of_turn>", "").strip() |
| | |
| | |
| | |
| | |
| | return f"```xml\n{response}\n```" |
| |
|
| | |
| | examples = [ |
| | ["A store sells notebooks for $3 each and pens for $1.50 each. Sarah buys 4 notebooks and 6 pens. How much does she pay in total?"], |
| | ["If John is taller than Mike, and Mike is taller than Sarah, who is the tallest?"], |
| | ["Solve for x: 2x + 5 = 15"] |
| | ] |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=generate, |
| | inputs=gr.Textbox( |
| | label="Question", |
| | lines=3, |
| | placeholder="Ask a math or reasoning question..." |
| | ), |
| | outputs=gr.Markdown( |
| | label="Reasoning & Answer" |
| | ), |
| | title="Gemma-2-2B GRPO (Adapter Version)", |
| | description=f"Running Adapter: {ADAPTER_ID}\nBase Model: {BASE_MODEL_ID}", |
| | examples=examples, |
| | theme=gr.themes.Soft() |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True) |
| |
|