File size: 6,147 Bytes
82a52cd
f66cd1f
 
82a52cd
f66cd1f
 
 
 
 
 
 
82a52cd
 
f66cd1f
 
 
82a52cd
f66cd1f
 
82a52cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f66cd1f
 
82a52cd
 
 
 
 
 
 
 
 
 
 
f66cd1f
82a52cd
f66cd1f
82a52cd
f66cd1f
82a52cd
f66cd1f
 
 
82a52cd
 
 
f66cd1f
82a52cd
 
 
 
 
 
 
 
 
 
 
 
f66cd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82a52cd
f66cd1f
 
 
 
82a52cd
f66cd1f
 
 
 
 
 
 
82a52cd
 
f66cd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82a52cd
f66cd1f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset

# Model configuration
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_PATH = "sumedh/tinyllama-lora-math-adapter-v3"

# Load tokenizer globally
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Load test dataset for examples
print("Loading test dataset...")
test_data = load_dataset("openai/gsm8k", "main", split="test[:50]")

# Models will be loaded lazily on first GPU call
base_model = None
tuned_model = None


def load_models():
    """Load models on GPU when needed."""
    global base_model, tuned_model
    
    if base_model is None:
        print("Loading base model...")
        base_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        ).eval()
    
    if tuned_model is None:
        print("Loading fine-tuned model...")
        tmp_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        tuned_model = PeftModel.from_pretrained(tmp_model, ADAPTER_PATH)
        tuned_model = tuned_model.merge_and_unload().eval()
    
    print("Models loaded!")
    return base_model, tuned_model


@spaces.GPU
def generate_responses(question):
    """Generate responses from both models - runs on GPU."""
    if not question.strip():
        return "Please enter a question.", ""
    
    # Load models if not already loaded
    base, tuned = load_models()
    
    prompt = f"### Instruction:\n{question}\n### Response:\n"
    token_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(base.device)
    
    # Generate from base model
    with torch.no_grad():
        base_output = base.generate(
            token_ids,
            max_new_tokens=256,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1
        )
    base_response = tokenizer.decode(base_output[0], skip_special_tokens=True)
    if "### Response:" in base_response:
        base_response = base_response.split("### Response:")[-1].strip()
    
    # Generate from tuned model
    token_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(tuned.device)
    with torch.no_grad():
        tuned_output = tuned.generate(
            token_ids,
            max_new_tokens=256,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1
        )
    tuned_response = tokenizer.decode(tuned_output[0], skip_special_tokens=True)
    if "### Response:" in tuned_response:
        tuned_response = tuned_response.split("### Response:")[-1].strip()
    
    return base_response, tuned_response


def load_example(idx):
    """Load an example from the test dataset."""
    idx = int(idx)
    if 0 <= idx < len(test_data):
        question = test_data[idx]["question"]
        answer = test_data[idx]["answer"]
        return question, answer
    return "", ""


def run_comparison(question, reference):
    """Run the full comparison."""
    if not question.strip():
        return "Please enter a question.", "", ""
    
    base_response, tuned_response = generate_responses(question)
    return base_response, tuned_response, reference


# Create Gradio interface
with gr.Blocks(title="TinyLlama Math Fine-tuning Demo") as demo:
    gr.Markdown("""
    # 🧮 TinyLlama Math Fine-tuning Demo
    
    Compare the performance of **base TinyLlama** vs **fine-tuned TinyLlama** on math word problems.
    
    - **Base Model**: TinyLlama-1.1B-Chat-v1.0 (no math training)
    - **Fine-tuned Model**: LoRA adapter trained on GSM8K dataset (7,473 examples)
    
    *Note: First run may take ~30s to load models.*
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            question_input = gr.Textbox(
                label="Math Question",
                placeholder="Enter a math word problem...",
                lines=4
            )
            reference_input = gr.Textbox(
                label="Reference Answer (optional)",
                placeholder="The correct answer will appear here when loading examples",
                lines=4
            )
        
        with gr.Column(scale=1):
            gr.Markdown("### Load Example")
            example_slider = gr.Slider(
                minimum=0,
                maximum=49,
                step=1,
                value=0,
                label="Example Index (0-49)"
            )
            load_btn = gr.Button("Load Example", variant="secondary")
    
    compare_btn = gr.Button("Compare Models", variant="primary", size="lg")
    
    gr.Markdown("---")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### 🔴 Base Model Response")
            base_output = gr.Textbox(label="", lines=10, show_label=False)
        
        with gr.Column():
            gr.Markdown("### 🟢 Fine-tuned Model Response")
            tuned_output = gr.Textbox(label="", lines=10, show_label=False)
        
        with gr.Column():
            gr.Markdown("### ✅ Correct Answer")
            reference_output = gr.Textbox(label="", lines=10, show_label=False)
    
    # Event handlers
    load_btn.click(
        fn=load_example,
        inputs=[example_slider],
        outputs=[question_input, reference_input]
    )
    
    compare_btn.click(
        fn=run_comparison,
        inputs=[question_input, reference_input],
        outputs=[base_output, tuned_output, reference_output]
    )
    
    gr.Markdown("""
    ---
    ### About
    
    This demo showcases the effect of fine-tuning a small language model (TinyLlama 1.1B) on math word problems.
    
    - **Dataset**: [GSM8K](https://huggingface.co/datasets/openai/gsm8k) - Grade School Math 8K
    - **Method**: LoRA (Low-Rank Adaptation)
    - **Training**: 5 epochs on 7,473 examples
    """)

if __name__ == "__main__":
    demo.launch()