Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import requests | |
| from transformers import pipeline | |
| from sentence_transformers import SentenceTransformer | |
| from qdrant_client import QdrantClient | |
| from datetime import datetime | |
| import dspy | |
| import json | |
| import google.generativeai as genai | |
| # Configure Gemini API | |
| genai.configure(api_key="AIzaSyBO3-HG-WcITn58PdpK7mMyvFQitoH00qA") # Replace with your actual Gemini API key | |
| # Load Gemini model | |
| def output_guard(answer): | |
| # Check if answer is empty or too short | |
| if not answer or len(answer.strip()) < 20: | |
| print("Output guard triggered: answer too short or empty.") | |
| return False | |
| # You can add more checks here if needed | |
| return True | |
| import os | |
| from datetime import datetime | |
| # Safe path for Hugging Face Spaces (will reset on restart) | |
| feedback_path = "feedback.json" | |
| def store_feedback(question, answer, feedback, correct_answer): | |
| entry = { | |
| "question": question, | |
| "model_answer": answer, | |
| "feedback": feedback, | |
| "correct_answer": correct_answer, | |
| "timestamp": str(datetime.now()) | |
| } | |
| print("Attempting to store feedback:", entry) | |
| try: | |
| with open(feedback_path, "a") as f: | |
| f.write(json.dumps(entry) + "\n") | |
| print("โ Feedback saved at", feedback_path) | |
| except Exception as e: | |
| print("โ Error writing feedback:", e) | |
| import re | |
| def latex_to_plain_math(latex_expr): | |
| # Replace LaTeX formatting with plain text math | |
| latex_expr = latex_expr.strip() | |
| latex_expr = re.sub(r"\\frac\{(.+?)\}\{(.+?)\}", r"(\1) / (\2)", latex_expr) | |
| latex_expr = re.sub(r"\\sqrt\{(.+?)\}", r"โ(\1)", latex_expr) | |
| latex_expr = latex_expr.replace("^2", "ยฒ").replace("^3", "ยณ") | |
| latex_expr = re.sub(r"\^(\d)", r"^\1", latex_expr) # other powers | |
| latex_expr = latex_expr.replace("\\pm", "ยฑ") | |
| latex_expr = latex_expr.replace("\\cdot", "โ ") | |
| latex_expr = latex_expr.replace("{", "").replace("}", "") | |
| return latex_expr | |
| # === Load Models === | |
| print("Loading zero-shot classifier...") | |
| classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
| print("Loading embedding model...") | |
| embedding_model = SentenceTransformer("intfloat/e5-large") | |
| # Use a lighter model for testing | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
| # === Qdrant Setup === | |
| print("Connecting to Qdrant...") | |
| qdrant_client = QdrantClient(path="qdrant_data") | |
| collection_name = "math_problems" | |
| # === Guard Function === | |
| def is_valid_math_question(text): | |
| candidate_labels = ["math", "not math"] | |
| result = classifier(text, candidate_labels) | |
| print("Classifier result:", result) | |
| return result['labels'][0] == "math" and result['scores'][0] > 0.7 | |
| # === Retrieval === | |
| def retrieve_from_qdrant(query): | |
| print("Retrieving context from Qdrant...") | |
| query_vector = embedding_model.encode(query).tolist() | |
| hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=1) | |
| print("Retrieved hits:", hits) | |
| return [hit.payload for hit in hits] if hits else [] | |
| # === Web Search === | |
| def web_search_tavily(query): | |
| print("Calling Tavily...") | |
| TAVILY_API_KEY = "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" | |
| response = requests.post( | |
| "https://api.tavily.com/search", | |
| json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"}, | |
| ) | |
| return response.json().get("answer", "No answer found from Tavily.") | |
| # === DSPy Signature === | |
| class MathAnswer(dspy.Signature): | |
| question = dspy.InputField() | |
| retrieved_context = dspy.InputField() | |
| answer = dspy.OutputField() | |
| # === DSPy Programs === | |
| import google.generativeai as genai | |
| # Configure Gemini | |
| genai.configure(api_key="AIzaSyBO3-HG-WcITn58PdpK7mMyvFQitoH00qA") # Replace with your key | |
| class MathRetrievalQA(dspy.Program): | |
| def forward(self, question): | |
| print("Inside MathRetrievalQA...") | |
| context_items = retrieve_from_qdrant(question) | |
| context = "\n".join([item["solution"] for item in context_items if "solution" in item]) | |
| print("Context for generation:", context) | |
| f = latex_to_plain_math(context) | |
| print(f) | |
| if not context: | |
| return {"answer": "", "retrieved_context": ""} | |
| prompt = f""" | |
| You are a math textbook author. Write a clear, professional, and well-formatted solution for the following math problem, using proper LaTeX formatting in every step. | |
| Format the following LaTeX-based math solution into a clean, human-readable explanation as found in textbooks. Use standard math symbols like ยฑ, โ, fractions with slashes (e.g. (a + b)/c), and superscripts with ^. Do not use LaTeX syntax or backslashes. Do not wrap equations in dollar signs. Present the steps clearly using numbered headings. Keep all fractions in plain text form. | |
| Problem: {question} | |
| Use the following context if needed: | |
| {f} | |
| Write only the formatted solution, as it would appear in a math textbook. please give me well formated as using stantard math symbols like +,=.- ,x,/. | |
| """ | |
| try: | |
| model = genai.GenerativeModel('gemini-2.0-flash') # or use 'gemini-1.5-flash' | |
| response = model.generate_content(prompt) | |
| formatted_answer = response.text | |
| print("Gemini Answer:", formatted_answer) | |
| return {"answer": formatted_answer, "retrieved_context": context} | |
| except Exception as e: | |
| print("Gemini generation error:", e) | |
| return {"answer": "โ ๏ธ Gemini failed to generate an answer.", "retrieved_context": context} | |
| # return dspy.Output(answer=answer, retrieved_context=context) | |
| class WebFallbackQA(dspy.Program): | |
| def forward(self, question): | |
| print("Fallback to Tavily...") | |
| answer = web_search_tavily(question) | |
| # return dspy.Output(answer=answer, retrieved_context="Tavily") | |
| return {"answer": answer, "retrieved_context": "Tavily"} | |
| class MathRouter(dspy.Program): | |
| def forward(self, question): | |
| print("Routing question:", question) | |
| if not is_valid_math_question(question): | |
| return {"answer": "โ Only math questions are accepted. Please rephrase.", "retrieved_context": ""} | |
| result = MathRetrievalQA().forward(question) | |
| # Apply output guard here | |
| return result if result["answer"] else WebFallbackQA().forward(question) | |
| router = MathRouter() | |
| # === Gradio Functions === | |
| def ask_question(question): | |
| print("ask_question() called with:", question) | |
| result = router.forward(question) | |
| print("Result:", result) | |
| #return result.answer, question, result.answer | |
| return result["answer"], question, result["answer"] | |
| def submit_feedback(question, model_answer, feedback, correct_answer): | |
| store_feedback(question, model_answer, feedback, correct_answer) | |
| return "โ Feedback received. Thank you!" | |
| # === Gradio UI === | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## ๐งฎ Math Agent") | |
| with gr.Tab("Ask a Math Question & Submit Feedback"): | |
| with gr.Row(): | |
| question_input = gr.Textbox(label="Enter your math question", lines=2) | |
| submit_btn = gr.Button("Get Answer") | |
| gr.Markdown("### ๐ง Answer:") | |
| answer_output = gr.Markdown() | |
| # Hidden fields to hold question and answer for feedback inputs | |
| hidden_q = gr.Textbox(visible=False) | |
| hidden_a = gr.Textbox(visible=False) | |
| # Connect submit button to ask_question functio | |
| submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a]) | |
| gr.Markdown("### ๐ Submit Feedback") | |
| fb_like = gr.Radio(["๐", "๐"], label="Was the answer helpful?") | |
| fb_correct = gr.Textbox(label="Correct Answer (optional) or Comments") | |
| fb_submit_btn = gr.Button("Submit Feedback") | |
| fb_status = gr.Textbox(label="Status", interactive=False) | |
| feedback_file = gr.File(label="๐ Download Saved Feedback", interactive=False) | |
| # Feedback submit button uses hidden fields + feedback inputs | |
| fb_submit_btn.click( | |
| fn=submit_feedback, | |
| inputs=[hidden_q, hidden_a, fb_like, fb_correct], | |
| outputs=[fb_status] | |
| ) | |
| # Update the file download component | |
| fb_submit_btn.click( | |
| fn=lambda: feedback_path, | |
| outputs=[feedback_file] | |
| ) | |
| demo.launch(share=True, debug=True) | |