vvv / app.py
Vivek16's picture
Update app.py
8193117 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# --- Configuration (Verified) ---
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Ensure this is correct for your model repository
ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU"
# Define the instruction template components
SYSTEM_INSTRUCTION = "Solve the following math problem:"
USER_TEMPLATE = "<|user|>\n{}</s>"
ASSISTANT_TEMPLATE = "<|assistant|>\n{}</s>"
# --- Model Loading Function ---
def load_model():
"""Loads the base model and merges the LoRA adapters."""
print("Loading base model...")
# Use bfloat16 for efficiency on CPU
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
print("Loading and merging PEFT adapters...")
# Load the trained LoRA adapters
model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID)
model = model.merge_and_unload()
model.eval()
# Ensure pad token is set for generation
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model loaded and merged successfully!")
return tokenizer, model
# Load the model outside the prediction function for efficiency
tokenizer, model = load_model()
# --- Prediction Function for gr.ChatInterface ---
def generate_response(message, history):
"""Generates a response using chat history and the fine-tuned model."""
# 1. Build the full prompt including System Instruction, History, and current Message
# Start with the system instruction
full_prompt = f"<|system|>\n{SYSTEM_INSTRUCTION}</s>\n"
# Append the chat history (if any)
for user_msg, assistant_msg in history:
full_prompt += USER_TEMPLATE.format(user_msg) + "\n"
full_prompt += ASSISTANT_TEMPLATE.format(assistant_msg) + "\n"
# Append the current user message and the start of the assistant's turn
full_prompt += USER_TEMPLATE.format(message) + "\n"
full_prompt += "<|assistant|>\n"
print(f"--- Full Prompt ---\n{full_prompt}")
# 2. Tokenize the input
inputs = tokenizer(full_prompt, return_tensors="pt")
# 3. Generate the response (on CPU)
with torch.no_grad():
output_tokens = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id
)
# 4. Decode the output
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
# 5. Extract only the model's new response
# Find the start of the assistant's turn in the output and everything after it
response_start = generated_text.rfind('<|assistant|>')
if response_start != -1:
# Get the text after <|assistant|> and strip the trailing </s>
raw_response = generated_text[response_start + len('<|assistant|>'):].strip()
assistant_response = raw_response.split('</s>')[0].strip()
else:
assistant_response = "Error: Could not parse model output."
return assistant_response
# --- Gradio Chat Interface ---
title = "Root Math TinyLlama 1.1B - Gemini-Like Chat Demo"
description = "A conversational interface for the CPU-friendly TinyLlama model fine-tuned for math problems. Ask follow-up questions!"
gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(height=500), # Makes the chat history window taller
textbox=gr.Textbox(placeholder="Enter your math problem or follow-up question...", scale=7),
title=title,
description=description,
submit_btn="Ask Model",
clear_btn="Start New Chat",
undo_btn="Undo Last Message",
theme="soft"
).queue().launch()