File size: 3,924 Bytes
8f1d4ac
76572f3
 
 
 
e4cb5da
76572f3
e4cb5da
76572f3
 
e4cb5da
 
 
 
e860614
76572f3
b2c590c
76572f3
 
 
e4cb5da
76572f3
 
 
e860614
76572f3
 
 
 
e4cb5da
76572f3
 
 
e860614
e4cb5da
76572f3
 
e860614
76572f3
 
 
e4cb5da
76572f3
 
e860614
e4cb5da
e860614
e4cb5da
e860614
e4cb5da
e860614
e4cb5da
 
 
 
 
 
 
 
 
 
 
 
 
f94bcb7
e4cb5da
afa0209
e4cb5da
 
76572f3
 
 
 
 
e4cb5da
76572f3
e4cb5da
76572f3
e860614
e4cb5da
 
e860614
e4cb5da
 
 
 
 
 
 
b2c590c
e4cb5da
e860614
76572f3
 
 
e4cb5da
 
 
e860614
 
76572f3
e4cb5da
 
76572f3
 
e4cb5da
 
 
76572f3
8193117
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# --- Configuration (Verified) ---
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Ensure this is correct for your model repository
ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU" 

# Define the instruction template components
SYSTEM_INSTRUCTION = "Solve the following math problem:"
USER_TEMPLATE = "<|user|>\n{}</s>"
ASSISTANT_TEMPLATE = "<|assistant|>\n{}</s>"


# --- Model Loading Function ---
def load_model():
    """Loads the base model and merges the LoRA adapters."""
    print("Loading base model...")
    # Use bfloat16 for efficiency on CPU
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID, 
        torch_dtype=torch.bfloat16, 
        device_map="cpu"
    )

    print("Loading and merging PEFT adapters...")
    # Load the trained LoRA adapters
    model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID)
    model = model.merge_and_unload()
    model.eval()
    
    # Ensure pad token is set for generation
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    print("Model loaded and merged successfully!")
    return tokenizer, model

# Load the model outside the prediction function for efficiency
tokenizer, model = load_model()


# --- Prediction Function for gr.ChatInterface ---
def generate_response(message, history):
    """Generates a response using chat history and the fine-tuned model."""
    
    # 1. Build the full prompt including System Instruction, History, and current Message
    
    # Start with the system instruction
    full_prompt = f"<|system|>\n{SYSTEM_INSTRUCTION}</s>\n"
    
    # Append the chat history (if any)
    for user_msg, assistant_msg in history:
        full_prompt += USER_TEMPLATE.format(user_msg) + "\n"
        full_prompt += ASSISTANT_TEMPLATE.format(assistant_msg) + "\n"
        
    # Append the current user message and the start of the assistant's turn
    full_prompt += USER_TEMPLATE.format(message) + "\n"
    full_prompt += "<|assistant|>\n"
    
    print(f"--- Full Prompt ---\n{full_prompt}")

    # 2. Tokenize the input
    inputs = tokenizer(full_prompt, return_tensors="pt")

    # 3. Generate the response (on CPU)
    with torch.no_grad():
        output_tokens = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            pad_token_id=tokenizer.eos_token_id 
        )
    
    # 4. Decode the output
    generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False)
    
    # 5. Extract only the model's new response
    # Find the start of the assistant's turn in the output and everything after it
    response_start = generated_text.rfind('<|assistant|>')
    if response_start != -1:
        # Get the text after <|assistant|> and strip the trailing </s>
        raw_response = generated_text[response_start + len('<|assistant|>'):].strip()
        assistant_response = raw_response.split('</s>')[0].strip()
    else:
        assistant_response = "Error: Could not parse model output."
        
    return assistant_response


# --- Gradio Chat Interface ---
title = "Root Math TinyLlama 1.1B - Gemini-Like Chat Demo"
description = "A conversational interface for the CPU-friendly TinyLlama model fine-tuned for math problems. Ask follow-up questions!"

gr.ChatInterface(
    fn=generate_response,
    chatbot=gr.Chatbot(height=500), # Makes the chat history window taller
    textbox=gr.Textbox(placeholder="Enter your math problem or follow-up question...", scale=7),
    title=title,
    description=description,
    submit_btn="Ask Model",
    clear_btn="Start New Chat",
    undo_btn="Undo Last Message",
    theme="soft"
).queue().launch()