Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel, PeftConfig | |
| import torch | |
| import os | |
| def load_model(model_id, model_type="base"): | |
| try: | |
| if model_type == "base": | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| return tokenizer, model | |
| else: # finetuned model with PEFT | |
| # Load the base model first | |
| base_model_id = "satyanayak/gemma-3-base" | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Load and merge the PEFT adapters | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| return tokenizer, model | |
| except Exception as e: | |
| print(f"Error loading {model_type} model: {str(e)}") | |
| return None, None | |
| # Load base model and tokenizer | |
| base_model_id = "satyanayak/gemma-3-base" | |
| base_tokenizer, base_model = load_model(base_model_id, "base") | |
| # Load finetuned model and tokenizer | |
| finetuned_model_id = "satyanayak/gemma-3-GRPO" | |
| finetuned_tokenizer, finetuned_model = load_model(finetuned_model_id, "finetuned") | |
| def generate_base_response(prompt, max_length=512): | |
| if base_model is None or base_tokenizer is None: | |
| return "Error: Base model failed to load. Please check if the model files are properly uploaded to Hugging Face." | |
| try: | |
| inputs = base_tokenizer(prompt, return_tensors="pt").to(base_model.device) | |
| outputs = base_model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=base_tokenizer.eos_token_id | |
| ) | |
| response = base_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| except Exception as e: | |
| return f"Error generating response with base model: {str(e)}" | |
| def generate_finetuned_response(prompt, max_length=512): | |
| if finetuned_model is None or finetuned_tokenizer is None: | |
| return "Error: Finetuned model failed to load. Please check if the model files are properly uploaded to Hugging Face." | |
| try: | |
| inputs = finetuned_tokenizer(prompt, return_tensors="pt").to(finetuned_model.device) | |
| outputs = finetuned_model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=finetuned_tokenizer.eos_token_id | |
| ) | |
| response = finetuned_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| except Exception as e: | |
| return f"Error generating response with finetuned model: {str(e)}" | |
| # Example prompts | |
| examples = [ | |
| ["What is the sqrt of 101"], | |
| ["How many r's are in strawberry?"], | |
| ["If Tom has 3 more apples than Jerry and Jerry has 5 apples, how many apples does Tom have?"] | |
| ] | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Gemma-3 Model Comparison Demo") | |
| gr.Markdown("Compare responses between the base model and the GRPO-finetuned model.") | |
| with gr.Row(): | |
| # Base Model Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Base Model (Gemma-3)") | |
| base_input = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="Type your prompt here...", | |
| lines=5 | |
| ) | |
| base_generate_btn = gr.Button("Generate with Base Model") | |
| base_output = gr.Textbox(label="Base Model Output", lines=10) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=base_input, | |
| outputs=base_output, | |
| fn=generate_base_response, | |
| cache_examples=True | |
| ) | |
| # Finetuned Model Column | |
| with gr.Column(scale=1): | |
| gr.Markdown("## GRPO-Finetuned Model") | |
| finetuned_input = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="Type your prompt here...", | |
| lines=5 | |
| ) | |
| finetuned_generate_btn = gr.Button("Generate with Finetuned Model") | |
| finetuned_output = gr.Textbox(label="Finetuned Model Output", lines=10) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=finetuned_input, | |
| outputs=finetuned_output, | |
| fn=generate_finetuned_response, | |
| cache_examples=True | |
| ) | |
| # Connect buttons to their respective functions | |
| base_generate_btn.click( | |
| fn=generate_base_response, | |
| inputs=base_input, | |
| outputs=base_output | |
| ) | |
| finetuned_generate_btn.click( | |
| fn=generate_finetuned_response, | |
| inputs=finetuned_input, | |
| outputs=finetuned_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |