Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from collections import Counter | |
| import re | |
| # --- CONFIGURATION --- | |
| # REPLACE WITH YOUR USERNAME | |
| MODEL_ID = "justhariharan/Qwen2.5-Math-1.5B-Solver" | |
| print(f"⏳ Loading {MODEL_ID}... (CPU Mode)") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| device_map="cpu" | |
| ) | |
| # --- HELPER FUNCTIONS --- | |
| def extract_answer(text): | |
| """Extracts the number after #### or the last number found.""" | |
| if not text: return None | |
| if "####" in text: | |
| text = text.split("####")[-1] | |
| pattern = r"(-?[$0-9.,]{1,})" | |
| matches = re.findall(pattern, text) | |
| return matches[-1].replace(",", "").replace("$", "").strip() if matches else None | |
| def format_prompt(current_question, history): | |
| # System Prompt: Friendly Teacher | |
| system_prompt = """<|im_start|>system | |
| You are a patient and friendly math teacher. | |
| 1. Solve the problem step-by-step. | |
| 2. Explain the 'logic' simply so a student can understand. | |
| 3. Always end your final result with '#### Number'.<|im_end|>""" | |
| # Add History (Short Term Memory - Last 1 Turn) | |
| history_context = "" | |
| # --- ROBUST HISTORY CHECK (The Fix) --- | |
| if len(history) > 0: | |
| try: | |
| # Get the last interaction | |
| last_turn = history[-1] | |
| # Ensure it's a list/tuple | |
| if isinstance(last_turn, (list, tuple)): | |
| # Take only the first 2 elements (User, AI) and ignore extra metadata | |
| last_q = last_turn[0] | |
| last_a = last_turn[1] | |
| history_context = f""" | |
| <|im_start|>user | |
| {last_q}<|im_end|> | |
| <|im_start|>assistant | |
| {last_a}<|im_end|>""" | |
| except Exception: | |
| # If history format is weird, just ignore it and continue safely | |
| pass | |
| # Current Input | |
| user_input = f""" | |
| <|im_start|>user | |
| {current_question}<|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| return system_prompt + history_context + user_input | |
| def solve_single(question, history, temperature=0.6): | |
| """Standard generation.""" | |
| try: | |
| prompt = format_prompt(question, history) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=temperature, | |
| do_sample=True | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if "<|im_start|>assistant" in response: | |
| return response.split("<|im_start|>assistant")[-1].strip() | |
| return response | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def solve_majority_vote(question, history): | |
| """Smart Mode: Generates 3 answers and votes.""" | |
| candidates = [] | |
| raw_responses = [] | |
| for _ in range(3): | |
| # Higher temp for variety | |
| resp = solve_single(question, history, temperature=0.8) | |
| raw_responses.append(resp) | |
| ans = extract_answer(resp) | |
| if ans: | |
| candidates.append(ans) | |
| if not candidates: | |
| return raw_responses[0] | |
| vote_counts = Counter(candidates) | |
| winner_ans = vote_counts.most_common(1)[0][0] | |
| for resp in raw_responses: | |
| if extract_answer(resp) == winner_ans: | |
| return f"🏆 **High Confidence Answer (Verified 3x)**\n\n{resp}" | |
| return raw_responses[0] | |
| # --- MAIN CHAT LOGIC --- | |
| def chat_logic(message, history, smart_mode): | |
| if smart_mode: | |
| return solve_majority_vote(message, history) | |
| else: | |
| return solve_single(message, history) | |
| # --- UI SETUP --- | |
| demo = gr.ChatInterface( | |
| fn=chat_logic, | |
| additional_inputs=[ | |
| gr.Checkbox(label="🔥 Enable Smart Mode (Slow but 82% Accurate)", value=False) | |
| ], | |
| title="🧮 AI Math Tutor (Qwen-1.5B Fine-Tuned)", | |
| description=""" | |
| <b>Portfolio Project:</b> A specialized math solver fine-tuned on GSM8K using LoRA. | |
| <br><br> | |
| <a href='https://colab.research.google.com/github/HariHaran9597/Math-solver/blob/main/Try_Math_Solver.ipynb' target='_blank'> | |
| <img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/> | |
| </a> | |
| <br> | |
| (Click the badge above to run this on a Free GPU for faster speeds!) | |
| <br><br> | |
| <b>⚠️ Performance Note:</b> This current page runs on a slow CPU. | |
| """, | |
| examples=[ | |
| ["If I have 30 candies and eat 12, then buy 5 more, how many do I have?", False], | |
| ["It takes 5 machines 5 minutes to make 5 widgets. How long for 100 machines?", True] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |