V2 / app.py
Vivek16's picture
Update app.py
e2451da verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# app.py snippet
# --- Configuration ---
# ⚠️ YOUR USERNAME IS NOW SET
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU" # <-- Replace 'Vivek16' with your actual username
# Define the instruction template used during fine-tuning (Step 5)
INSTRUCTION_TEMPLATE = "<|system|>\nSolve the following math problem:</s>\n<|user|>\n{}</s>\n<|assistant|>"
# --- Model Loading Function ---
def load_model():
"""Loads the base model and merges the LoRA adapters."""
print("Loading base model...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency on CPU
device_map="cpu"
)
print("Loading and merging PEFT adapters...")
# Load the trained LoRA adapters from your repo
model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID)
# Merge the adapter weights into the base model weights
model = model.merge_and_unload()
model.eval()
# Ensure pad token is set for generation
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model loaded and merged successfully!")
return tokenizer, model
# Load the model outside the prediction function for efficiency
tokenizer, model = load_model()
# --- Prediction Function ---
def generate_response(prompt):
"""Generates a response using the fine-tuned model."""
# 1. Format the user input using the exact chat template
formatted_prompt = INSTRUCTION_TEMPLATE.format(prompt)
# 2. Tokenize the input
inputs = tokenizer(formatted_prompt, return_tensors="pt")
# 3. Generate the response (on CPU)
with torch.no_grad():
output_tokens = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id
)
# 4. Decode the output and strip the prompt
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
# Extract only the assistant's response (everything after the last <|assistant|> tag)
response_start = generated_text.rfind('<|assistant|>')
if response_start != -1:
# Get the text after <|assistant|> and strip the trailing </s>
assistant_response = generated_text[response_start + len('<|assistant|>'):].strip().split('</s>')[0].strip()
else:
assistant_response = "Error: Could not parse model output."
return assistant_response
# --- Gradio Interface ---
title = "Root Math TinyLlama 1.1B - CPU Fine-Tuned"
description = "A CPU-friendly TinyLlama model fine-tuned on the Big-Math-RL-Verified dataset using LoRA."
article = "Model repository: " + ADAPTER_MODEL_ID
gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=5, label="Enter your math problem here:"),
outputs=gr.Textbox(label="Model Answer"),
title=title,
description=description,
article=article,
theme="soft"
).launch()