import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # --- Configuration (Verified) --- BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Ensure this is correct for your model repository ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU" # Define the instruction template components SYSTEM_INSTRUCTION = "Solve the following math problem:" USER_TEMPLATE = "<|user|>\n{}" ASSISTANT_TEMPLATE = "<|assistant|>\n{}" # --- Model Loading Function --- def load_model(): """Loads the base model and merges the LoRA adapters.""" print("Loading base model...") # Use bfloat16 for efficiency on CPU tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.bfloat16, device_map="cpu" ) print("Loading and merging PEFT adapters...") # Load the trained LoRA adapters model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID) model = model.merge_and_unload() model.eval() # Ensure pad token is set for generation if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Model loaded and merged successfully!") return tokenizer, model # Load the model outside the prediction function for efficiency tokenizer, model = load_model() # --- Prediction Function for gr.ChatInterface --- def generate_response(message, history): """Generates a response using chat history and the fine-tuned model.""" # 1. Build the full prompt including System Instruction, History, and current Message # Start with the system instruction full_prompt = f"<|system|>\n{SYSTEM_INSTRUCTION}\n" # Append the chat history (if any) for user_msg, assistant_msg in history: full_prompt += USER_TEMPLATE.format(user_msg) + "\n" full_prompt += ASSISTANT_TEMPLATE.format(assistant_msg) + "\n" # Append the current user message and the start of the assistant's turn full_prompt += USER_TEMPLATE.format(message) + "\n" full_prompt += "<|assistant|>\n" print(f"--- Full Prompt ---\n{full_prompt}") # 2. Tokenize the input inputs = tokenizer(full_prompt, return_tensors="pt") # 3. Generate the response (on CPU) with torch.no_grad(): output_tokens = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, pad_token_id=tokenizer.eos_token_id ) # 4. Decode the output generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False) # 5. Extract only the model's new response # Find the start of the assistant's turn in the output and everything after it response_start = generated_text.rfind('<|assistant|>') if response_start != -1: # Get the text after <|assistant|> and strip the trailing raw_response = generated_text[response_start + len('<|assistant|>'):].strip() assistant_response = raw_response.split('')[0].strip() else: assistant_response = "Error: Could not parse model output." return assistant_response # --- Gradio Chat Interface --- title = "Root Math TinyLlama 1.1B - Gemini-Like Chat Demo" description = "A conversational interface for the CPU-friendly TinyLlama model fine-tuned for math problems. Ask follow-up questions!" gr.ChatInterface( fn=generate_response, chatbot=gr.Chatbot(height=500), # Makes the chat history window taller textbox=gr.Textbox(placeholder="Enter your math problem or follow-up question...", scale=7), title=title, description=description, submit_btn="Ask Model", clear_btn="Start New Chat", undo_btn="Undo Last Message", theme="soft" ).queue().launch()