Spaces:
Sleeping
Sleeping
File size: 4,804 Bytes
7316cfa 48cd902 7316cfa 3c6efdb 7316cfa 3c6efdb 7316cfa 3c6efdb 7316cfa 3c6efdb 7316cfa 3c6efdb 7316cfa 17f6861 34068d5 0fcc475 34068d5 0fcc475 34068d5 17f6861 7316cfa 3c6efdb 17f6861 7316cfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from collections import Counter
import re
# --- CONFIGURATION ---
# REPLACE WITH YOUR USERNAME
MODEL_ID = "justhariharan/Qwen2.5-Math-1.5B-Solver"
print(f"⏳ Loading {MODEL_ID}... (CPU Mode)")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
device_map="cpu"
)
# --- HELPER FUNCTIONS ---
def extract_answer(text):
"""Extracts the number after #### or the last number found."""
if not text: return None
if "####" in text:
text = text.split("####")[-1]
pattern = r"(-?[$0-9.,]{1,})"
matches = re.findall(pattern, text)
return matches[-1].replace(",", "").replace("$", "").strip() if matches else None
def format_prompt(current_question, history):
# System Prompt: Friendly Teacher
system_prompt = """<|im_start|>system
You are a patient and friendly math teacher.
1. Solve the problem step-by-step.
2. Explain the 'logic' simply so a student can understand.
3. Always end your final result with '#### Number'.<|im_end|>"""
# Add History (Short Term Memory - Last 1 Turn)
history_context = ""
# --- ROBUST HISTORY CHECK (The Fix) ---
if len(history) > 0:
try:
# Get the last interaction
last_turn = history[-1]
# Ensure it's a list/tuple
if isinstance(last_turn, (list, tuple)):
# Take only the first 2 elements (User, AI) and ignore extra metadata
last_q = last_turn[0]
last_a = last_turn[1]
history_context = f"""
<|im_start|>user
{last_q}<|im_end|>
<|im_start|>assistant
{last_a}<|im_end|>"""
except Exception:
# If history format is weird, just ignore it and continue safely
pass
# Current Input
user_input = f"""
<|im_start|>user
{current_question}<|im_end|>
<|im_start|>assistant
"""
return system_prompt + history_context + user_input
def solve_single(question, history, temperature=0.6):
"""Standard generation."""
try:
prompt = format_prompt(question, history)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=temperature,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "<|im_start|>assistant" in response:
return response.split("<|im_start|>assistant")[-1].strip()
return response
except Exception as e:
return f"Error generating response: {str(e)}"
def solve_majority_vote(question, history):
"""Smart Mode: Generates 3 answers and votes."""
candidates = []
raw_responses = []
for _ in range(3):
# Higher temp for variety
resp = solve_single(question, history, temperature=0.8)
raw_responses.append(resp)
ans = extract_answer(resp)
if ans:
candidates.append(ans)
if not candidates:
return raw_responses[0]
vote_counts = Counter(candidates)
winner_ans = vote_counts.most_common(1)[0][0]
for resp in raw_responses:
if extract_answer(resp) == winner_ans:
return f"🏆 **High Confidence Answer (Verified 3x)**\n\n{resp}"
return raw_responses[0]
# --- MAIN CHAT LOGIC ---
def chat_logic(message, history, smart_mode):
if smart_mode:
return solve_majority_vote(message, history)
else:
return solve_single(message, history)
# --- UI SETUP ---
demo = gr.ChatInterface(
fn=chat_logic,
additional_inputs=[
gr.Checkbox(label="🔥 Enable Smart Mode (Slow but 82% Accurate)", value=False)
],
title="🧮 AI Math Tutor (Qwen-1.5B Fine-Tuned)",
description="""
<b>Portfolio Project:</b> A specialized math solver fine-tuned on GSM8K using LoRA.
<br><br>
<a href='https://colab.research.google.com/github/HariHaran9597/Math-solver/blob/main/Try_Math_Solver.ipynb' target='_blank'>
<img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/>
</a>
<br>
(Click the badge above to run this on a Free GPU for faster speeds!)
<br><br>
<b>⚠️ Performance Note:</b> This current page runs on a slow CPU.
""",
examples=[
["If I have 30 candies and eat 12, then buy 5 more, how many do I have?", False],
["It takes 5 machines 5 minutes to make 5 widgets. How long for 100 machines?", True]
]
)
if __name__ == "__main__":
demo.launch() |