|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU" |
|
|
|
|
|
INSTRUCTION_TEMPLATE = "<|system|>\nSolve the following math problem:</s>\n<|user|>\n{}</s>\n<|assistant|>" |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Loads the base model and merges the LoRA adapters.""" |
|
|
print("Loading base model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cpu" |
|
|
) |
|
|
|
|
|
print("Loading and merging PEFT adapters...") |
|
|
|
|
|
model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID) |
|
|
|
|
|
model = model.merge_and_unload() |
|
|
model.eval() |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print("Model loaded and merged successfully!") |
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
|
|
|
def generate_response(prompt): |
|
|
"""Generates a response using the fine-tuned model.""" |
|
|
|
|
|
formatted_prompt = INSTRUCTION_TEMPLATE.format(prompt) |
|
|
|
|
|
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_tokens = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_k=50, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
response_start = generated_text.rfind('<|assistant|>') |
|
|
if response_start != -1: |
|
|
|
|
|
assistant_response = generated_text[response_start + len('<|assistant|>'):].strip().split('</s>')[0].strip() |
|
|
else: |
|
|
assistant_response = "Error: Could not parse model output." |
|
|
|
|
|
return assistant_response |
|
|
|
|
|
|
|
|
title = "Root Math TinyLlama 1.1B - CPU Fine-Tuned" |
|
|
description = "A CPU-friendly TinyLlama model fine-tuned on the Big-Math-RL-Verified dataset using LoRA." |
|
|
article = "Model repository: " + ADAPTER_MODEL_ID |
|
|
|
|
|
gr.Interface( |
|
|
fn=generate_response, |
|
|
inputs=gr.Textbox(lines=5, label="Enter your math problem here:"), |
|
|
outputs=gr.Textbox(label="Model Answer"), |
|
|
title=title, |
|
|
description=description, |
|
|
article=article, |
|
|
theme="soft" |
|
|
).launch() |