import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from collections import Counter
import re
# --- CONFIGURATION ---
# REPLACE WITH YOUR USERNAME
MODEL_ID = "justhariharan/Qwen2.5-Math-1.5B-Solver"
print(f"⏳ Loading {MODEL_ID}... (CPU Mode)")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
device_map="cpu"
)
# --- HELPER FUNCTIONS ---
def extract_answer(text):
"""Extracts the number after #### or the last number found."""
if not text: return None
if "####" in text:
text = text.split("####")[-1]
pattern = r"(-?[$0-9.,]{1,})"
matches = re.findall(pattern, text)
return matches[-1].replace(",", "").replace("$", "").strip() if matches else None
def format_prompt(current_question, history):
# System Prompt: Friendly Teacher
system_prompt = """<|im_start|>system
You are a patient and friendly math teacher.
1. Solve the problem step-by-step.
2. Explain the 'logic' simply so a student can understand.
3. Always end your final result with '#### Number'.<|im_end|>"""
# Add History (Short Term Memory - Last 1 Turn)
history_context = ""
# --- ROBUST HISTORY CHECK (The Fix) ---
if len(history) > 0:
try:
# Get the last interaction
last_turn = history[-1]
# Ensure it's a list/tuple
if isinstance(last_turn, (list, tuple)):
# Take only the first 2 elements (User, AI) and ignore extra metadata
last_q = last_turn[0]
last_a = last_turn[1]
history_context = f"""
<|im_start|>user
{last_q}<|im_end|>
<|im_start|>assistant
{last_a}<|im_end|>"""
except Exception:
# If history format is weird, just ignore it and continue safely
pass
# Current Input
user_input = f"""
<|im_start|>user
{current_question}<|im_end|>
<|im_start|>assistant
"""
return system_prompt + history_context + user_input
def solve_single(question, history, temperature=0.6):
"""Standard generation."""
try:
prompt = format_prompt(question, history)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=temperature,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "<|im_start|>assistant" in response:
return response.split("<|im_start|>assistant")[-1].strip()
return response
except Exception as e:
return f"Error generating response: {str(e)}"
def solve_majority_vote(question, history):
"""Smart Mode: Generates 3 answers and votes."""
candidates = []
raw_responses = []
for _ in range(3):
# Higher temp for variety
resp = solve_single(question, history, temperature=0.8)
raw_responses.append(resp)
ans = extract_answer(resp)
if ans:
candidates.append(ans)
if not candidates:
return raw_responses[0]
vote_counts = Counter(candidates)
winner_ans = vote_counts.most_common(1)[0][0]
for resp in raw_responses:
if extract_answer(resp) == winner_ans:
return f"🏆 **High Confidence Answer (Verified 3x)**\n\n{resp}"
return raw_responses[0]
# --- MAIN CHAT LOGIC ---
def chat_logic(message, history, smart_mode):
if smart_mode:
return solve_majority_vote(message, history)
else:
return solve_single(message, history)
# --- UI SETUP ---
demo = gr.ChatInterface(
fn=chat_logic,
additional_inputs=[
gr.Checkbox(label="🔥 Enable Smart Mode (Slow but 82% Accurate)", value=False)
],
title="🧮 AI Math Tutor (Qwen-1.5B Fine-Tuned)",
description="""
Portfolio Project: A specialized math solver fine-tuned on GSM8K using LoRA.
(Click the badge above to run this on a Free GPU for faster speeds!)
⚠️ Performance Note: This current page runs on a slow CPU.
""",
examples=[
["If I have 30 candies and eat 12, then buy 5 more, how many do I have?", False],
["It takes 5 machines 5 minutes to make 5 widgets. How long for 100 machines?", True]
]
)
if __name__ == "__main__":
demo.launch()