nlac
fixes
17e3e06
"""
Gradio UI for testing the Multiplication LoRA model.
Deployable to HuggingFace Spaces.
"""
import os
from dotenv import load_dotenv
load_dotenv()
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# Configuration - can be overridden by environment variables
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
LORA_ADAPTER = os.environ.get(
"LORA_ADAPTER", None
) # HF Hub path, e.g., "username/lora-multiplicator"
SYSTEM_PROMPT = os.environ.get(
"SYSTEM_PROMPT",
"You are a helpful calculator that multiplies two numbers. Answer only a number. No preamble.",
)
# Global model cache - base and lora need separate model instances
# because PeftModel.from_pretrained wraps the model in place
_cache = {
"base_model": None,
"lora_model": None,
"tokenizer": None,
"lora_path": None,
}
def get_lora_path():
"""Determine the LoRA adapter path."""
if _cache["lora_path"] is not None:
return _cache["lora_path"]
lora_path = LORA_ADAPTER
if lora_path is None:
# Try local path for development
local_path = os.path.join(
os.path.dirname(__file__), "output", "lora-multiplicator", "final"
)
if os.path.exists(local_path):
lora_path = local_path
else:
raise ValueError(
"No LoRA adapter found. Set LORA_ADAPTER environment variable "
"or place adapter in output/lora-multiplicator/final/"
)
_cache["lora_path"] = lora_path
return lora_path
def load_tokenizer():
"""Load and cache the tokenizer."""
if _cache["tokenizer"] is None:
print(f"Loading tokenizer from {BASE_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
_cache["tokenizer"] = tokenizer
return _cache["tokenizer"]
def load_base_model():
"""Load and cache the base model (without LoRA)."""
if _cache["base_model"] is None:
print(f"Loading base model (no LoRA): {BASE_MODEL}...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
)
model.eval()
_cache["base_model"] = model
print("Base model loaded successfully!")
return _cache["base_model"]
def load_lora_model():
"""Load and cache the model with LoRA adapter (separate instance from base)."""
if _cache["lora_model"] is None:
# Load a NEW base model instance for LoRA (don't reuse the base model)
# This is important because PeftModel wraps the model in place
print(f"Loading base model for LoRA: {BASE_MODEL}...")
base_for_lora = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
)
lora_path = get_lora_path()
print(f"Loading LoRA adapter from: {lora_path}...")
model = PeftModel.from_pretrained(base_for_lora, lora_path)
model.eval()
_cache["lora_model"] = model
print("LoRA model loaded successfully!")
return _cache["lora_model"]
def generate_answer(number: int, use_lora: bool) -> tuple[str, str, bool]:
"""
Generate multiplication answer.
Args:
number: The 6-digit number to multiply by 7
use_lora: Whether to use the LoRA adapter
Returns:
Tuple of (predicted_answer, expected_answer, is_correct)
"""
print(f"use_lora: {use_lora}")
tokenizer = load_tokenizer()
model = load_lora_model() if use_lora else load_base_model()
# Calculate expected result
expected = number * 7
# Format as chat message
query = f"{number} * 7"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": query},
]
# Apply chat template
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=32,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
# Decode only the generated part
generated_ids = outputs[0][inputs["input_ids"].shape[1] :]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
# Try to extract numeric prediction
import re
predicted_numbers = re.findall(r"\d+", answer)
if predicted_numbers:
predicted = int(predicted_numbers[0])
is_correct = predicted == expected
return str(predicted), str(expected), is_correct
else:
return answer, str(expected), False
def predict(number_input: str, use_lora: bool) -> tuple[str, str]:
"""
Main prediction function for Gradio interface.
Returns formatted HTML for predicted and expected values.
"""
# Validate input
try:
number = int(number_input.strip())
except ValueError:
return (
'<span style="color: red; font-size: 24px; font-weight: bold;">Invalid input</span>',
'<span style="color: gray; font-size: 24px;">-</span>',
)
if not (100000 <= number <= 999999):
return (
'<span style="color: red; font-size: 24px; font-weight: bold;">Must be 6 digits (100000-999999)</span>',
'<span style="color: gray; font-size: 24px;">-</span>',
)
# Generate prediction
predicted, expected, is_correct = generate_answer(number, use_lora)
# Format output with colors
if is_correct:
predicted_html = f'<span style="color: green; font-size: 32px; font-weight: bold;">{predicted}</span>'
else:
predicted_html = f'<span style="color: red; font-size: 32px; font-weight: bold;">{predicted}</span>'
expected_html = f'<span style="color: green; font-size: 32px; font-weight: bold;">{expected}</span>'
return predicted_html, expected_html
def create_demo():
"""Create the Gradio demo interface."""
with gr.Blocks(title="Multiplication LoRA Demo") as demo:
gr.Markdown(
"""
# Multiplication LoRA Demo
A fun experiment in LoRA fine-tuning on a tiny model using a simple arithmetic task (multiplication by 7).
**LoRA Adapter**: [nlac/multiplication-lora-demo-adapter](https://huggingface.co/nlac/multiplication-lora-demo-adapter)
"""
)
with gr.Row():
with gr.Column(scale=2):
number_input = gr.Textbox(
label="Enter a 6-digit number to multiply it by 7",
placeholder="e.g. 123456",
max_lines=1,
)
use_lora = gr.Checkbox(
label="Use LoRA adapter",
value=True,
info="Uncheck to see base model performance (hint: it's much worse!)",
)
submit_btn = gr.Button("Send", variant="primary", size="lg")
with gr.Column(scale=3):
with gr.Row():
with gr.Column():
gr.Markdown("### Predicted")
predicted_output = gr.HTML(
value='<span style="color: gray; font-size: 24px;">-</span>',
elem_classes=["result-box", "predicted-box"],
)
with gr.Column():
gr.Markdown("### Expected")
expected_output = gr.HTML(
value='<span style="color: gray; font-size: 24px;">-</span>',
elem_classes=["result-box", "expected-box"],
)
# Wire up the interface
submit_btn.click(
fn=predict,
inputs=[number_input, use_lora],
outputs=[predicted_output, expected_output],
)
# Also trigger on Enter key
number_input.submit(
fn=predict,
inputs=[number_input, use_lora],
outputs=[predicted_output, expected_output],
)
gr.Examples(
examples=[
["123456", True],
["999999", False],
["100000", True],
["123456", False],
],
inputs=[number_input, use_lora],
outputs=[predicted_output, expected_output],
fn=predict,
cache_examples=False,
)
gr.Markdown(
"""
## Results
| Model | Accuracy |
|-------|----------|
| Base Qwen2.5-0.5B | ~3% |
| With LoRA adapter | ~94% |
The LoRA adapter adds only **~2MB of parameters** but improves accuracy by **31x**!
"""
)
gr.Markdown(
"""
## Why this project?
This is an experiment to learn LoRA fine-tuning. Arithmetic makes an ideal test case:
- **Easy data generation** - examples generated programmatically, no manual labeling
- **Objective evaluation** - answers are either correct or wrong
The training completed in under an hour on a consumer laptop, using 20,000 generated examples using 6-digit numbers, in 3 epochs: that means 2% of all 6-digit numbers used for training. Increasing the number of samples and the epochs would likely result even higher accuracy.
A typical training example was: [{"role":"system", "assistant": "You are a helpful calculator that multiplies two numbers. Answer only a number. No preamble."}, {"role": "user", "content": "772694* 7?"}, {"role": "assistant", "content": "5408858"}
"""
)
return demo
# Create and launch the demo
demo = create_demo()
if __name__ == "__main__":
demo.launch(
ssr_mode=False,
theme=gr.themes.Soft(),
css="""
.result-box {
padding: 20px;
border-radius: 10px;
text-align: center;
min-height: 80px;
}
.predicted-box, .expected-box {
background-color: #f0f0f0;
}
""",
)