Spaces:
Running
Running
| """ | |
| Gradio UI for testing the Multiplication LoRA model. | |
| Deployable to HuggingFace Spaces. | |
| """ | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| # Configuration - can be overridden by environment variables | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") | |
| LORA_ADAPTER = os.environ.get( | |
| "LORA_ADAPTER", None | |
| ) # HF Hub path, e.g., "username/lora-multiplicator" | |
| SYSTEM_PROMPT = os.environ.get( | |
| "SYSTEM_PROMPT", | |
| "You are a helpful calculator that multiplies two numbers. Answer only a number. No preamble.", | |
| ) | |
| # Global model cache - base and lora need separate model instances | |
| # because PeftModel.from_pretrained wraps the model in place | |
| _cache = { | |
| "base_model": None, | |
| "lora_model": None, | |
| "tokenizer": None, | |
| "lora_path": None, | |
| } | |
| def get_lora_path(): | |
| """Determine the LoRA adapter path.""" | |
| if _cache["lora_path"] is not None: | |
| return _cache["lora_path"] | |
| lora_path = LORA_ADAPTER | |
| if lora_path is None: | |
| # Try local path for development | |
| local_path = os.path.join( | |
| os.path.dirname(__file__), "output", "lora-multiplicator", "final" | |
| ) | |
| if os.path.exists(local_path): | |
| lora_path = local_path | |
| else: | |
| raise ValueError( | |
| "No LoRA adapter found. Set LORA_ADAPTER environment variable " | |
| "or place adapter in output/lora-multiplicator/final/" | |
| ) | |
| _cache["lora_path"] = lora_path | |
| return lora_path | |
| def load_tokenizer(): | |
| """Load and cache the tokenizer.""" | |
| if _cache["tokenizer"] is None: | |
| print(f"Loading tokenizer from {BASE_MODEL}...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| _cache["tokenizer"] = tokenizer | |
| return _cache["tokenizer"] | |
| def load_base_model(): | |
| """Load and cache the base model (without LoRA).""" | |
| if _cache["base_model"] is None: | |
| print(f"Loading base model (no LoRA): {BASE_MODEL}...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| _cache["base_model"] = model | |
| print("Base model loaded successfully!") | |
| return _cache["base_model"] | |
| def load_lora_model(): | |
| """Load and cache the model with LoRA adapter (separate instance from base).""" | |
| if _cache["lora_model"] is None: | |
| # Load a NEW base model instance for LoRA (don't reuse the base model) | |
| # This is important because PeftModel wraps the model in place | |
| print(f"Loading base model for LoRA: {BASE_MODEL}...") | |
| base_for_lora = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| lora_path = get_lora_path() | |
| print(f"Loading LoRA adapter from: {lora_path}...") | |
| model = PeftModel.from_pretrained(base_for_lora, lora_path) | |
| model.eval() | |
| _cache["lora_model"] = model | |
| print("LoRA model loaded successfully!") | |
| return _cache["lora_model"] | |
| def generate_answer(number: int, use_lora: bool) -> tuple[str, str, bool]: | |
| """ | |
| Generate multiplication answer. | |
| Args: | |
| number: The 6-digit number to multiply by 7 | |
| use_lora: Whether to use the LoRA adapter | |
| Returns: | |
| Tuple of (predicted_answer, expected_answer, is_correct) | |
| """ | |
| print(f"use_lora: {use_lora}") | |
| tokenizer = load_tokenizer() | |
| model = load_lora_model() if use_lora else load_base_model() | |
| # Calculate expected result | |
| expected = number * 7 | |
| # Format as chat message | |
| query = f"{number} * 7" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": query}, | |
| ] | |
| # Apply chat template | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Tokenize | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=32, | |
| do_sample=False, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Decode only the generated part | |
| generated_ids = outputs[0][inputs["input_ids"].shape[1] :] | |
| answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| # Try to extract numeric prediction | |
| import re | |
| predicted_numbers = re.findall(r"\d+", answer) | |
| if predicted_numbers: | |
| predicted = int(predicted_numbers[0]) | |
| is_correct = predicted == expected | |
| return str(predicted), str(expected), is_correct | |
| else: | |
| return answer, str(expected), False | |
| def predict(number_input: str, use_lora: bool) -> tuple[str, str]: | |
| """ | |
| Main prediction function for Gradio interface. | |
| Returns formatted HTML for predicted and expected values. | |
| """ | |
| # Validate input | |
| try: | |
| number = int(number_input.strip()) | |
| except ValueError: | |
| return ( | |
| '<span style="color: red; font-size: 24px; font-weight: bold;">Invalid input</span>', | |
| '<span style="color: gray; font-size: 24px;">-</span>', | |
| ) | |
| if not (100000 <= number <= 999999): | |
| return ( | |
| '<span style="color: red; font-size: 24px; font-weight: bold;">Must be 6 digits (100000-999999)</span>', | |
| '<span style="color: gray; font-size: 24px;">-</span>', | |
| ) | |
| # Generate prediction | |
| predicted, expected, is_correct = generate_answer(number, use_lora) | |
| # Format output with colors | |
| if is_correct: | |
| predicted_html = f'<span style="color: green; font-size: 32px; font-weight: bold;">{predicted}</span>' | |
| else: | |
| predicted_html = f'<span style="color: red; font-size: 32px; font-weight: bold;">{predicted}</span>' | |
| expected_html = f'<span style="color: green; font-size: 32px; font-weight: bold;">{expected}</span>' | |
| return predicted_html, expected_html | |
| def create_demo(): | |
| """Create the Gradio demo interface.""" | |
| with gr.Blocks(title="Multiplication LoRA Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # Multiplication LoRA Demo | |
| A fun experiment in LoRA fine-tuning on a tiny model using a simple arithmetic task (multiplication by 7). | |
| **LoRA Adapter**: [nlac/multiplication-lora-demo-adapter](https://huggingface.co/nlac/multiplication-lora-demo-adapter) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| number_input = gr.Textbox( | |
| label="Enter a 6-digit number to multiply it by 7", | |
| placeholder="e.g. 123456", | |
| max_lines=1, | |
| ) | |
| use_lora = gr.Checkbox( | |
| label="Use LoRA adapter", | |
| value=True, | |
| info="Uncheck to see base model performance (hint: it's much worse!)", | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Predicted") | |
| predicted_output = gr.HTML( | |
| value='<span style="color: gray; font-size: 24px;">-</span>', | |
| elem_classes=["result-box", "predicted-box"], | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### Expected") | |
| expected_output = gr.HTML( | |
| value='<span style="color: gray; font-size: 24px;">-</span>', | |
| elem_classes=["result-box", "expected-box"], | |
| ) | |
| # Wire up the interface | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[number_input, use_lora], | |
| outputs=[predicted_output, expected_output], | |
| ) | |
| # Also trigger on Enter key | |
| number_input.submit( | |
| fn=predict, | |
| inputs=[number_input, use_lora], | |
| outputs=[predicted_output, expected_output], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["123456", True], | |
| ["999999", False], | |
| ["100000", True], | |
| ["123456", False], | |
| ], | |
| inputs=[number_input, use_lora], | |
| outputs=[predicted_output, expected_output], | |
| fn=predict, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| ## Results | |
| | Model | Accuracy | | |
| |-------|----------| | |
| | Base Qwen2.5-0.5B | ~3% | | |
| | With LoRA adapter | ~94% | | |
| The LoRA adapter adds only **~2MB of parameters** but improves accuracy by **31x**! | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| ## Why this project? | |
| This is an experiment to learn LoRA fine-tuning. Arithmetic makes an ideal test case: | |
| - **Easy data generation** - examples generated programmatically, no manual labeling | |
| - **Objective evaluation** - answers are either correct or wrong | |
| The training completed in under an hour on a consumer laptop, using 20,000 generated examples using 6-digit numbers, in 3 epochs: that means 2% of all 6-digit numbers used for training. Increasing the number of samples and the epochs would likely result even higher accuracy. | |
| A typical training example was: [{"role":"system", "assistant": "You are a helpful calculator that multiplies two numbers. Answer only a number. No preamble."}, {"role": "user", "content": "772694* 7?"}, {"role": "assistant", "content": "5408858"} | |
| """ | |
| ) | |
| return demo | |
| # Create and launch the demo | |
| demo = create_demo() | |
| if __name__ == "__main__": | |
| demo.launch( | |
| ssr_mode=False, | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .result-box { | |
| padding: 20px; | |
| border-radius: 10px; | |
| text-align: center; | |
| min-height: 80px; | |
| } | |
| .predicted-box, .expected-box { | |
| background-color: #f0f0f0; | |
| } | |
| """, | |
| ) | |