import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # app.py snippet # --- Configuration --- # ⚠️ YOUR USERNAME IS NOW SET BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ADAPTER_MODEL_ID = "Vivek16/Root_Math-TinyLlama-CPU" # <-- Replace 'Vivek16' with your actual username # Define the instruction template used during fine-tuning (Step 5) INSTRUCTION_TEMPLATE = "<|system|>\nSolve the following math problem:\n<|user|>\n{}\n<|assistant|>" # --- Model Loading Function --- def load_model(): """Loads the base model and merges the LoRA adapters.""" print("Loading base model...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency on CPU device_map="cpu" ) print("Loading and merging PEFT adapters...") # Load the trained LoRA adapters from your repo model = PeftModel.from_pretrained(model, ADAPTER_MODEL_ID) # Merge the adapter weights into the base model weights model = model.merge_and_unload() model.eval() # Ensure pad token is set for generation if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Model loaded and merged successfully!") return tokenizer, model # Load the model outside the prediction function for efficiency tokenizer, model = load_model() # --- Prediction Function --- def generate_response(prompt): """Generates a response using the fine-tuned model.""" # 1. Format the user input using the exact chat template formatted_prompt = INSTRUCTION_TEMPLATE.format(prompt) # 2. Tokenize the input inputs = tokenizer(formatted_prompt, return_tensors="pt") # 3. Generate the response (on CPU) with torch.no_grad(): output_tokens = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, pad_token_id=tokenizer.eos_token_id ) # 4. Decode the output and strip the prompt generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=False) # Extract only the assistant's response (everything after the last <|assistant|> tag) response_start = generated_text.rfind('<|assistant|>') if response_start != -1: # Get the text after <|assistant|> and strip the trailing assistant_response = generated_text[response_start + len('<|assistant|>'):].strip().split('')[0].strip() else: assistant_response = "Error: Could not parse model output." return assistant_response # --- Gradio Interface --- title = "Root Math TinyLlama 1.1B - CPU Fine-Tuned" description = "A CPU-friendly TinyLlama model fine-tuned on the Big-Math-RL-Verified dataset using LoRA." article = "Model repository: " + ADAPTER_MODEL_ID gr.Interface( fn=generate_response, inputs=gr.Textbox(lines=5, label="Enter your math problem here:"), outputs=gr.Textbox(label="Model Answer"), title=title, description=description, article=article, theme="soft" ).launch()