Spaces:
Paused
Paused
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| from datasets import load_dataset | |
| # Model configuration | |
| MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| ADAPTER_PATH = "sumedh/tinyllama-lora-math-adapter-v3" | |
| # Load tokenizer globally | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| # Load test dataset for examples | |
| print("Loading test dataset...") | |
| test_data = load_dataset("openai/gsm8k", "main", split="test[:50]") | |
| # Models will be loaded lazily on first GPU call | |
| base_model = None | |
| tuned_model = None | |
| def load_models(): | |
| """Load models on GPU when needed.""" | |
| global base_model, tuned_model | |
| if base_model is None: | |
| print("Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ).eval() | |
| if tuned_model is None: | |
| print("Loading fine-tuned model...") | |
| tmp_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| tuned_model = PeftModel.from_pretrained(tmp_model, ADAPTER_PATH) | |
| tuned_model = tuned_model.merge_and_unload().eval() | |
| print("Models loaded!") | |
| return base_model, tuned_model | |
| def generate_responses(question): | |
| """Generate responses from both models - runs on GPU.""" | |
| if not question.strip(): | |
| return "Please enter a question.", "" | |
| # Load models if not already loaded | |
| base, tuned = load_models() | |
| prompt = f"### Instruction:\n{question}\n### Response:\n" | |
| token_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(base.device) | |
| # Generate from base model | |
| with torch.no_grad(): | |
| base_output = base.generate( | |
| token_ids, | |
| max_new_tokens=256, | |
| pad_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| base_response = tokenizer.decode(base_output[0], skip_special_tokens=True) | |
| if "### Response:" in base_response: | |
| base_response = base_response.split("### Response:")[-1].strip() | |
| # Generate from tuned model | |
| token_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(tuned.device) | |
| with torch.no_grad(): | |
| tuned_output = tuned.generate( | |
| token_ids, | |
| max_new_tokens=256, | |
| pad_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| tuned_response = tokenizer.decode(tuned_output[0], skip_special_tokens=True) | |
| if "### Response:" in tuned_response: | |
| tuned_response = tuned_response.split("### Response:")[-1].strip() | |
| return base_response, tuned_response | |
| def load_example(idx): | |
| """Load an example from the test dataset.""" | |
| idx = int(idx) | |
| if 0 <= idx < len(test_data): | |
| question = test_data[idx]["question"] | |
| answer = test_data[idx]["answer"] | |
| return question, answer | |
| return "", "" | |
| def run_comparison(question, reference): | |
| """Run the full comparison.""" | |
| if not question.strip(): | |
| return "Please enter a question.", "", "" | |
| base_response, tuned_response = generate_responses(question) | |
| return base_response, tuned_response, reference | |
| # Create Gradio interface | |
| with gr.Blocks(title="TinyLlama Math Fine-tuning Demo") as demo: | |
| gr.Markdown(""" | |
| # 🧮 TinyLlama Math Fine-tuning Demo | |
| Compare the performance of **base TinyLlama** vs **fine-tuned TinyLlama** on math word problems. | |
| - **Base Model**: TinyLlama-1.1B-Chat-v1.0 (no math training) | |
| - **Fine-tuned Model**: LoRA adapter trained on GSM8K dataset (7,473 examples) | |
| *Note: First run may take ~30s to load models.* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| question_input = gr.Textbox( | |
| label="Math Question", | |
| placeholder="Enter a math word problem...", | |
| lines=4 | |
| ) | |
| reference_input = gr.Textbox( | |
| label="Reference Answer (optional)", | |
| placeholder="The correct answer will appear here when loading examples", | |
| lines=4 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Load Example") | |
| example_slider = gr.Slider( | |
| minimum=0, | |
| maximum=49, | |
| step=1, | |
| value=0, | |
| label="Example Index (0-49)" | |
| ) | |
| load_btn = gr.Button("Load Example", variant="secondary") | |
| compare_btn = gr.Button("Compare Models", variant="primary", size="lg") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### 🔴 Base Model Response") | |
| base_output = gr.Textbox(label="", lines=10, show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("### 🟢 Fine-tuned Model Response") | |
| tuned_output = gr.Textbox(label="", lines=10, show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("### ✅ Correct Answer") | |
| reference_output = gr.Textbox(label="", lines=10, show_label=False) | |
| # Event handlers | |
| load_btn.click( | |
| fn=load_example, | |
| inputs=[example_slider], | |
| outputs=[question_input, reference_input] | |
| ) | |
| compare_btn.click( | |
| fn=run_comparison, | |
| inputs=[question_input, reference_input], | |
| outputs=[base_output, tuned_output, reference_output] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### About | |
| This demo showcases the effect of fine-tuning a small language model (TinyLlama 1.1B) on math word problems. | |
| - **Dataset**: [GSM8K](https://huggingface.co/datasets/openai/gsm8k) - Grade School Math 8K | |
| - **Method**: LoRA (Low-Rank Adaptation) | |
| - **Training**: 5 epochs on 7,473 examples | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |