|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
|
|
|
ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU" |
|
|
|
|
|
|
|
|
SYSTEM_INSTRUCTION = "Solve the following math problem:" |
|
|
USER_TEMPLATE = "<|user|>\n{}</s>" |
|
|
ASSISTANT_TEMPLATE = "<|assistant|>\n{}</s>" |
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Loads the base model and merges the LoRA adapters.""" |
|
|
print("Loading base model...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cpu" |
|
|
) |
|
|
|
|
|
print("Loading and merging PEFT adapters...") |
|
|
|
|
|
model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID) |
|
|
model = model.merge_and_unload() |
|
|
model.eval() |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print("Model loaded and merged successfully!") |
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
|
|
|
|
|
|
def generate_response(message, history): |
|
|
"""Generates a response using chat history and the fine-tuned model.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_prompt = f"<|system|>\n{SYSTEM_INSTRUCTION}</s>\n" |
|
|
|
|
|
|
|
|
for user_msg, assistant_msg in history: |
|
|
full_prompt += USER_TEMPLATE.format(user_msg) + "\n" |
|
|
full_prompt += ASSISTANT_TEMPLATE.format(assistant_msg) + "\n" |
|
|
|
|
|
|
|
|
full_prompt += USER_TEMPLATE.format(message) + "\n" |
|
|
full_prompt += "<|assistant|>\n" |
|
|
|
|
|
print(f"--- Full Prompt ---\n{full_prompt}") |
|
|
|
|
|
|
|
|
inputs = tokenizer(full_prompt, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_tokens = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_k=50, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
response_start = generated_text.rfind('<|assistant|>') |
|
|
if response_start != -1: |
|
|
|
|
|
raw_response = generated_text[response_start + len('<|assistant|>'):].strip() |
|
|
assistant_response = raw_response.split('</s>')[0].strip() |
|
|
else: |
|
|
assistant_response = "Error: Could not parse model output." |
|
|
|
|
|
return assistant_response |
|
|
|
|
|
|
|
|
|
|
|
title = "Root Math TinyLlama 1.1B - Gemini-Like Chat Demo" |
|
|
description = "A conversational interface for the CPU-friendly TinyLlama model fine-tuned for math problems. Ask follow-up questions!" |
|
|
|
|
|
gr.ChatInterface( |
|
|
fn=generate_response, |
|
|
chatbot=gr.Chatbot(height=500), |
|
|
textbox=gr.Textbox(placeholder="Enter your math problem or follow-up question...", scale=7), |
|
|
title=title, |
|
|
description=description, |
|
|
submit_btn="Ask Model", |
|
|
clear_btn="Start New Chat", |
|
|
undo_btn="Undo Last Message", |
|
|
theme="soft" |
|
|
).queue().launch() |
|
|
|