File size: 3,924 Bytes
8f1d4ac 76572f3 e4cb5da 76572f3 e4cb5da 76572f3 e4cb5da e860614 76572f3 b2c590c 76572f3 e4cb5da 76572f3 e860614 76572f3 e4cb5da 76572f3 e860614 e4cb5da 76572f3 e860614 76572f3 e4cb5da 76572f3 e860614 e4cb5da e860614 e4cb5da e860614 e4cb5da e860614 e4cb5da f94bcb7 e4cb5da afa0209 e4cb5da 76572f3 e4cb5da 76572f3 e4cb5da 76572f3 e860614 e4cb5da e860614 e4cb5da b2c590c e4cb5da e860614 76572f3 e4cb5da e860614 76572f3 e4cb5da 76572f3 e4cb5da 76572f3 8193117 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# --- Configuration (Verified) ---
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Ensure this is correct for your model repository
ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU"
# Define the instruction template components
SYSTEM_INSTRUCTION = "Solve the following math problem:"
USER_TEMPLATE = "<|user|>\n{}</s>"
ASSISTANT_TEMPLATE = "<|assistant|>\n{}</s>"
# --- Model Loading Function ---
def load_model():
"""Loads the base model and merges the LoRA adapters."""
print("Loading base model...")
# Use bfloat16 for efficiency on CPU
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
print("Loading and merging PEFT adapters...")
# Load the trained LoRA adapters
model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID)
model = model.merge_and_unload()
model.eval()
# Ensure pad token is set for generation
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model loaded and merged successfully!")
return tokenizer, model
# Load the model outside the prediction function for efficiency
tokenizer, model = load_model()
# --- Prediction Function for gr.ChatInterface ---
def generate_response(message, history):
"""Generates a response using chat history and the fine-tuned model."""
# 1. Build the full prompt including System Instruction, History, and current Message
# Start with the system instruction
full_prompt = f"<|system|>\n{SYSTEM_INSTRUCTION}</s>\n"
# Append the chat history (if any)
for user_msg, assistant_msg in history:
full_prompt += USER_TEMPLATE.format(user_msg) + "\n"
full_prompt += ASSISTANT_TEMPLATE.format(assistant_msg) + "\n"
# Append the current user message and the start of the assistant's turn
full_prompt += USER_TEMPLATE.format(message) + "\n"
full_prompt += "<|assistant|>\n"
print(f"--- Full Prompt ---\n{full_prompt}")
# 2. Tokenize the input
inputs = tokenizer(full_prompt, return_tensors="pt")
# 3. Generate the response (on CPU)
with torch.no_grad():
output_tokens = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id
)
# 4. Decode the output
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
# 5. Extract only the model's new response
# Find the start of the assistant's turn in the output and everything after it
response_start = generated_text.rfind('<|assistant|>')
if response_start != -1:
# Get the text after <|assistant|> and strip the trailing </s>
raw_response = generated_text[response_start + len('<|assistant|>'):].strip()
assistant_response = raw_response.split('</s>')[0].strip()
else:
assistant_response = "Error: Could not parse model output."
return assistant_response
# --- Gradio Chat Interface ---
title = "Root Math TinyLlama 1.1B - Gemini-Like Chat Demo"
description = "A conversational interface for the CPU-friendly TinyLlama model fine-tuned for math problems. Ask follow-up questions!"
gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(height=500), # Makes the chat history window taller
textbox=gr.Textbox(placeholder="Enter your math problem or follow-up question...", scale=7),
title=title,
description=description,
submit_btn="Ask Model",
clear_btn="Start New Chat",
undo_btn="Undo Last Message",
theme="soft"
).queue().launch()
|