File size: 4,804 Bytes
7316cfa
 
 
 
 
 
 
 
48cd902
7316cfa
 
 
 
 
 
 
 
 
 
 
 
3c6efdb
7316cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c6efdb
 
7316cfa
3c6efdb
 
 
 
 
 
 
 
 
 
 
7316cfa
 
 
 
3c6efdb
 
 
7316cfa
 
 
 
 
 
 
 
 
 
 
3c6efdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7316cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17f6861
 
 
34068d5
0fcc475
 
34068d5
 
0fcc475
34068d5
17f6861
7316cfa
 
3c6efdb
17f6861
7316cfa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from collections import Counter
import re

# --- CONFIGURATION ---
# REPLACE WITH YOUR USERNAME
MODEL_ID = "justhariharan/Qwen2.5-Math-1.5B-Solver"

print(f"⏳ Loading {MODEL_ID}... (CPU Mode)")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    device_map="cpu"
)

# --- HELPER FUNCTIONS ---
def extract_answer(text):
    """Extracts the number after #### or the last number found."""
    if not text: return None
    if "####" in text:
        text = text.split("####")[-1]
    pattern = r"(-?[$0-9.,]{1,})"
    matches = re.findall(pattern, text)
    return matches[-1].replace(",", "").replace("$", "").strip() if matches else None

def format_prompt(current_question, history):
    # System Prompt: Friendly Teacher
    system_prompt = """<|im_start|>system
You are a patient and friendly math teacher. 
1. Solve the problem step-by-step.
2. Explain the 'logic' simply so a student can understand.
3. Always end your final result with '#### Number'.<|im_end|>"""

    # Add History (Short Term Memory - Last 1 Turn)
    history_context = ""
    
    # --- ROBUST HISTORY CHECK (The Fix) ---
    if len(history) > 0:
        try:
            # Get the last interaction
            last_turn = history[-1]
            
            # Ensure it's a list/tuple
            if isinstance(last_turn, (list, tuple)):
                # Take only the first 2 elements (User, AI) and ignore extra metadata
                last_q = last_turn[0]
                last_a = last_turn[1]
                
                history_context = f"""
<|im_start|>user
{last_q}<|im_end|>
<|im_start|>assistant
{last_a}<|im_end|>"""
        except Exception:
            # If history format is weird, just ignore it and continue safely
            pass

    # Current Input
    user_input = f"""
<|im_start|>user
{current_question}<|im_end|>
<|im_start|>assistant
"""
    return system_prompt + history_context + user_input

def solve_single(question, history, temperature=0.6):
    """Standard generation."""
    try:
        prompt = format_prompt(question, history)
        inputs = tokenizer(prompt, return_tensors="pt")
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs, 
                max_new_tokens=512,
                temperature=temperature,
                do_sample=True
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "<|im_start|>assistant" in response:
            return response.split("<|im_start|>assistant")[-1].strip()
        return response
    except Exception as e:
        return f"Error generating response: {str(e)}"

def solve_majority_vote(question, history):
    """Smart Mode: Generates 3 answers and votes."""
    candidates = []
    raw_responses = []
    
    for _ in range(3):
        # Higher temp for variety
        resp = solve_single(question, history, temperature=0.8) 
        raw_responses.append(resp)
        ans = extract_answer(resp)
        if ans:
            candidates.append(ans)
    
    if not candidates:
        return raw_responses[0]
        
    vote_counts = Counter(candidates)
    winner_ans = vote_counts.most_common(1)[0][0]
    
    for resp in raw_responses:
        if extract_answer(resp) == winner_ans:
            return f"🏆 **High Confidence Answer (Verified 3x)**\n\n{resp}"
            
    return raw_responses[0]

# --- MAIN CHAT LOGIC ---
def chat_logic(message, history, smart_mode):
    if smart_mode:
        return solve_majority_vote(message, history)
    else:
        return solve_single(message, history)

# --- UI SETUP ---
demo = gr.ChatInterface(
    fn=chat_logic,
    additional_inputs=[
        gr.Checkbox(label="🔥 Enable Smart Mode (Slow but 82% Accurate)", value=False)
    ],
    title="🧮 AI Math Tutor (Qwen-1.5B Fine-Tuned)",
    description="""
    <b>Portfolio Project:</b> A specialized math solver fine-tuned on GSM8K using LoRA.
    <br><br>
    <a href='https://colab.research.google.com/github/HariHaran9597/Math-solver/blob/main/Try_Math_Solver.ipynb' target='_blank'>
        <img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/>
    </a>
    <br>
    (Click the badge above to run this on a Free GPU for faster speeds!)
    <br><br>
    <b>⚠️ Performance Note:</b> This current page runs on a slow CPU.
    """,
    examples=[
        ["If I have 30 candies and eat 12, then buy 5 more, how many do I have?", False],
        ["It takes 5 machines 5 minutes to make 5 widgets. How long for 100 machines?", True]
    ]
)

if __name__ == "__main__":
    demo.launch()