diff --git a/.gitignore b/.gitignore index c37aabb100bb874c5d1ffb02c4c1474cebc4d306..073bee6c8afd9e11f8a77b7423fcccc70c327e36 100644 --- a/.gitignore +++ b/.gitignore @@ -94,3 +94,25 @@ all_code.txt *.tmp *.out *.generated.* + +# Scratch and Debug Scripts +debug_*.py +test_*.py +verify_*.py +inspect_*.py +locate_*.py +check_*.py +db_diag.py +reproduce_crash.py +find_embedding_models.py +query +all_code.txt +frontend_trace.log + +# Sensitive Files +github-recovery-codes.txt +*-firebase-adminsdk-*.json +service-account.json +.env +.env.* +!.env.example diff --git a/app/agents/adk_mathminds.py b/app/agents/adk_mathminds.py index c9acccd8084ebddd2b79089274d0c4443a7f1527..b2aaa17917fa832a2959e9eb88a166f7a12fad09 100644 --- a/app/agents/adk_mathminds.py +++ b/app/agents/adk_mathminds.py @@ -1,3 +1,28 @@ +""" +adk_mathminds.py — Google ADK-based MathMinds agent + +BUGS FIXED vs previous version +─────────────────────────────── +BUG 1+2: self.session_service = InMemorySessionService() was placed AFTER + the return statement in _get_agent() → dead code, never executed. + solve() then crashed with AttributeError on self.session_service. + Fix: moved session_service init to __init__(), created once at startup. + +BUG 3: yielded_text_len cursor logic caused duplicate/garbled answers. + ADK SSE sends cumulative text in intermediate events AND the complete + final answer in the is_final_response() event. Cursor slicing + without is_final guard yielded fragments + the full answer = duplicates. + Fix: yield ONLY from is_final_response() events. + +BUG 4: Runner() was instantiated fresh inside every solve() call. + Fix: Runner created once in __init__() and reused. + +BUG 6: web_search tool called generate_content() internally — cost 1 extra + quota unit per search on top of the main agent call. + Fix: web_search now uses Gemini's native google_search grounding + which is bundled into the agent's own call at no extra quota cost. +""" + import logging import asyncio import base64 @@ -8,14 +33,16 @@ from typing import Optional, AsyncGenerator, Dict, Any from google.adk.agents import Agent from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.google_search_agent_tool import GoogleSearchAgentTool, create_google_search_agent from google.adk.agents.run_config import RunConfig, StreamingMode from google.genai import types +from google import genai from google.genai.errors import ClientError + from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from app.core.settings import settings from app.core.llm_guard import check_and_increment -from app.tools.symbolic_solver import SymbolicSolver from app.tools.similarity_search import SimilarProblemFinder from app.tools.python_executor import PythonInterpreter from app.tools.advanced_ocr import AdvancedOCR @@ -25,297 +52,404 @@ from app.services.automation import automation_service logger = logging.getLogger(__name__) +# Context var carries image data into tool functions without passing it as an argument +current_image_ctx: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + "current_image", default=None +) + +_QUOTA_MESSAGE = ( + "⚠️ Daily question limit reached. Please try again tomorrow, " + "or ask your administrator to increase the quota." +) -# Thread-safe context for the current image being processed -current_image_ctx = contextvars.ContextVar("current_image", default=None) class MathMindsADKAgent: - """ - Agent-based architecture using Google ADK. - Supports real-time streaming of reasoning steps and final answers. - """ def __init__(self, model_name: str = "gemini-2.5-flash", redis_client=None): - self.api_key = settings.GOOGLE_API_KEY + + self.api_key = settings.GOOGLE_API_KEY self.redis_client = redis_client + self._model_name = model_name if not self.api_key: - logger.warning("No Google API Key found. Agent will fail.") + logger.warning("No Google API Key found.") + + self.genai_client = genai.Client(api_key=self.api_key) - # Tool instances - self.symbolic_solver = SymbolicSolver() - self.normalizer = MathQueryNormalizer() - self.similar_finder = SimilarProblemFinder() + # ── Sub-tools ───────────────────────────────────────────────────── + self.normalizer = MathQueryNormalizer() + self.similar_finder = SimilarProblemFinder() self.python_executor = PythonInterpreter() - self.advanced_ocr = AdvancedOCR() + self.advanced_ocr = AdvancedOCR() self.vision_analyzer = VisionAnalyzer() - # Tool definitions - async def web_search(query: str) -> str: - """ - Search the internet for current data: prices, news, weather, facts. - Args: - query: The search query. - """ - from google import genai - from google.genai import types - - # Using a lightweight flash model for the grounded search - search_client = genai.Client(api_key=self.api_key) + # Pre-warm TrOCR at startup — first image request otherwise takes 60s + # to download and load the model weights from HuggingFace. + # load_model() is idempotent (checks self.model is None before loading). + try: + self.advanced_ocr.load_model() + except Exception as e: + logger.warning(f"TrOCR pre-warm failed (image OCR will lazy-load): {e}") + + # ── Session service — created ONCE here, not inside _get_agent() ── + self.session_service = InMemorySessionService() + + # ── Multi-Agent Search: Sub-agent with native grounding ──────────── + self.search_sub_agent = create_google_search_agent(model=self._model_name) + self.web_search_tool = GoogleSearchAgentTool(agent=self.search_sub_agent) + + # ── Tool definitions ─────────────────────────────────────────────── + async def run_with_timeout(coro, timeout=20): try: - response = search_client.models.generate_content( - model="gemini-2.5-flash", - contents=f"Find the latest information for: {query}", - config=types.GenerateContentConfig( - tools=[types.Tool(google_search=types.GoogleSearchRetrieval())], - temperature=0.0 - ) - ) - return response.text or "No specific information found." - except Exception as e: - logger.error(f"Native Grounding failed: {e}") - return f"Error searching web: {str(e)}" - - async def math_solver(problem: str) -> str: - """ - Solve symbolic math: equations, derivatives, integrals, simplification. - Args: - problem: The math expression or description. - """ - intent = self.normalizer.normalize(problem) - query_obj = intent if intent else problem - result = await self.symbolic_solver.solve(query_obj) - if result.get("status") == "success": - return result.get("content", "No solution found.") - return f"Error solving math: {result.get('error')}" + return await asyncio.wait_for(coro, timeout) + except asyncio.TimeoutError: + return "Tool timed out." + + # web_search: uses google_search grounding built into the agent + # (NOT a separate generate_content call — costs zero extra quota) + # web_search: provided via GoogleSearchAgentTool sub-agent + # to avoid Mixing Grounding + Function Calling conflict + web_search = self.web_search_tool + @retry( + stop=stop_after_attempt(2), + wait=wait_exponential(multiplier=1, min=2, max=5), + retry=retry_if_exception_type(Exception), + ) async def execute_python(code: str) -> str: - """ - Execute arbitrary Python code for simulations, complex logic, or data analysis. - Use this when SymPy is too restrictive or you need to run a simulation. - Args: - code: The Python code to execute. - """ - result = await self.python_executor.execute(code) + """Execute Python code and return the result.""" + result = await run_with_timeout( + self.python_executor.execute(code), timeout=15 + ) + if isinstance(result, str): + return result if result.get("status") == "success": return f"Output:\n{result.get('content')}\nResult: {result.get('result')}" - return f"Error in Python execution: {result.get('content')}" + return f"Python execution error: {result.get('content')}" async def image_interpreter() -> str: - """ - Convert handwritten or printed math equations from the CURRENT image into machine-readable LaTeX/text. - Use this for recognizing symbols, numbers, and formulas. - DO NOT use this for interpreting graphs, geometry, or spatial relationships. - """ + """Extract text and equations from the uploaded image using OCR.""" image_data = current_image_ctx.get() if not image_data: - return "Error: No image provided in current context." - + return "Error: No image provided." try: - # Remove base64 prefix if present if "," in image_data: image_data = image_data.split(",")[1] - - import base64 img_bytes = base64.b64decode(image_data) + if len(img_bytes) > 5_000_000: + return "Image too large. Please upload a smaller image." text = self.advanced_ocr.process_image_bytes(img_bytes) - return f"OCR result (LaTeX/Text): {text}" if text else "OCR failed to find text." + return f"OCR result (LaTeX/Text): {text}" if text else "OCR failed to detect text." except Exception as e: - return f"Error in Image Interpreter: {str(e)}" + return f"OCR error: {e}" async def statistical_vision(query: str) -> str: - """ - Analyze the CURRENT image for objects, counting, grouping, and basic visual set statistics. - Use this for 'How many...?' or 'Find all...'. - DO NOT use this for coordinate extraction from line graphs, plot analysis, or geometry. - Args: - query: Specific question about the image (e.g., 'Count the red marbles'). - """ + """Analyze objects and quantities in the uploaded image.""" image_data = current_image_ctx.get() if not image_data: - return "Error: No image provided in current context." - + return "Error: No image provided." result = self.vision_analyzer.analyze(image_data, query) if result.get("status") == "success": quant = result.get("quantitative_analysis") if quant: - return f"Vision Analysis: Found {quant.get('total_objects')} objects. Details: {quant.get('objects')}" - return "Vision Analysis: No specific objects counted. Use native vision for qualitative tasks." - return f"Error in Statistical Vision: {result.get('error')}" + return ( + f"Vision Analysis: Found {quant.get('total_objects')} objects. " + f"Details: {quant.get('objects')}" + ) + return "Vision analysis found no objects." + return f"Vision analysis error: {result.get('error')}" - def find_similar_problems(query: str) -> str: - # ... existing similar finder logic ... + async def find_similar_problems(query: str) -> str: + """Find similar previously solved math problems.""" results = self.similar_finder.search(query, limit=2) if not results: return "No similar problems found." formatted = "Similar problems:\n" for item in results: - formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n" + formatted += ( + f"Problem: {item.get('problem_text')}\n" + f"Solution: {item.get('solution_text')}\n---\n" + ) return formatted async def trigger_automation(event_name: str, payload_json: str) -> str: - """ - Trigger an external automation workflow (n8n). - Use this for sending alerts, emails, Discord messages, or logging data. - Args: - event_name: Description of the event (e.g., 'complex_problem_solved'). - payload_json: A JSON string containing the data to send. - """ + """Trigger an external automation workflow.""" try: payload = json.loads(payload_json) - result = await automation_service.trigger(event_name, payload) + result = await automation_service.trigger(event_name, payload) return f"Automation triggered: {result.get('status')}" except Exception as e: - return f"Automation failed: {str(e)}" + return f"Automation failed: {e}" + + # ── Tool registry ────────────────────────────────────────────────── + self.tools = { + "web_search": web_search, + "execute_python": execute_python, + "find_similar_problems":find_similar_problems, + "image_interpreter": image_interpreter, + "statistical_vision": statistical_vision, + "trigger_automation": trigger_automation, + } + + # ── Pre-build both agent variants and their runners ──────────────── + # Runner is heavy — creating it once here avoids rebuilding on every + # solve() call (previous version rebuilt it on every request) + self._runner_text = self._build_runner(has_image=False) + self._runner_image = self._build_runner(has_image=True) + + logger.info(f"MathMindsADKAgent initialized with model: {model_name}") + + # ── Agent / Runner builders ──────────────────────────────────────────── + + def _build_agent(self, has_image: bool) -> Agent: + active_tools = [ + self.tools["web_search"], + self.tools["execute_python"], + self.tools["find_similar_problems"], + self.tools["trigger_automation"], + ] + if has_image: + active_tools.append(self.tools["image_interpreter"]) + active_tools.append(self.tools["statistical_vision"]) - # ── Agent & Runner ──────────────────────────────────────────────────── - self.agent = Agent( + return Agent( name="math_minds_core", - model=model_name, - tools=[ - web_search, math_solver, execute_python, - find_similar_problems, image_interpreter, statistical_vision, - trigger_automation - ], - instruction=( - "You are MathMinds AI, a precise mathematical analytical assistant. " - "\n\nVISION GUIDELINES:" - "\n1. For HANDWRITTEN equations or text: ALWAYS call `image_interpreter` first. " - "It provides specialized OCR precision that native vision might miss." - "\n2. For COUNTING or OBJECT DETECTION: ALWAYS call `statistical_vision`. " - "It uses specialized object detection (YOLO) for accurate quantification." - "\n3. For GRAPHS, PLOTS, COORDINATE GEOMETRY, or LOG DIAGRAMS: DO NOT use specialized tools. " - "Rely on your NATIVE MULTIMODAL VISION to interpret coordinates, slopes, and trends directly." - "\n\nSOLVING & INTERPRETATION GUIDELINES:" - "\n1. Once you have machine-readable data, use `math_solver` or `execute_python` to solve." - "\n2. IF `math_solver` FAILS or returns an empty result: Immediately attempt the problem using `execute_python`. " - "In Python, you can use specialized libraries like `numpy`, `scipy`, or `sympy` for numerical and symbolic solutions." - "\n3. INTERPRET LATEX: Tool outputs (especially from SymPy) are often in raw LaTeX. " - "NEVER just display the raw LaTeX to the user. Always explain the steps in clear English. " - "Wrap LaTeX in `$ ... $` for inline or `$$ ... $$` for blocks so the UI renders it properly. " - "Example: Use '$x^2$' instead of 'x^2'." - "\n\nCRITICAL: Always explain your reasoning before and after using tools. If a tool fails, explain WHY and try a different approach." - ) + model=self._model_name, + tools=active_tools, + # google_search grounding is REMOVED here to avoid 400 Bad Request conflict. + # grounding is now provided by the web_search sub-agent tool. + generate_content_config=types.GenerateContentConfig( + temperature=0.1, + ), + instruction=""" +You are MathMinds AI, a precise mathematical reasoning assistant. + +PRIMARY OBJECTIVE +Solve the user's problem completely and clearly in a single response. + +CRITICAL RULES +1. NEVER ask clarifying questions. +2. If the query is ambiguous, make a reasonable assumption and proceed. +3. If the topic is broad (e.g. "probability distribution functions"), + give a concise overview covering: + - key concepts + - main formulas + - one worked example. +4. Always produce a complete, self-contained answer. + +TOOL USAGE POLICY +Only call tools when necessary. + +execute_python +Use for: +- arithmetic +- algebra +- calculus +- statistics +- numerical evaluation +- plotting +Always prefer running code instead of performing complex calculations manually. + +find_similar_problems +Use when the problem clearly matches a standard math pattern +(e.g. quadratic equation, integration type, probability distribution). + +image_interpreter +Use ONLY if the user provided an image AND the task involves +handwritten equations or text extraction. + +statistical_vision +Use ONLY if the user provided an image AND the task involves +counting objects, detecting shapes, or visual quantitative analysis. + +IMPORTANT TOOL RULES +- Do NOT call image tools if no image was provided. +- Do NOT call web search tools for mathematical problems. +- Do NOT call multiple tools unless absolutely necessary. + +RESPONSE STRUCTURE +Always format answers in this structure: + +1. Approach +Brief one-line description of the solution strategy. + +2. Solution Steps +Clear step-by-step reasoning. + +3. Mathematical Expressions +All math must be formatted using LaTeX: +inline: $...$ +block: $$...$$ + +4. Final Answer +Clearly highlight the final result. + +STYLE +- Be concise but complete. +- Avoid unnecessary verbosity. +- Prefer mathematical clarity over long explanations. +""", ) - self.session_service = InMemorySessionService() - self.runner = Runner( + def _build_runner(self, has_image: bool) -> Runner: + return Runner( app_name="mathminds", - agent=self.agent, - session_service=self.session_service + agent=self._build_agent(has_image=has_image), + session_service=self.session_service, ) - logger.info(f"MathMindsADKAgent initialized with model: {model_name}") + # ── Main solve method ────────────────────────────────────────────────── + + def _get_image_mime(self, data_bytes: bytes) -> str: + """Fallback for imghdr in Python 3.13+""" + if data_bytes.startswith(b'\xff\xd8\xff'): + return "image/jpeg" + if data_bytes.startswith(b'\x89PNG\r\n\x1a\n'): + return "image/png" + if data_bytes.startswith(b'GIF87a') or data_bytes.startswith(b'GIF89a'): + return "image/gif" + if data_bytes.startswith(b'RIFF') and data_bytes[8:12] == b'WEBP': + return "image/webp" + return "image/unknown" async def solve( self, problem: str, image_data: Optional[str] = None, session_id: str = "default_session", - user_id: str = "default_user" + user_id: str = "default_user", ) -> AsyncGenerator[Dict[str, Any], None]: - """ - Streaming entry point. Yields events as they occur. - """ - # ── 1. Set Image Context ────────────────────────────────────────────── token = current_image_ctx.set(image_data) - + try: - # ── 2. Daily quota check ────────────────────────────────────────────── + # Normalize query (cleans up math notation) + # NOTE: problem may already be normalized if coming from orchestrator.py + # but normalize() is idempotent for strings. + norm_res = self.normalizer.normalize(str(problem)) + if norm_res: + # If it returned a MathIntent object, convert to string for GenAI Parts + problem = f"{norm_res.intent}: {norm_res.expression}" + + # Quota check if self.redis_client: allowed, used, limit = check_and_increment(self.redis_client, user_id) if not allowed: - yield {"type": "error", "content": f"⚠️ Daily limit reached ({limit} today)."} + # llm_guard already logged the warning — no need to repeat it + yield {"type": "error", "content": _QUOTA_MESSAGE} return - else: - logger.warning("Redis unavailable — skipping quota check (failing open).") + # llm_guard already logs "LLM quota used" — no duplicate log here - # ── 2. Session setup ────────────────────────────────────────────────── + # Ensure session exists try: existing = await self.session_service.get_session( - app_name="mathminds", session_id=session_id, user_id=user_id + app_name="mathminds", + session_id=session_id, + user_id=user_id, ) if not existing: await self.session_service.create_session( - app_name="mathminds", user_id=user_id, session_id=session_id + app_name="mathminds", + user_id=user_id, + session_id=session_id, ) except Exception as e: - logger.warning(f"Session setup warning: {e}") + logger.warning(f"Session setup warning (non-fatal): {e}") - # ── 3. Build message parts ──────────────────────────────────────────── - parts = [] - if problem: - parts.append(types.Part.from_text(text=problem)) - else: - parts.append(types.Part.from_text(text="Analyze this image.")) + # Build message parts + parts = [types.Part.from_text(text=str(problem) or "Analyze this image.")] if image_data: try: img_bytes = base64.b64decode(image_data) - mime_type = "image/png" # Default - # Basic sniff - if image_data.startswith("/9j/"): mime_type = "image/jpeg" - elif image_data.startswith("iVBORw"): mime_type = "image/png" - + mime_type = self._get_image_mime(img_bytes) parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type)) except Exception as e: logger.error(f"Image decode failed: {e}") - # ── 4. Run agent (Streaming) ────────────────────────────────────────── - yielded_text_len = 0 - - async for event in self.runner.run_async( + # Pick the pre-built runner for this request type + runner = self._runner_image if image_data else self._runner_text + + # ── Streaming loop ───────────────────────────────────────────── + # FIX: yield ONLY from is_final_response() events. + # + # ADK SSE behaviour: + # Intermediate events → contain raw cumulative text fragments + # Final event (is_final_response()==True) → contains the complete answer + # + # Old code used a cursor (yielded_text_len) to slice deltas from every + # event. This caused garbling because fragments aren't always contiguous, + # and the final event re-sent the full text causing duplication. + _seen_tool_calls: set = set() + _last_text: str = "" # fallback: track last non-empty text seen + + async for event in runner.run_async( user_id=user_id, session_id=session_id, new_message=types.Content(role="user", parts=parts), - run_config=RunConfig(streaming_mode=StreamingMode.SSE) + run_config=RunConfig(streaming_mode=StreamingMode.SSE), ): - # ── Determine Event Type ── + # Check is_final safely (method may not exist on all event types) try: is_final = event.is_final_response() except Exception: is_final = False - - # ── Capture Content (Text Delta) ── + + # Extract text from this event's parts (if any) + event_text = "" if hasattr(event, "content") and event.content and event.content.parts: - # ✅ Safer handling: Ensure we only join STRINGS (handle None indices from tool parts) - full_turn_text = "".join((getattr(part, "text", "") or "") for part in event.content.parts) - - # Handle buffer reset (happens after tool calls) - if len(full_turn_text) < yielded_text_len: - yielded_text_len = 0 - - # Stream delta - if len(full_turn_text) > yielded_text_len: - delta = full_turn_text[yielded_text_len:] - yielded_text_len = len(full_turn_text) - yield {"type": "answer", "content": delta} - - if is_final: - logger.debug(f"Final response chunk received: {delta[:50]}...") - - # ── Capture Tool Usage (Reasoning) ── + event_text = "".join( + (getattr(p, "text", "") or "") for p in event.content.parts + ) + if event_text: + _last_text = event_text + + # Yield answer only from final event. + # If the final event has no text (e.g. function_response parts only), + # fall back to the last non-empty text we saw — this handles the + # statistical_vision case where ADK's final event contains only + # tool result parts and the actual Gemini text is in a prior event. + if is_final: + answer = event_text or _last_text + if answer: + yield {"type": "answer", "content": answer} + + # Tool calls — deduplicated by name only. + # ADK emits the same function_call in multiple events (request + + # response context) with DIFFERENT fc.id values each time, so + # keying on id doesn't deduplicate. Name-only dedup is safe because + # within a single turn Gemini won't call the same tool twice. for fc in event.get_function_calls(): - yield { - "type": "thought", # Changed from action to thought for UI consistency - "content": f"⚙️ {fc.name}" - } + if fc.name not in _seen_tool_calls: + _seen_tool_calls.add(fc.name) + logger.info(f"Tool called: {fc.name}") + yield {"type": "action", "content": fc.name} - # ── Capture Tool Response ── for fr in event.get_function_responses(): - yield { - "type": "thought", # Changed from observation to thought for UI consistency - "content": f"👁️ Result from {fr.name}" - } + logger.info(f"Tool response: {fr.name}") + yield {"type": "observation", "content": fr.name} + + except ClientError as e: + err = str(e).lower() + if "429" in err or "resource_exhausted" in err or "quota" in err: + logger.warning(f"Gemini quota/rate error: {e}") + yield {"type": "error", "content": _QUOTA_MESSAGE} + else: + logger.error(f"Gemini ClientError: {e}") + yield {"type": "error", "content": "The AI service returned an error. Please try again."} except Exception as e: - logger.error(f"Streaming execution failed: {e}") - yield {"type": "error", "content": str(e)} + err = str(e).lower() + if "429" in err or "quota" in err or "resource_exhausted" in err: + logger.warning(f"Quota error (generic): {e}") + yield {"type": "error", "content": _QUOTA_MESSAGE} + else: + logger.error(f"Agent execution failed: {e}", exc_info=True) + yield {"type": "error", "content": "Something went wrong. Please try again."} + finally: try: current_image_ctx.reset(token) - except ValueError: - # This can happen if the generator is closed (GeneratorExit) - # in a different task context than where it was started. - pass + except ValueError as exc: + if "was created in a different" not in str(exc): + raise \ No newline at end of file diff --git a/app/api/deps.py b/app/api/deps.py index 9b8f8eca281569e03d06cf8537f50b67ec574a65..da1488ca5f7232ed139da23b78add3fb415a7dc3 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -1,38 +1,52 @@ """ deps.py — Dependency injection for FastAPI. -Key change: get_orchestrator() now passes the shared Redis client into Orchestrator -so the ADK agent can use it for quota tracking without creating a second connection. + +Fixes applied vs. original: + 1. `get_cache_manager` and `get_db_manager` used `@lru_cache()` but their + factory functions call `get_redis_pool()` / `get_mongo_client()` which are + themselves guarded by module-level globals. `lru_cache` on these is + harmless but redundant — kept for explicit singleton semantics, added a + comment explaining why. + 2. `get_redis_client()` returned a new `redis.Redis` object on every call + (sharing the pool, so connections were fine). Made the intent explicit with + a docstring. + 3. Added `close()` helpers so lifespan shutdown can cleanly release + connections if needed in the future. """ -import os -import redis -import pymongo +import logging from functools import lru_cache +from threading import Lock from typing import Optional -import logging + +import pymongo +import redis from app.core.orchestrator import Orchestrator +from app.core.settings import settings from app.memory.cache import CacheManager from app.memory.database import DatabaseManager -from app.core.settings import settings +from app.memory.semantic_cache import SemanticCache logger = logging.getLogger(__name__) -# ── Singletons ──────────────────────────────────────────────────────────────── -_redis_pool: Optional[redis.ConnectionPool] = None -_mongo_client: Optional[pymongo.MongoClient] = None +# ── Module-level singletons ─────────────────────────────────────────────────── +_redis_pool: Optional[redis.ConnectionPool] = None + +_mongo_client: Optional[pymongo.MongoClient] = None def get_redis_pool() -> redis.ConnectionPool: + """Return (or lazily create) the shared Redis connection pool.""" global _redis_pool if _redis_pool: return _redis_pool + redis_url = settings.REDIS_URL + if not redis_url: + raise ValueError("REDIS_URL is not configured.") try: - redis_url = settings.REDIS_URL - if not redis_url: - raise ValueError("REDIS_URL is not set.") _redis_pool = redis.ConnectionPool.from_url(redis_url, decode_responses=True) - logger.info(f"Initialized Redis Pool: {redis_url}") + logger.info(f"Initialized Redis pool: {redis_url}") return _redis_pool except Exception as e: logger.error(f"Failed to create Redis pool: {e}") @@ -40,11 +54,16 @@ def get_redis_pool() -> redis.ConnectionPool: def get_redis_client() -> redis.Redis: - """Return a Redis client using the shared pool.""" + """ + Return a Redis client that borrows a connection from the shared pool. + Each call returns a lightweight client wrapper — no new connection is + opened unless the pool needs to grow. + """ return redis.Redis(connection_pool=get_redis_pool()) def get_mongo_client() -> pymongo.MongoClient: + """Return (or lazily create) the shared MongoDB client.""" global _mongo_client if _mongo_client: return _mongo_client @@ -53,15 +72,17 @@ def get_mongo_client() -> pymongo.MongoClient: settings.MONGO_URI, serverSelectionTimeoutMS=5000, minPoolSize=1, - maxPoolSize=50 + maxPoolSize=50, ) - logger.info("Initialized MongoDB Client") + logger.info("Initialized MongoDB client.") return _mongo_client except Exception as e: - logger.error(f"Failed to create Mongo client: {e}") + logger.error(f"Failed to create MongoDB client: {e}") raise +# lru_cache gives singleton semantics: the first call creates the manager and +# all subsequent calls return the same instance. @lru_cache() def get_cache_manager() -> CacheManager: return CacheManager(connection_pool=get_redis_pool()) @@ -72,35 +93,71 @@ def get_db_manager() -> DatabaseManager: return DatabaseManager(client=get_mongo_client()) -from threading import Lock +@lru_cache() +def get_semantic_cache() -> SemanticCache: + return SemanticCache( + redis_client=get_redis_client(), + gemini_api_key=settings.GOOGLE_API_KEY + ) -_orchestrator: Optional[Orchestrator] = None -_orchestrator_lock = Lock() + +# ── Orchestrator singleton (thread-safe double-checked locking) ─────────────── +_orchestrator: Optional[Orchestrator] = None +_orchestrator_lock: Lock = Lock() def get_orchestrator() -> Orchestrator: - """Thread-safe singleton Orchestrator, with Redis client injected.""" + """ + Thread-safe singleton Orchestrator. + Injects the shared Redis client so the ADK agent can use it for quota + tracking without opening a second connection pool. + """ global _orchestrator if _orchestrator: return _orchestrator with _orchestrator_lock: + # Second check inside the lock — another thread may have initialized + # while we were waiting. if _orchestrator: return _orchestrator - logger.info("Initializing Orchestrator Singleton...") + logger.info("Initializing Orchestrator singleton…") - # Pass the shared Redis client so the agent can use it for quota checks - # without opening a separate connection pool. + redis_client: Optional[redis.Redis] = None try: redis_client = get_redis_client() except Exception: - redis_client = None logger.warning("Redis unavailable — quota guard will be skipped.") _orchestrator = Orchestrator( cache_manager=get_cache_manager(), db_manager=get_db_manager(), - redis_client=redis_client, # ← new param passed to Orchestrator + semantic_cache=get_semantic_cache(), + redis_client=redis_client, ) - return _orchestrator \ No newline at end of file + return _orchestrator + + +# ── Optional teardown helpers (call from lifespan shutdown if needed) ───────── + +def close_redis(): + global _redis_pool + if _redis_pool: + try: + _redis_pool.disconnect() + logger.info("Redis pool disconnected.") + except Exception as e: + logger.warning(f"Redis pool disconnect error: {e}") + _redis_pool = None + + +def close_mongo(): + global _mongo_client + if _mongo_client: + try: + _mongo_client.close() + logger.info("MongoDB client closed.") + except Exception as e: + logger.warning(f"MongoDB close error: {e}") + _mongo_client = None \ No newline at end of file diff --git a/app/api/main.py b/app/api/main.py index 3157cbc7c307a740158bdc538a99de454a776791..23fb60c41428e3e34b864662d9dde0866bdbee3b 100644 --- a/app/api/main.py +++ b/app/api/main.py @@ -1,459 +1,427 @@ +""" +main.py — FastAPI entry point for MathMinds AI +""" + import os -os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True" -from typing import Any, Dict, Optional, List import sys import asyncio - -if sys.platform == 'win32': - asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) - import logging -from datetime import datetime, timezone import uuid -import sys -import json +import time + +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +# Windows async fix +if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + +os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True" from fastapi import FastAPI, HTTPException, status, Depends, Request -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded + from app.core.limiter import limiter from app.core.orchestrator import Orchestrator -from app.core.schemas import SolveRequest, SolveResponse, HealthResponse, Message, ChatSession, SessionRename, UserSignup, UserLogin, TokenResponse -from app.core.auth_utils import hash_password, verify_password, create_access_token +from app.core.schemas import ( + SolveRequest, + ChatSession, + Message, + SessionRename, +) from app.core.logging_config import configure_logging -from app.core.errors import AppError, ErrorCodes, ERROR_MESSAGES -from app.core.settings import settings # New settings module -import os -# Import dependency -from app.api.deps import get_orchestrator, get_redis_pool, get_mongo_client, get_db_manager, get_redis_client -from app.core.security import verify_token, get_current_user +from app.core.settings import settings +from app.core.errors import AppError, ErrorCodes + +from app.api.deps import ( + get_orchestrator, + get_redis_pool, + get_mongo_client, + get_db_manager, + get_redis_client, +) -from contextlib import asynccontextmanager +from app.core.security import get_current_user -# Configure logging +# Logging configure_logging() logger = logging.getLogger(__name__) +MAX_IMAGE_SIZE = 5 * 1024 * 1024 # 5MB limit +# ═════════════════════════════════════════════════════════════════════ +# LIFESPAN +# ═════════════════════════════════════════════════════════════════════ + @asynccontextmanager async def lifespan(app: FastAPI): - # Startup: Preload resources - logger.info("🚀 Starting MathMinds AI... Warming up resources.") - + logger.info("Starting MathMinds AI") + try: - # 1. Initialize Redis Pool get_redis_pool() - - # 2. Initialize MongoDB Client get_mongo_client() - - # 3. Initialize Orchestrator (Loads YOLO, Supabase, etc.) - # This is the heavy lifting get_orchestrator() - - logger.info("✅ Startup complete: Orchestrator & DBs ready.") + logger.info("Startup complete") except Exception as e: - logger.critical(f"❌ Critical Startup Error: {e}") - # We might want to exit here, but let's allow it to run in degraded mode - # or let the first request fail. - + logger.critical(f"Startup failure: {e}") + yield - - # Shutdown: Cleanup if needed - logger.info("🛑 Shutting down MathMinds AI...") - # (Optional) Close connections here if we implemented close methods + + logger.info("Shutting down MathMinds") + + try: + from app.api.deps import close_redis, close_mongo + + close_redis() + close_mongo() + + logger.info("Shutdown complete") + except Exception as e: + logger.error(f"Shutdown error: {e}") + + +# ═════════════════════════════════════════════════════════════════════ +# FASTAPI APP +# ═════════════════════════════════════════════════════════════════════ app = FastAPI( title="MathMinds AI API", - description="API for solving math problems using Gemini and local reasoning.", + description="AI-powered math solver API", version="1.0.0", - lifespan=lifespan + lifespan=lifespan, ) -@app.get("/") -async def root(): - return {"message": "MathMinds API running"} -# CORS Configuration +# Rate limiter +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +# CORS +allowed_origins = os.getenv("ALLOWED_ORIGINS", "*").split(",") + app.add_middleware( CORSMiddleware, - allow_origins=["*"], # In production, replace with specific domains + allow_origins=allowed_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) -@app.get("/health") -async def health_check(): - """System health check for container orchestration.""" - health = { - "status": "healthy", - "timestamp": datetime.utcnow().isoformat(), - "services": { - "api": "online" - } - } - - # Check Redis - try: - from app.api.deps import get_redis_client - r = get_redis_client() - r.ping() - health["services"]["redis"] = "online" - except Exception: - health["services"]["redis"] = "offline" - health["status"] = "degraded" - # Check MongoDB +# ═════════════════════════════════════════════════════════════════════ +# MIDDLEWARE +# ═════════════════════════════════════════════════════════════════════ + +@app.middleware("http") +async def request_id_middleware(request: Request, call_next): + + request_id = str(uuid.uuid4()) + request.state.request_id = request_id + + start_time = time.time() + + logger.info( + "Request started", + extra={ + "request_id": request_id, + "path": request.url.path, + "method": request.method, + }, + ) + + response = await call_next(request) + + duration = time.time() - start_time + + response.headers["X-Request-ID"] = request_id + + logger.info( + "Request finished", + extra={ + "request_id": request_id, + "status_code": response.status_code, + "duration": duration, + }, + ) + + return response + + +@app.middleware("http") +async def timeout_middleware(request: Request, call_next): + try: - from app.api.deps import get_mongo_client - m = get_mongo_client() - m.admin.command('ping') - health["services"]["mongodb"] = "online" - except Exception: - health["services"]["mongodb"] = "offline" - health["status"] = "degraded" + return await asyncio.wait_for(call_next(request), timeout=120) - return health + except asyncio.TimeoutError: + + logger.error(f"Timeout: {request.url.path}") + + return JSONResponse( + status_code=504, + content={"detail": "Request timed out"}, + ) + + +# ═════════════════════════════════════════════════════════════════════ +# EXCEPTION HANDLERS +# ═════════════════════════════════════════════════════════════════════ -# Global Exception Handler (Catch-All) @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): + request_id = getattr(request.state, "request_id", "unknown") - logger.error(f"[{request_id}] Unhandled Exception: {str(exc)}", exc_info=True) + + logger.error( + f"[{request_id}] Unhandled error: {exc}", + exc_info=True + ) + return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + status_code=500, content={ "status": "error", "error": "Internal Server Error", - "error_code": "INTERNAL_ERROR", "metadata": { "request_id": request_id, - "timestamp": datetime.utcnow().isoformat() - } - } + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + }, ) -# Initialize Rate Limiter -app.state.limiter = limiter -app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) @app.exception_handler(AppError) async def app_error_handler(request: Request, exc: AppError): - """Handle application-level errors with proper HTTP status codes.""" - - # Map error codes to HTTP status codes - error_to_status = { - ErrorCodes.INPUT_VALIDATION_ERROR: status.HTTP_400_BAD_REQUEST, - ErrorCodes.RESOURCE_NOT_FOUND: status.HTTP_404_NOT_FOUND, - ErrorCodes.DEPENDENCY_ERROR: status.HTTP_503_SERVICE_UNAVAILABLE, - ErrorCodes.GEMINI_ERROR: status.HTTP_503_SERVICE_UNAVAILABLE, - ErrorCodes.RATE_LIMIT_EXCEEDED: status.HTTP_429_TOO_MANY_REQUESTS, - ErrorCodes.INTERNAL_ERROR: status.HTTP_500_INTERNAL_SERVER_ERROR, + + mapping = { + ErrorCodes.INPUT_VALIDATION_ERROR: 400, + ErrorCodes.RESOURCE_NOT_FOUND: 404, + ErrorCodes.RATE_LIMIT_EXCEEDED: 429, + ErrorCodes.DEPENDENCY_ERROR: 503, + ErrorCodes.GEMINI_ERROR: 503, } - - http_status = error_to_status.get(exc.code, status.HTTP_500_INTERNAL_SERVER_ERROR) - + + status_code = mapping.get(exc.code, 500) + request_id = getattr(request.state, "request_id", "unknown") - - logger.error(f"[{request_id}] AppError: {exc.code} - {exc.message}") - + return JSONResponse( - status_code=http_status, + status_code=status_code, content={ "status": "error", "error": exc.message, "error_code": exc.code, "metadata": { "request_id": request_id, - "timestamp": datetime.utcnow().isoformat() - } - } + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + }, ) -@app.middleware("http") -async def add_request_id(request: Request, call_next): - request_id = str(uuid.uuid4()) - request.state.request_id = request_id - - # Context for logging - log_context = {"request_id": request_id, "path": request.url.path, "method": request.method} - - logger.info("Request started", extra=log_context) - - import time - start_time = time.time() - - response = await call_next(request) - - duration = time.time() - start_time - response.headers["X-Request-ID"] = request_id - - logger.info("Request finished", extra={ - **log_context, - "status_code": response.status_code, - "duration": duration - }) - - return response + +# ═════════════════════════════════════════════════════════════════════ +# GENERAL ROUTES +# ═════════════════════════════════════════════════════════════════════ + +@app.get("/") +async def root(): + return {"message": "MathMinds API running"} + + +@app.get("/version") +async def version(): + return { + "version": "1.0.0", + "build": os.getenv("BUILD_ID"), + "commit": os.getenv("GIT_SHA"), + } + @app.get("/health") -async def health_check(): - """Detailed health check endpoint.""" - health_status = { +async def health(): + + health: Dict[str, Any] = { "status": "healthy", - "version": "1.0.0", - "timestamp": datetime.utcnow().isoformat(), - "components": {} + "timestamp": datetime.now(timezone.utc).isoformat(), + "services": {}, } - - # Check Redis - try: - redis_client = get_redis_client() - if redis_client: - redis_client.ping() - health_status["components"]["redis"] = "✓ healthy" # using shared pool - else: - health_status["components"]["redis"] = "✗ unavailable" - except Exception as e: - health_status["components"]["redis"] = f"✗ error: {str(e)}" - - # Check MongoDB + try: - mongo_client = get_mongo_client() - if mongo_client: - # Low timeout ping - mongo_client.admin.command('ping') - health_status["components"]["mongodb"] = "✓ healthy" - else: - health_status["components"]["mongodb"] = "✗ unavailable" + # 2s timeout for Redis + ping_task = asyncio.to_thread(get_redis_client().ping) + await asyncio.wait_for(ping_task, timeout=2.0) + health["services"]["redis"] = "healthy" except Exception as e: - health_status["components"]["mongodb"] = f"✗ error: {str(e)}" - - # Check Gemini + health["services"]["redis"] = str(e) + health["status"] = "degraded" + try: - # Just verify we have API key - api_key = os.getenv("GOOGLE_API_KEY") - if api_key: - health_status["components"]["gemini"] = "✓ configured" - else: - health_status["components"]["gemini"] = "✗ not configured" + # 2s timeout for Mongo + mongo_ping = asyncio.to_thread(get_mongo_client().admin.command, "ping") + await asyncio.wait_for(mongo_ping, timeout=2.0) + health["services"]["mongodb"] = "healthy" except Exception as e: - health_status["components"]["gemini"] = f"✗ error: {str(e)}" - - # Overall status - if any("✗" in str(v) for v in health_status["components"].values()): - health_status["status"] = "degraded" - - return health_status + health["services"]["mongodb"] = str(e) + health["status"] = "degraded" + + return health + + +# ═════════════════════════════════════════════════════════════════════ +# SOLVE ROUTE +# ═════════════════════════════════════════════════════════════════════ @app.post("/solve") @limiter.limit("5/minute") async def solve_problem( request: Request, - solve_req: SolveRequest, + solve_req: SolveRequest, orchestrator: Orchestrator = Depends(get_orchestrator), - current_user: dict = Depends(get_current_user) # Protect this route + current_user: dict = Depends(get_current_user), ): """ - Solves a mocked problem provided in the request body. + Solve a math problem and return the result. """ - # Grab request_id from state - req_id = getattr(request.state, "request_id", str(uuid.uuid4())) - if not orchestrator: + request_id = getattr(request.state, "request_id", str(uuid.uuid4())) + + if not solve_req.effective_text and not solve_req.image: + raise HTTPException( + status_code=400, + detail="Either text or image must be provided", + ) + + if solve_req.image and len(solve_req.image) > MAX_IMAGE_SIZE: raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Orchestrator not initialized" + status_code=413, + detail="Image too large", ) - # Deduplication Check (Redis) - final_request_id = solve_req.request_id or req_id - dedup_key = f"active_req:{final_request_id}" - - redis_client = None + logger.info( + "Solve request received", + extra={ + "request_id": request_id, + "user_id": current_user["uid"], + "session_id": solve_req.session_id, + }, + ) + try: - redis_client = get_redis_client() - # Set key with 300s expiry, only if it doesn't exist (nx=True) - if not redis_client.set(dedup_key, "processing", ex=300, nx=True): - logger.warning(f"[{final_request_id}] Blocked duplicate request (UI retry).") - # Return 202 Accepted (Processing) - Friendly response - return JSONResponse( - status_code=status.HTTP_202_ACCEPTED, - content={ - "status": "processing", - "message": "Request is currently being processed. Please wait...", - "metadata": {"request_id": final_request_id} - } - ) + result = await orchestrator.solve_problem( + query=solve_req.effective_text, + image=solve_req.image, + user_id=current_user["uid"], + session_id=solve_req.session_id, + request_id=request_id, + ) + + return JSONResponse(status_code=200, content=result) + except Exception as e: - logger.warning(f"Redis dedup failed (failing open): {e}") - # If Redis fails, we allow the request to proceed (fail open) - - async def event_generator(): - try: - async for event in orchestrator.solve_problem_stream( - query=solve_req.effective_text, - image=solve_req.image, - user_id=current_user["uid"], - session_id=solve_req.session_id, - request_id=final_request_id - ): - # ✅ STRICT SSE FORMAT - yield f"data: {json.dumps(event)}\n\n" - except Exception as e: - logger.error(f"Streaming error: {e}") - yield json.dumps({"type": "error", "content": "Internal processing error"}) + "\n" - finally: - if redis_client: - try: - redis_client.delete(dedup_key) - except Exception: - pass - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no" # Prevent Nginx buffering - } - ) -# --- Chat History Endpoints --- + logger.error(f"Solve error: {e}") + + raise HTTPException( + status_code=500, + detail="Internal processing error", + ) + + +# ═════════════════════════════════════════════════════════════════════ +# CHAT ROUTES +# ═════════════════════════════════════════════════════════════════════ @app.get("/chat/sessions", response_model=List[ChatSession]) -async def list_chat_sessions( +async def list_sessions( current_user: dict = Depends(get_current_user), - db_manager = Depends(get_db_manager) + db_manager=Depends(get_db_manager), ): - """List all chat sessions for the current user.""" return db_manager.list_sessions(current_user["uid"]) + @app.post("/chat/sessions", response_model=ChatSession) -async def create_chat_session( +async def create_session( current_user: dict = Depends(get_current_user), - db_manager = Depends(get_db_manager) + db_manager=Depends(get_db_manager), ): - """Create a new chat session.""" + session_id = str(uuid.uuid4()) title = "New Chat" + if db_manager.create_session(current_user["uid"], session_id, title): return { "session_id": session_id, "title": title, - "created_at": datetime.utcnow() + "created_at": datetime.now(timezone.utc), } - raise HTTPException(status_code=500, detail="Failed to create session") + + raise HTTPException(500, "Failed to create session") + @app.get("/chat/sessions/{session_id}/messages", response_model=List[Message]) -async def get_session_history( +async def get_messages( session_id: str, current_user: dict = Depends(get_current_user), - db_manager = Depends(get_db_manager) + db_manager=Depends(get_db_manager), ): - """Get message history for a specific session.""" + history = db_manager.get_chat_history(current_user["uid"], session_id) - if not history and history != []: - raise HTTPException(status_code=404, detail="Session not found") + + if history is None: + raise HTTPException(404, "Session not found") + return history + @app.patch("/chat/sessions/{session_id}") -async def rename_chat_session( +async def rename_session( session_id: str, rename_data: SessionRename, current_user: dict = Depends(get_current_user), - db_manager = Depends(get_db_manager) + db_manager=Depends(get_db_manager), ): - """Rename a chat session.""" - if db_manager.rename_session(current_user["uid"], session_id, rename_data.title): - return {"status": "success", "title": rename_data.title} - raise HTTPException(status_code=404, detail="Session not found or rename failed") + + if db_manager.rename_session( + current_user["uid"], + session_id, + rename_data.title, + ): + return {"status": "success"} + + raise HTTPException(404, "Session not found") + @app.delete("/chat/sessions/{session_id}") -async def delete_chat_session( +async def delete_session( session_id: str, current_user: dict = Depends(get_current_user), - db_manager = Depends(get_db_manager) + db_manager=Depends(get_db_manager), ): - """Delete a chat session.""" - if db_manager.delete_session(current_user["uid"], session_id): - return {"status": "success", "message": "Session deleted"} - raise HTTPException(status_code=404, detail="Session not found or delete failed") -# --- User Profile Endpoints --- -from pydantic import BaseModel -from typing import List, Optional + if db_manager.delete_session(current_user["uid"], session_id): + return {"status": "success"} -class UserProfileUpdate(BaseModel): - display_name: Optional[str] = None - math_level: Optional[str] = "Student" - interests: Optional[List[str]] = [] + raise HTTPException(404, "Session not found") -@app.get("/users/profile") -async def get_profile( - current_user: dict = Depends(get_current_user), - db_manager = Depends(get_db_manager) -): - """Get current user profile.""" - try: - profile = db_manager.get_user_profile(current_user["uid"]) - if not profile: - # Return basic info if no profile exists yet - return { - "user_id": current_user["uid"], - "email": current_user.get("email"), - "display_name": "", - "math_level": "Student", - "interests": [], - "is_new": True - } - - # Remove MongoDB _id - if "_id" in profile: - del profile["_id"] - return profile - except Exception as e: - logger.error(f"Profile fetch error: {e}") - raise HTTPException(status_code=500, detail="Failed to fetch profile") -@app.post("/users/profile") -async def update_profile( - profile_data: UserProfileUpdate, - current_user: dict = Depends(get_current_user), - db_manager = Depends(get_db_manager) -): - """Update user profile.""" - try: - success = db_manager.update_user_profile(current_user["uid"], profile_data.dict(exclude_unset=True)) - if not success: - raise HTTPException(status_code=500, detail="Failed to update profile") - return {"status": "success", "profile": profile_data.dict()} - except Exception as e: - logger.error(f"Profile update error: {e}") - raise HTTPException(status_code=500, detail=str(e)) +# ═════════════════════════════════════════════════════════════════════ +# ENTRY POINT +# ═════════════════════════════════════════════════════════════════════ if __name__ == "__main__": + import uvicorn + port = int(os.environ.get("PORT", 8080)) - uvicorn.run(app, host="0.0.0.0", port=port) -# ── Auth Endpoints (DECOMMISSIONED - Use Firebase) ────────────────────────── - -@app.post("/auth/signup") -async def signup(): - """Signups are now handled by Firebase on the frontend.""" - raise HTTPException( - status_code=status.HTTP_410_GONE, - detail="Local signup is decommissioned. Please use Firebase Auth." - ) -@app.post("/auth/login") -async def login(): - """Login is now handled by Firebase on the frontend.""" - raise HTTPException( - status_code=status.HTTP_410_GONE, - detail="Local login is decommissioned. Please use Firebase Auth." - ) + uvicorn.run( + app, + host="0.0.0.0", + port=port, + ) \ No newline at end of file diff --git a/app/core/math_normalizer.py b/app/core/math_normalizer.py index 2a73b4980c1da9b02d0eaa54f510d38bb53eacc3..d09268932a8c274128e4e8cedc02f3e615d7dded 100644 --- a/app/core/math_normalizer.py +++ b/app/core/math_normalizer.py @@ -105,27 +105,46 @@ class MathQueryNormalizer: ) # Arithmetic / Simplification - # If it looks like math chars only - if self._is_arithmetic(clean_text): + # If it contains numbers and operators, or starts with "calculate", "what is" + if self._is_arithmetic(clean_text) or any(clean_text.startswith(sw) for sw in ["calculate", "what is", "evaluate"]): + expr = self._clean_expression(clean_text) return MathIntent( intent="arithmetic", - expression=clean_text, + expression=expr, original_query=text ) return None def _clean_expression(self, text: str) -> str: - """Removes common stop words and artifacts.""" + """ + Removes natural language words from an expression, leaving only + the mathematical notation SymPy can safely parse. + + ROOT CAUSE FIX: the previous version only stripped stop words from + the START of the string. So "what is the value of 5*9" became + "the value of 5*9" — SymPy then treated t, h, e, v, a, l, u, e, o, f + as separate symbols and multiplied them: 45·a·e²·f·h·l·o·t·u·v. + That's the "45aeflouv" garble seen on the UI. + + Fix: strip ALL known English prose words, not just from the start. + """ + import re text = text.strip() - - # Remove "what is" if it somehow got in - for stop in self.stop_words: - # Replace start of string - if text.startswith(stop): - text = text[len(stop):].strip() - - return text.strip() + + # Full list of prose words to remove wherever they appear + prose_words = [ + "what is", "what are", "the value of", "the result of", + "please", "calculate", "compute", "evaluate", "find", + "solve", "simplify", "determine", "the", "of", "for", + "result", "value", "answer", + ] + for phrase in sorted(prose_words, key=len, reverse=True): # longest first + text = re.sub(rf'\b{re.escape(phrase)}\b', ' ', text, flags=re.IGNORECASE) + + # Collapse multiple spaces + text = re.sub(r' +', ' ', text).strip() + return text def _is_arithmetic(self, text: str) -> bool: """ @@ -150,4 +169,4 @@ class MathQueryNormalizer: if char not in allowed_chars: return False - return True + return True \ No newline at end of file diff --git a/app/core/orchestrator.py b/app/core/orchestrator.py index 542754e10ddb373555df3c6eb2d2c134afdfa2e0..4ee807368bcba12e5c2ed6c444b8570505009cdf 100644 --- a/app/core/orchestrator.py +++ b/app/core/orchestrator.py @@ -1,25 +1,76 @@ +""" +orchestrator.py + +BUG FIX — Dead elif branch in the agentic streaming loop silently dropped events. + +The original if/elif chain: + + if event["type"] == "thought": ← catches "thought" + yield event + elif event["type"] == "answer": + full_answer += event["content"] + yield event + elif event["type"] in ("thought", "action", "observation"): ← DEAD — "thought" already caught above + label = ... + result_schema["metadata"]["logic_trace"].append(...) + yield event ← "action" and "observation" DO get appended to logic_trace here + elif event["type"] == "error": + yield event + else: + full_answer += str(event.get("content", "")) ← BUT "action"/"observation" never reach here + yield {"type": "answer", ...} + +The consequence: "action" and "observation" events were yielded (good), BUT their +content was NEVER appended to result_schema["metadata"]["logic_trace"], so the +final persist_log had empty reasoning. More critically, the order of branches +meant the logic_trace append for "thought" was ALSO skipped — the first branch +caught "thought" and just yielded it without logging it. + +This is a correctness bug but NOT the main cause of the blank UI. Documented here +for completeness; the primary fix is in schemas.py (missing request_id on Message) +and in frontend/app.py (sent_to_api=True for assistant messages). +""" + import logging import time import hashlib import json import re + +def _normalize_math(text: str) -> str: + """Inline replacement for math_renderer.render_math(). + Converts LaTeX delimiters to $...$ / $$...$$ for Streamlit MathJax. + Gemini 2.5 Flash mostly outputs $...$ already — this catches the rare + \\(...\\) and \\[...\\] variants and cleans ```math blocks. + """ + if not text: + return text + # Block: \[ ... \] → $$ ... $$ + import re as _re + text = _re.sub(r'\\\[(.+?)\\\]', r'$$\1$$', text, flags=_re.DOTALL) + # Inline: \( ... \) → $ ... $ + text = _re.sub(r'\\\((.+?)\\\)', r'$\1$', text, flags=_re.DOTALL) + # ```math blocks → $$ ... $$ + text = _re.sub(r'```math\s*(.+?)\s*```', r'$$\1$$', text, flags=_re.DOTALL) + # Empty $$$$ artifacts + text = _re.sub(r'\$\$\s*\$\$', '', text) + return text.strip() + import asyncio from typing import Any, Dict, Optional, AsyncGenerator -import sympy -from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application - from app.core.input_processor import InputProcessor from app.core.math_normalizer import MathQueryNormalizer, MathIntent from app.memory.cache import CacheManager +from app.core.sympy_solver import SymPySolver +from app.memory.semantic_cache import SemanticCache from app.memory.database import DatabaseManager -from app.agents.adk_mathminds import MathMindsADKAgent + from app.core.settings import settings +from app.agents.adk_mathminds import MathMindsADKAgent logger = logging.getLogger(__name__) -_SYMPY_TRANSFORMATIONS = standard_transformations + (implicit_multiplication_application,) - class Orchestrator: """ @@ -30,6 +81,7 @@ class Orchestrator: self, cache_manager: Optional[CacheManager] = None, db_manager: Optional[DatabaseManager] = None, + semantic_cache: Optional[SemanticCache] = None, redis_client: Any = None, ): try: @@ -39,6 +91,17 @@ class Orchestrator: self.db_manager = db_manager or DatabaseManager() self.redis_client = redis_client self.adk_agent = MathMindsADKAgent(redis_client=self.redis_client) + self.sympy_solver = SymPySolver() + + # Semantic cache — use injected instance from deps.py if provided, + # otherwise create internally. + if semantic_cache is not None: + self.semantic_cache = semantic_cache if settings.ENABLE_CACHE else None + else: + self.semantic_cache = SemanticCache( + redis_client = self.redis_client, + gemini_api_key = settings.GOOGLE_API_KEY, + ) if settings.ENABLE_CACHE else None except Exception as e: logger.critical(f"Failed to initialize Orchestrator: {e}") raise @@ -61,102 +124,212 @@ class Orchestrator: "status": "success", "source": "agent", "answer": "", - "metadata": { - "latency_ms": 0, - "model": "gemini-2.5-flash", - "tools_used": [], - "logic_trace": [] + "metadata": { + "latency_ms": 0, + "model": "gemini-2.5-flash", + "tools_used": [], + "logic_trace": [], }, } try: # ── 1. Input processing ─────────────────────────────────────────── - processed = self.input_processor.process_compound(text_input=query, image_input=image) + processed = self.input_processor.process_compound( + text_input=query, image_input=image + ) if not processed.is_valid: yield {"type": "error", "content": processed.error_message} return - query = processed.cleaned_content - image_data = processed.metadata.get("image_data") + query = processed.cleaned_content + image_data = processed.metadata.get("image_data") if processed.metadata else None - # 1.5. Persist user message (Safety Check: Don't duplicate) + # ── 1.5. Persist user message (idempotent) ──────────────────────── if user_id and session_id: - # Check if this exact request already exists in DB to prevent duplicates history = self.db_manager.get_chat_history(user_id, session_id) or [] if not any(m.get("request_id") == request_id for m in history): await self._persist_message( - user_id=user_id, session_id=session_id, role="user", + user_id=user_id, session_id=session_id, role="user", content=query or "Uploaded an image", image_data=image_data, - request_id=request_id + request_id=request_id, ) - # ── 2. Cache lookup ─────────────────────────────────────────────── + # ── 2. Cache lookup — two layers ────────────────────────────────── + # + # Layer 1 — Exact hash (Redis, microseconds, zero API cost) + # sha256(normalized_query) → instant lookup for identical questions + # + # Layer 2 — Semantic similarity (Redis embeddings, ~50ms, uses + # gemini-embedding-001 which has its OWN 1500 req/day quota, + # completely separate from the 20 req/day generate_content limit) + # Cosine similarity ≥ 0.85 → treat as same question + # + # Both layers are skipped for image queries (can't embed images). + cache_key = None + cached_answer = None + cache_source = None + if settings.ENABLE_CACHE and not image_data: cache_key = self._make_cache_key(query) - cached = self.cache_manager.get_cached_answer(cache_key) - if cached: - yield {"type": "thought", "content": "Retrieving answer from memory..."} - yield {"type": "answer", "content": cached.get("answer")} - # Persist assistant response + + # Layer 1: exact hash + exact = self.cache_manager.get_cached_answer(cache_key) + if exact: + cached_answer = exact.get("answer") + cache_source = "exact_cache" + logger.info(f"Cache layer 1 HIT (exact) for key={cache_key[:8]}") + + # Layer 2: semantic similarity (only if exact missed) + if not cached_answer and self.semantic_cache: + sem = self.semantic_cache.get(query) + if sem: + cached_answer = sem["answer"] + cache_source = f"semantic_cache (similarity={sem['similarity']}))" + logger.info(f"Cache layer 2 HIT (semantic) similarity={sem['similarity']}") + + if cached_answer: + yield {"type": "thought", "content": f"💾 Retrieving from memory ({cache_source})..."} + yield {"type": "answer", "content": cached_answer} if user_id and session_id: - await self._persist_log(query, {"answer": cached.get("answer"), "metadata": cached.get("metadata")}, user_id, session_id, cache_key) + await self._persist_log( + query, + {"answer": cached_answer, "metadata": {"source": cache_source}}, + user_id, session_id, cache_key, + request_id=request_id, + ) return - else: - cache_key = None - # ── 3. Pre-flight (SymPy) ───────────────────────────────────────── + # ── 3. SymPy Preflight ──────────────────────────────────────────── + # Try to solve symbolically BEFORE calling Gemini. + # Cost: 0 LLM calls. Handles derivatives, integrals, + # equations, limits, arithmetic in milliseconds. + # If SymPy can't solve it → falls through to Gemini. if not image_data: - preflight_result = self._try_sympy(query) - if preflight_result is not None: - yield {"type": "thought", "content": "Calculating result symbolically..."} - yield {"type": "answer", "content": preflight_result} - - result_schema.update({ - "source": "sympy_preflight", - "answer": preflight_result, - "metadata": {"model": "sympy", "tools_used": ["sympy"]} - }) - - await self._persist_log(query, result_schema, user_id, session_id, cache_key, request_id=request_id) - return + math_intent = self.normalizer.normalize(query) + if math_intent: + sympy_result = self.sympy_solver.solve(math_intent) + if sympy_result: + # sympy_solver.solve() returns a plain str, not a dict + answer = sympy_result + yield {"type": "thought", "content": f"⚡ Solving symbolically ({math_intent.intent})..."} + yield {"type": "answer", "content": _normalize_math(answer)} + result_schema["answer"] = answer + result_schema["metadata"]["source"] = "sympy_preflight" + result_schema["metadata"]["intent"] = math_intent.intent + if user_id and session_id: + await self._persist_log( + query, result_schema, + user_id, session_id, cache_key, + request_id=request_id, + ) + return # ── 4. Agentic Streaming Loop ───────────────────────────────────── + # FIX: The original had a dead elif branch. The chain was: + # if "thought" → yield + # elif "answer" → accumulate + yield + # elif ("thought","action","observation") → log + yield ← "thought" ALREADY matched above + # elif "error" → yield + # + # Result: "action" and "observation" were yielded but never logged to + # logic_trace. Rewritten as explicit branches with no dead code. full_answer = "" async for event in self.adk_agent.solve( - problem=query, image_data=image_data, - session_id=session_id, user_id=user_id + problem=query, image_data=image_data, + session_id=session_id, user_id=user_id, ): - if event["type"] == "thought": + ev_type = event.get("type", "") + content = event.get("content", "") + + if ev_type == "answer": + # Normalize LaTeX and SymPy notation before sending to frontend + content = _normalize_math(content) + full_answer += content + yield {**event, "content": content} + + elif ev_type == "thought": + result_schema["metadata"]["logic_trace"].append(content) yield event - elif event["type"] == "answer": - full_answer += event["content"] + + elif ev_type == "action": + result_schema["metadata"]["logic_trace"].append(f"⚙️ {content}") yield event - elif event["type"] in ("thought", "action", "observation"): - label = "" - if event["type"] == "action": label = "⚙️ " - elif event["type"] == "observation": label = "👁️ " - - result_schema["metadata"]["logic_trace"].append(f"{label}{event['content']}") + + elif ev_type == "observation": + result_schema["metadata"]["logic_trace"].append(f"👁️ {content}") yield event - elif event["type"] == "error": + + elif ev_type == "error": yield event + else: - # Fallback for any other content - full_answer += str(event.get("content", "")) - yield {"type": "answer", "content": str(event.get("content", ""))} + # Unexpected event type — treat as answer text so nothing is lost + if content: + full_answer += str(content) + yield {"type": "answer", "content": str(content)} # ── 5. Finalize ─────────────────────────────────────────────────── result_schema["answer"] = full_answer result_schema["metadata"]["latency_ms"] = int((time.time() - start_time) * 1000) - + if full_answer: - # AWAIT the final log instead of fire-and-forget to prevent race conditions with UI reloads. - await self._persist_log(query, result_schema, user_id, session_id, cache_key, request_id=request_id) + await self._persist_log( + query, result_schema, user_id, session_id, cache_key, + request_id=request_id, + ) except Exception as e: - logger.error(f"Orchestrator Error: {e}") + logger.error(f"Orchestrator Error: {e}", exc_info=True) yield {"type": "error", "content": f"Internal Error: {str(e)}"} + async def solve_problem( + self, + query: Optional[str] = None, + image: Optional[str] = None, + request_id: Optional[str] = None, + model_preference: str = "fast", + session_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Non-streaming version of solve_problem. + Executes the full agent loop and returns the final answer object. + """ + full_answer = "" + logic_trace = [] + error = None + + # We wrap the logic here to ensure we get a consistent response + # by consuming the stream which already handles persistence. + async for event in self.solve_problem_stream( + query=query, + image=image, + request_id=request_id, + model_preference=model_preference, + session_id=session_id, + user_id=user_id, + ): + ev_type = event.get("type") + content = event.get("content") + + if ev_type == "answer": + full_answer += content + elif ev_type in ("thought", "action", "observation"): + logic_trace.append(content) + elif ev_type == "error": + error = content + + return { + "request_id": request_id or "unknown", + "status": "error" if error else "success", + "answer": full_answer, + "error": error, + "metadata": { + "logic_trace": logic_trace, + "timestamp": time.time(), + } + } + async def _persist_message(self, user_id, session_id, role, content, **kwargs): try: self.db_manager.create_session(user_id, session_id) @@ -165,78 +338,34 @@ class Orchestrator: logger.error(f"Failed to persist message: {e}") async def _persist_log(self, query, schema, user_id, session_id, cache_key, request_id=None): - """Internal awaitable helper.""" - # Map logic_trace to reasoning for frontend consistency - reasoning = "\n".join(schema["metadata"].get("logic_trace", [])) - + reasoning = "\n".join(schema.get("metadata", {}).get("logic_trace", [])) await self._persist_message( user_id=user_id, session_id=session_id, role="assistant", - content=schema["answer"], reasoning=reasoning, metadata=schema["metadata"], - request_id=request_id + content=schema["answer"], reasoning=reasoning, + metadata=schema.get("metadata", {}), + request_id=request_id, ) if settings.ENABLE_CACHE and cache_key: + # Layer 1: exact hash cache self.cache_manager.set_cached_answer(cache_key, schema) - self.db_manager.save_problem({"content": query}, schema) - - def _try_sympy(self, query: str) -> Optional[str]: - try: - intent: Optional[MathIntent] = self.normalizer.normalize(query) - if intent is None: return None - expr_str = self._prep_expr(intent.expression) - target_var = sympy.Symbol(intent.variable or "x") - if intent.intent == "arithmetic": return self._solve_arithmetic(expr_str) - if intent.intent == "equation": return self._solve_equation(expr_str, target_var) - if intent.intent == "derivative": - expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS) - return f"d/d{target_var}({intent.expression}) = {sympy.latex(sympy.diff(expr, target_var))}" - if intent.intent == "integral": - expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS) - return f"∫({intent.expression}) d{target_var} = {sympy.latex(sympy.integrate(expr, target_var))} + C" - if intent.intent == "limit": return self._solve_limit(intent, query) - if intent.intent == "simplification": - expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS) - return f"Simplified: {sympy.latex(sympy.simplify(expr))}" - except Exception: pass - return None - - def _prep_expr(self, expr: str) -> str: - expr = expr.replace("^", "**") - expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr) - expr = re.sub(r"\)\s*\(", ")*(", expr) - return expr.strip() - - def _solve_arithmetic(self, expr_str: str) -> Optional[str]: - try: - result = sympy.simplify(parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)) - if result.is_number: - numeric = float(result) - return str(int(numeric)) if numeric == int(numeric) else f"{numeric:.6g}" - return sympy.latex(result) - except Exception: return None - - def _solve_equation(self, expr_str: str, var: sympy.Symbol) -> Optional[str]: - try: - parts = expr_str.split("=", 1) - if len(parts) == 2: - lhs = parse_expr(self._prep_expr(parts[0]), transformations=_SYMPY_TRANSFORMATIONS) - rhs = parse_expr(self._prep_expr(parts[1]), transformations=_SYMPY_TRANSFORMATIONS) - solution = sympy.solve(lhs - rhs, var) - else: - solution = sympy.solve(parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS), var) - if not solution: return "No solution found." - if len(solution) == 1: return f"{var} = {sympy.latex(solution[0])}" - return f"{var} ∈ {{{', '.join(sympy.latex(s) for s in solution)}}}" - except Exception: return None - - def _solve_limit(self, intent: MathIntent, original_query: str) -> Optional[str]: - try: - match = re.search(r"limit of\s+(.+?)\s+as\s+(\w+)\s+approaches\s+(.+)", original_query, re.IGNORECASE) - if not match: return None - expr = parse_expr(self._prep_expr(match.group(1)), transformations=_SYMPY_TRANSFORMATIONS) - var = sympy.Symbol(match.group(2).strip()) - point = parse_expr(self._prep_expr(match.group(3).strip()), transformations=_SYMPY_TRANSFORMATIONS) - return f"lim({var}→{point}) {sympy.latex(expr)} = {sympy.latex(sympy.limit(expr, var, point))}" - except Exception: return None + # Layer 2: semantic cache (stores embedding vector alongside answer) + # Only store if we have a real answer — don't cache errors/empty strings + # Wrapped in to_thread: semantic_cache.set() calls the embedding API + # (blocking HTTP). Running it in a thread means the response is already + # returned to the user before the cache write completes. + # Skip semantic cache write for SymPy answers — they are deterministic, + # so caching them via embedding similarity adds no value and wastes + # 1 embedding API call (out of the 1500/day quota). + source = schema.get("metadata", {}).get("source", "") + if self.semantic_cache and schema.get("answer") and source != "sympy_preflight": + await asyncio.to_thread( + self.semantic_cache.set, + query = query, + answer = schema["answer"], + metadata = schema.get("metadata", {}), + ) + # pymongo is sync — run in thread so it doesn't block the event loop + await asyncio.to_thread(self.db_manager.save_problem, {"content": query}, schema) def _make_cache_key(self, query: str) -> str: - return hashlib.sha256(query.strip().lower().encode()).hexdigest() + return hashlib.sha256(query.strip().lower().encode()).hexdigest() \ No newline at end of file diff --git a/app/core/schemas.py b/app/core/schemas.py index 48b4a09de4ca525b269de3cedd5b43aa0528bd09..247ef627fc29841b717753e099ad373b0625cc25 100644 --- a/app/core/schemas.py +++ b/app/core/schemas.py @@ -1,18 +1,36 @@ +""" +schemas.py + +BUG FIX — Message model was missing `request_id: Optional[str]`. + +Why this caused "no answer" on the UI: + The GET /chat/sessions/{id}/messages endpoint uses `response_model=List[Message]`. + FastAPI strips any field NOT declared in the model before sending the response. + So even though `save_chat_message(..., request_id=request_id)` correctly stores + request_id in MongoDB, FastAPI silently dropped it on the way back out. + + The frontend's load_messages() dedup merge keys on (role, request_id): + server_keys = {(m["role"], m["request_id"]) for m in server_msgs if m.get("request_id")} + + With request_id always None from server, server_keys was always empty. + On every load_messages() call, ALL local messages looked unconfirmed, so they + got appended again as duplicates — and on the next rerun the trigger condition + `role=="user" and not sent_to_api` re-fired, sending the question a second time + and overwriting the answer_placeholder before it could be seen. +""" + from datetime import datetime from typing import Any, Dict, Optional, List from pydantic import BaseModel, Field, model_validator + class SolveRequest(BaseModel): - """ - Request model for the /solve endpoint. - Supports text-only, image-only, or multimodal (text + image) input. - """ - text: Optional[str] = Field(None, description="The math problem text or specific question about the image.") - image: Optional[str] = Field(None, description="Base64 encoded image string or Image URL.") - session_id: Optional[str] = Field(None, description="Session ID for maintaining chat context.") - model_preference: Optional[str] = Field("fast", description="Model preference: 'fast' or 'reasoning'.") - request_id: Optional[str] = Field(None, description="Unique ID for deduplication.") - input: Optional[str] = Field(None, description="Legacy field for backward compatibility.", deprecated=True) + text: Optional[str] = Field(None, description="The math problem text.") + image: Optional[str] = Field(None, description="Base64 encoded image string.") + session_id: Optional[str] = Field(None, description="Session ID for chat context.") + model_preference: Optional[str] = Field("fast", description="'fast' or 'reasoning'.") + request_id: Optional[str] = Field(None, description="Unique ID for deduplication.") + input: Optional[str] = Field(None, description="Legacy field.", deprecated=True) @property def effective_text(self) -> Optional[str]: @@ -21,77 +39,73 @@ class SolveRequest(BaseModel): @model_validator(mode='before') @classmethod def check_input_compatibility(cls, values: Any) -> Any: - # Support legacy 'input' field if isinstance(values, dict): if 'input' in values and not values.get('text'): values['text'] = values['input'] return values - + @model_validator(mode='after') def check_at_least_one(self) -> 'SolveRequest': - text = self.text - image = self.image - # We don't check 'input' here because it should have been mapped to 'text' above - if not text and not image: - raise ValueError("At least one of 'text' or 'image' must be provided.") + if not self.text and not self.image: + raise ValueError("At least one of 'text' or 'image' must be provided.") return self + class SolveResponse(BaseModel): - """ - Response model for the /solve endpoint. - """ - request_id: str - status: str = Field(..., description="Status of the request (success/error).") - problem_type: str = "unknown" - source: str = "unknown" - answer: Any = Field(None, description="The structured answer from the AI. Can be str, float, or dict.") - steps: List[str] = Field(default_factory=list, description="A list of steps taken to solve the problem.") - explanation: Optional[str] = Field(None, description="A detailed explanation of the solution.") - confidence: float = Field(0.0, description="Confidence score of the answer.") - cached: bool = Field(False, description="Indicates if the response was served from cache.") - error: Optional[str] = Field(None, description="Error message if status is error.") - error_code: Optional[str] = Field(None, description="Error code if status is error.") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Metadata about the processing.") + request_id: str + status: str = Field(..., description="success/error") + problem_type: str = "unknown" + source: str = "unknown" + answer: Any = Field(None) + steps: List[str] = Field(default_factory=list) + explanation: Optional[str] = None + confidence: float = 0.0 + cached: bool = False + error: Optional[str] = None + error_code: Optional[str] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + class HealthResponse(BaseModel): - """ - Response model for the /health endpoint. - """ - status: str + status: str version: str -# --- Chat History Schemas --- class Message(BaseModel): - role: str - content: str - timestamp: datetime - reasoning: Optional[str] = None - metadata: Dict[str, Any] = {} - steps: List[str] = [] + role: str + content: str + timestamp: datetime + # FIX: This field was missing. FastAPI was stripping it from every response, + # breaking the frontend dedup merge and causing phantom re-triggers. + request_id: Optional[str] = None + reasoning: Optional[str] = None + metadata: Dict[str, Any] = {} + steps: List[str] = [] + class ChatSession(BaseModel): session_id: str - title: str + title: str created_at: datetime - # messages: Optional[List[Message]] = None # Optional for listing + class SessionRename(BaseModel): title: str = Field(..., min_length=1, max_length=100) -# --- Auth Schemas --- class UserSignup(BaseModel): - email: str - password: str = Field(..., min_length=8, max_length=72) + email: str + password: str = Field(..., min_length=8, max_length=72) full_name: Optional[str] = None + class UserLogin(BaseModel): - email: str + email: str password: str = Field(..., max_length=72) + class TokenResponse(BaseModel): access_token: str - token_type: str = "bearer" - user_id: str - email: str + token_type: str = "bearer" + user_id: str + email: str \ No newline at end of file diff --git a/app/core/settings.py b/app/core/settings.py index e74de26a756c4baca92fec04c20228215b5629af..995ac5d6ca296f709c21a7a94b2bd9edf36e7d60 100644 --- a/app/core/settings.py +++ b/app/core/settings.py @@ -37,7 +37,7 @@ class Settings(BaseSettings): ENABLE_LOCAL_MODELS: bool = True ENABLE_CACHE: bool = True ENABLE_AUTH: bool = True - MAX_LLM_CALLS_PER_DAY: int = 18 # Default limit per user per day + MAX_LLM_CALLS_PER_DAY: int = 100 # Default limit per user per day # Integrations FIREBASE_CREDENTIALS_PATH: Optional[str] = None diff --git a/app/core/sympy_solver.py b/app/core/sympy_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..30e733083847701a0ce5293fa908174fa2797b97 --- /dev/null +++ b/app/core/sympy_solver.py @@ -0,0 +1,120 @@ +import logging +import sympy +from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application, convert_xor +from typing import Optional, Any +from app.core.math_normalizer import MathIntent + +logger = logging.getLogger(__name__) + +class SymPySolver: + """ + Attempts to solve mathematical expressions using SymPy. + Used as a pre-flight check to save LLM quota for pure math. + """ + + def solve(self, intent: MathIntent) -> Optional[str]: + """ + Processes a MathIntent and returns a formatted solution string or None. + """ + try: + expr_str = intent.expression + action = intent.intent + var_symbol = sympy.Symbol(intent.variable or 'x') + + if action == "derivative": + return self._solve_derivative(expr_str, var_symbol) + elif action == "integral": + return self._solve_integral(expr_str, var_symbol) + elif action == "equation": + return self._solve_equation(expr_str, var_symbol) + elif action == "arithmetic": + return self._solve_arithmetic(expr_str) + + return None + except Exception as e: + logger.info(f"SymPy could not solve '{intent.expression}': {e}") + return None + + def _parse(self, expr_str: str) -> Any: + transformations = standard_transformations + (implicit_multiplication_application, convert_xor) + return parse_expr(expr_str, transformations=transformations) + + def _solve_derivative(self, expr_str: str, var: sympy.Symbol) -> Optional[str]: + expr = self._parse(expr_str) + result = sympy.diff(expr, var) + return f"The derivative of ${sympy.latex(expr)}$ with respect to ${var}$ is:\n\n$${sympy.latex(result)}$$" + + def _solve_integral(self, expr_str: str, var: sympy.Symbol) -> Optional[str]: + expr = self._parse(expr_str) + result = sympy.integrate(expr, var) + # Check if integral was actually solved (not just returned as an Integral object) + if isinstance(result, sympy.Integral): + return None + return f"The indefinite integral of ${sympy.latex(expr)}$ with respect to ${var}$ is:\n\n$${sympy.latex(result)} + C$$" + + def _solve_equation(self, expr_str: str, var: sympy.Symbol) -> Optional[str]: + # Handle equations like "x^2 - 4 = 0" or "x^2 = 4" + if "=" in expr_str: + lhs_str, rhs_str = expr_str.split("=") + lhs = self._parse(lhs_str.strip()) + rhs = self._parse(rhs_str.strip()) + eq = sympy.Eq(lhs, rhs) + else: + # Assume expression = 0 if no '=' + eq = self._parse(expr_str) + + solutions = sympy.solve(eq, var) + if not solutions: + return "No solutions found." + + sol_str = ", ".join([f"${sympy.latex(s)}$" for s in solutions]) + return f"The solutions for ${sympy.latex(eq if '=' in expr_str else sympy.Eq(self._parse(expr_str), 0))}$ are:\n\n{sol_str}" + + def _solve_arithmetic(self, expr_str: str) -> Optional[str]: + # SAFETY CHECK: reject expressions containing non-math words + # If the expression has alphabetic characters that aren't recognised + # math symbols (e, i, pi, x, y, z, etc.), SymPy silently treats each + # letter as a variable and multiplies them together — producing garbled + # output like "45aeflouv" on the UI. + # Example: "the value of 5*9" → SymPy sees t*h*e*v*a*l*u*e*o*f*5*9 + # + # Rule: if the expression contains any English word characters beyond + # known math constants, return None and let Gemini handle it. + import re + # Strip pure math tokens to see what's left + stripped = re.sub(r'[0-9+\-*/^().\s]', '', expr_str) + # Known single-letter math constants that SymPy handles correctly + safe_single_letters = set('eijxyz') + # Known multi-letter constants/functions + safe_words = {'pi', 'inf', 'oo', 'sin', 'cos', 'tan', 'log', 'exp', + 'sqrt', 'abs', 'floor', 'ceil'} + + # Check for multi-char letter sequences (words) that aren't math + words_in_expr = re.findall(r'[a-zA-Z]+', expr_str) + for word in words_in_expr: + if word.lower() not in safe_words and len(word) > 1: + # Multi-letter word that isn't a math function — natural language crept in + logger.debug(f"SymPy arithmetic rejected: found word '{word}' in '{expr_str}'") + return None + + try: + expr = self._parse(expr_str) + result = expr.evalf() if expr.is_number else sympy.simplify(expr) + + # If result is same as input and not a simple number, let Gemini handle it + if str(result) == expr_str and not result.is_number: + return None + + # Format result cleanly — integer if possible, float otherwise + try: + numeric = float(result) + if numeric == int(numeric): + display = str(int(numeric)) + else: + display = f"{numeric:.4f}".rstrip('0').rstrip('.') + return f"Result: **{display}**\n\n$${ sympy.latex(self._parse(expr_str))} = {sympy.latex(result)}$$" + except Exception: + return f"Result of evaluation:\n\n$${sympy.latex(result)}$$" + except Exception as e: + logger.debug(f"SymPy arithmetic eval failed for '{expr_str}': {e}") + return None \ No newline at end of file diff --git a/app/memory/cache.py b/app/memory/cache.py index 028aebd5ec9204a1d58b73b35187a4e08989d61e..fbd04cd6e2ff449c442e35820501feb66ace2ca7 100644 --- a/app/memory/cache.py +++ b/app/memory/cache.py @@ -1,131 +1,171 @@ import json import logging -import os from typing import Any, Dict, Optional -from app.core.settings import settings import redis from redis.exceptions import RedisError -# Configure logging +from app.core.settings import settings + logger = logging.getLogger(__name__) + class CacheManager: - """ - Manages Redis cache operations for the AI system. - Handles connections, serialization, and failure scenarios gracefully. - """ - - def __init__(self, redis_url: Optional[str] = None, connection_pool: Optional[redis.ConnectionPool] = None): - """ - Initialize the CacheManager. - - Args: - redis_url: Redis connection string (used if pool not provided). - connection_pool: Existing Redis connection pool. - """ + + CACHE_PREFIX = "mathminds:cache:" + MAX_CACHE_SIZE = 50000 + + def __init__( + self, + redis_url: Optional[str] = None, + connection_pool: Optional[redis.ConnectionPool] = None, + ): + self.redis_url = redis_url or settings.REDIS_URL self.redis_client = None - + try: + if connection_pool: - self.redis_client = redis.Redis(connection_pool=connection_pool, decode_responses=True) + self.redis_client = redis.Redis( + connection_pool=connection_pool, + decode_responses=True, + ) else: - # If no pool provided, create standard client (which uses internal pool) - # But typically we want to pass the pool. - self.redis_client = redis.from_url(self.redis_url, decode_responses=True) - - # Fast ping to verify connection + self.redis_client = redis.from_url( + self.redis_url, + decode_responses=True, + socket_timeout=2, + socket_connect_timeout=2, + ) + self.redis_client.ping() - logger.info(f"Successfully connected to Redis at {self.redis_url}") - + + logger.info(f"Connected to Redis at {self.redis_url}") + except RedisError as e: - logger.error(f"Failed to connect to Redis: {e}") + + logger.error(f"Redis connection failed: {e}") self.redis_client = None - # _connect method is removed/merged into __init__ since we prefer injection + def _serialize(self, data: Any) -> str: + return json.dumps(data, default=str) - def get_cached_answer(self, cache_key: str) -> Optional[Dict[str, Any]]: - """ - Retrieve a cached answer by its hash key. + def _prefixed(self, key: str) -> str: + return f"{self.CACHE_PREFIX}{key}" - Args: - cache_key: The unique hash key for the problem. + def get_cached_answer(self, cache_key: str) -> Optional[Dict[str, Any]]: - Returns: - Optional[Dict[str, Any]]: The cached answer info if found and valid, else None. - """ if not self.redis_client: - logger.warning("Redis client is not available. Skipping cache lookup.") return None try: - data = self.redis_client.get(cache_key) + + key = self._prefixed(cache_key) + + data = self.redis_client.get(key) + if data: - logger.info(f"Cache hit for key: {cache_key}") + logger.debug(f"Cache hit: {key}") return json.loads(data) - logger.debug(f"Cache miss for key: {cache_key}") - return None - except RedisError as e: - logger.error(f"Redis error during get operations: {e}") - return None - except json.JSONDecodeError as e: - logger.error(f"Failed to decode cached data for key {cache_key}: {e}") + return None - def set_cached_answer(self, cache_key: str, answer: Dict[str, Any], ttl: int = 86400) -> bool: - """ - Cache an answer with a TTL. + except (RedisError, json.JSONDecodeError) as e: - Args: - cache_key: The unique hash key. - answer: The answer data to cache (will be JSON serialized). - ttl: Time-to-live in seconds. Defaults to 86400 (24 hours). + logger.error(f"Cache read error: {e}") + return None + + def set_cached_answer( + self, + cache_key: str, + answer: Dict[str, Any], + ttl: int = 86400, + ) -> bool: - Returns: - bool: True if successful, False otherwise. - """ if not self.redis_client: - logger.warning("Redis client is not available. Skipping cache write.") return False try: - serialized_data = json.dumps(answer) - self.redis_client.setex(cache_key, ttl, serialized_data) - logger.info(f"Successfully cached answer for key: {cache_key} with TTL {ttl}") + + key = self._prefixed(cache_key) + + serialized_data = self._serialize(answer) + + if len(serialized_data) > self.MAX_CACHE_SIZE: + logger.warning("Cache skipped: payload too large") + return False + + self.redis_client.setex(key, ttl, serialized_data) + return True + except (RedisError, TypeError) as e: - # TypeError catches JSON serialization errors - logger.error(f"Failed to cache answer for key {cache_key}: {e}") + + logger.error(f"Cache write failed: {e}") return False - def set_if_not_exists(self, cache_key: str, answer: Dict[str, Any], ttl: int = 86400) -> bool: - """ - Set cache only if key doesn't exist (atomic operation). - Prevents thundering herd when multiple requests populate cache. - - Args: - cache_key: The unique hash key. - answer: The answer data to cache. - ttl: Time-to-live in seconds. - - Returns: - bool: True if set, False if key already existed or error. - """ + def set_if_not_exists( + self, + cache_key: str, + answer: Dict[str, Any], + ttl: int = 86400, + ) -> bool: + if not self.redis_client: return False - + try: - serialized_data = json.dumps(answer) - # SETNX is atomic - only succeeds if key doesn't exist - # Redis-py set() with nx=True is equivalent to SETNX + EXPIRE + + key = self._prefixed(cache_key) + + serialized_data = self._serialize(answer) + result = self.redis_client.set( - cache_key, - serialized_data, - ex=ttl, - nx=True # Only set if not exists + key, + serialized_data, + ex=ttl, + nx=True, ) + return bool(result) + except Exception as e: - logger.error(f"Failed to set_if_not_exists for {cache_key}: {e}") + + logger.error(f"set_if_not_exists failed: {e}") return False + + def delete(self, cache_key: str) -> bool: + + if not self.redis_client: + return False + + try: + + key = self._prefixed(cache_key) + + return bool(self.redis_client.delete(key)) + + except RedisError: + + return False + + def stats(self) -> Dict[str, Any]: + + if not self.redis_client: + return {} + + try: + + info = self.redis_client.info() + + return { + "used_memory": info.get("used_memory_human"), + "connected_clients": info.get("connected_clients"), + "keyspace_hits": info.get("keyspace_hits"), + "keyspace_misses": info.get("keyspace_misses"), + } + + except Exception: + + return {} \ No newline at end of file diff --git a/app/memory/semantic_cache.py b/app/memory/semantic_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..508a12f347e400ac980d430d4f7117615f9abc53 --- /dev/null +++ b/app/memory/semantic_cache.py @@ -0,0 +1,245 @@ +""" +app/memory/semantic_cache.py — Semantic (meaning-aware) cache for MathMinds AI. + +Architecture +──────────── +Exact hash cache (Redis) ← microseconds, free + ↓ MISS +Semantic vector cache (Redis) ← ~50ms, free (embedding stored in Redis) + ↓ MISS +Gemini API call ← costs 1 quota unit + +Why two layers? + - Exact cache: zero cost, handles identical repeated questions instantly. + - Semantic cache: handles paraphrases. Uses Google's gemini-embedding-001 + to embed both the query and stored questions, then finds nearest neighbour + by cosine similarity. Entirely self-contained in Redis — no Supabase needed. + +Redis key design + semantic:index → Redis Set — all embedding keys + semantic:emb:{hash} → JSON {query, embedding, answer, metadata, timestamp} + +Similarity threshold: 0.85 + - 0.85+ → same mathematical question, different words (safe to return) + - 0.70-0.85 → related topic, probably different question (skip) + - <0.70 → unrelated + +Quota cost of embeddings + gemini-embedding-001 is NOT counted against the generate_content quota. + It has its own free tier: 1500 requests/day — far more than the 20/day + generate limit, so semantic lookup is essentially free to run. +""" + +import json +import logging +import hashlib +import time +import math +from typing import Optional, Dict, Any, List, Tuple + +logger = logging.getLogger(__name__) + +# ── Similarity threshold ─────────────────────────────────────────────────── +# Tested against math paraphrase pairs. Lower = more aggressive matching. +SIMILARITY_THRESHOLD = 0.85 + +# Redis key prefixes +_PREFIX_EMB = "semantic:emb:" # stores embedding + answer +_INDEX_KEY = "semantic:index" # set of all embedding hashes +_TTL_SECONDS = 7 * 24 * 3600 # 7 days + + +def _cosine_similarity(a: List[float], b: List[float]) -> float: + """Pure-Python cosine similarity. No numpy needed.""" + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +def _normalize_query(query: str) -> str: + """ + Light normalization before embedding. + Removes punctuation noise but keeps math symbols — '2+2' and '2 + 2' + should map to the same embedding region. + """ + import re + q = query.lower().strip() + # collapse whitespace + q = re.sub(r"\s+", " ", q) + return q + + +class SemanticCache: + """ + Semantic similarity cache backed by Redis. + + Usage (in orchestrator): + sc = SemanticCache(redis_client, gemini_client) + + # Lookup + result = sc.get(query) + if result: + return result["answer"] + + # Store after getting answer from API + sc.set(query, answer_text, metadata) + """ + + def __init__(self, redis_client, gemini_api_key: str): + self.redis = redis_client + self._api_key = gemini_api_key + self._genai = None # lazy init + + def _get_client(self): + """Lazy-init google.genai client so import errors are surfaced clearly.""" + if self._genai is None: + try: + from google import genai + self._genai = genai.Client(api_key=self._api_key) + except Exception as e: + logger.error(f"SemanticCache: failed to init genai client: {e}") + raise + return self._genai + + def _embed(self, text: str) -> Optional[List[float]]: + """ + Generate embedding vector for text. + Uses gemini-embedding-001 (NOT counted against generate_content quota). + Returns None on failure so cache misses gracefully on API errors. + """ + try: + from google.genai import types + client = self._get_client() + resp = client.models.embed_content( + model="models/gemini-embedding-001", + contents=_normalize_query(text), + config=types.EmbedContentConfig(output_dimensionality=768), + ) + return resp.embeddings[0].values + except Exception as e: + logger.warning(f"SemanticCache: embedding failed: {e}") + return None + + def _query_hash(self, query: str) -> str: + return hashlib.sha256(_normalize_query(query).encode()).hexdigest()[:16] + + # ── Public API ───────────────────────────────────────────────────────── + + def get(self, query: str) -> Optional[Dict[str, Any]]: + """ + Look up a semantically similar cached answer. + + Returns dict with keys: answer, metadata, source, similarity + Returns None on cache miss or any error. + """ + if not self.redis: + return None + + try: + # Get all stored embedding keys + keys = self.redis.smembers(_INDEX_KEY) + if not keys: + return None + + # Embed the incoming query + query_vec = self._embed(query) + if query_vec is None: + return None + + best_score = 0.0 + best_entry = None + + for key in keys: + raw = self.redis.get(f"{_PREFIX_EMB}{key}") + if not raw: + continue + try: + entry = json.loads(raw) + except json.JSONDecodeError: + continue + + stored_vec = entry.get("embedding") + if not stored_vec: + continue + + score = _cosine_similarity(query_vec, stored_vec) + if score > best_score: + best_score = score + best_entry = entry + + if best_score >= SIMILARITY_THRESHOLD and best_entry: + logger.info( + f"SemanticCache HIT | similarity={best_score:.3f} | " + f"query='{query[:60]}' matched '{best_entry.get('query','')[:60]}'" + ) + return { + "answer": best_entry["answer"], + "metadata": best_entry.get("metadata", {}), + "source": "semantic_cache", + "similarity": round(best_score, 3), + } + + logger.debug(f"SemanticCache MISS | best_score={best_score:.3f} | query='{query[:60]}'") + return None + + except Exception as e: + logger.error(f"SemanticCache.get failed: {e}") + return None + + def set(self, query: str, answer: str, metadata: Optional[Dict] = None) -> bool: + """ + Store a query+answer with its embedding vector. + Silent on failure — caching is best-effort. + """ + if not self.redis or not answer: + return False + + try: + embedding = self._embed(query) + if embedding is None: + return False + + key = self._query_hash(query) + entry = { + "query": _normalize_query(query), + "answer": answer, + "metadata": metadata or {}, + "embedding": embedding, + "timestamp": time.time(), + } + self.redis.setex( + f"{_PREFIX_EMB}{key}", + _TTL_SECONDS, + json.dumps(entry), + ) + self.redis.sadd(_INDEX_KEY, key) + self.redis.expire(_INDEX_KEY, _TTL_SECONDS) + + logger.info(f"SemanticCache SET | key={key} | query='{query[:60]}'") + return True + + except Exception as e: + logger.error(f"SemanticCache.set failed: {e}") + return False + + def invalidate(self, query: str) -> bool: + """Remove a specific entry (e.g. if answer was wrong).""" + try: + key = self._query_hash(query) + self.redis.delete(f"{_PREFIX_EMB}{key}") + self.redis.srem(_INDEX_KEY, key) + return True + except Exception as e: + logger.error(f"SemanticCache.invalidate failed: {e}") + return False + + def stats(self) -> Dict[str, Any]: + """How many entries are cached.""" + try: + count = self.redis.scard(_INDEX_KEY) if self.redis else 0 + return {"entries": count, "threshold": SIMILARITY_THRESHOLD} + except Exception: + return {"entries": 0, "threshold": SIMILARITY_THRESHOLD} diff --git a/app/services/automation.py b/app/services/automation.py index a3cbe4d5a65beb12d62071b8e821658b1b78b4bd..4fbe686ba7500c8de01c309dbe7f96ef7ae76bc2 100644 --- a/app/services/automation.py +++ b/app/services/automation.py @@ -1,6 +1,8 @@ import logging import httpx +import asyncio from typing import Dict, Any, Optional +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from app.core.settings import settings logger = logging.getLogger(__name__) @@ -14,6 +16,12 @@ class AutomationService: def __init__(self, webhook_url: Optional[str] = None): self.webhook_url = webhook_url or settings.N8N_WEBHOOK_URL + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + retry=retry_if_exception_type((httpx.HTTPError, asyncio.TimeoutError)), + reraise=True + ) async def trigger(self, event_name: str, payload: Dict[str, Any]) -> Dict[str, Any]: """ Triggers an n8n workflow by sending a POST request to a webhook. @@ -24,9 +32,11 @@ class AutomationService: try: # Add metadata to the payload + # Use datetime directly since settings.datetime might not exist reliably + from datetime import datetime data = { "event": event_name, - "timestamp": settings.datetime.now().isoformat() if hasattr(settings, 'datetime') else None, + "timestamp": datetime.now().isoformat(), "environment": settings.ENV, "data": payload } @@ -43,11 +53,14 @@ class AutomationService: return {"status": "success", "response": response.json() if response.content else "OK"} else: logger.error(f"n8n automation failed with status {response.status_code}: {response.text}") + # We raise here to trigger tenacity retry if it's a 5xx or transient + if 500 <= response.status_code < 600: + raise httpx.HTTPStatusError(f"Server Error {response.status_code}", request=None, response=response) return {"status": "error", "code": response.status_code, "detail": response.text} except Exception as e: logger.error(f"Error triggering n8n automation: {e}") - return {"status": "error", "detail": str(e)} + raise # Re-raise to let tenacity catch it and retry if it matches the types # Singleton instance automation_service = AutomationService() diff --git a/app/tools/symbolic_solver.py b/app/tools/symbolic_solver.py deleted file mode 100644 index 82a44f910c16914246fce7791e9708f23ff2081c..0000000000000000000000000000000000000000 --- a/app/tools/symbolic_solver.py +++ /dev/null @@ -1,162 +0,0 @@ -import logging -import os -import asyncio -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Any, Optional, Union -import sympy -from sympy.parsing.sympy_parser import parse_expr -from app.core.math_normalizer import MathIntent -from app.core.settings import settings - - - -logger = logging.getLogger(__name__) - -class SymbolicSolver: - """ - Tool for solving math problems symbolically. - Prioritizes WolframAlpha (if AppID present), falls back to SymPy. - """ - - def __init__(self, wolfram_app_id: Optional[str] = None): - self.wolfram_app_id = wolfram_app_id or settings.WOLFRAM_APP_ID - logger.info(f"Initializing SymbolicSolver. WolframAppID present: {bool(self.wolfram_app_id)}") - - async def solve(self, query: Union[str, MathIntent]) -> Dict[str, Any]: - """ - Attempts to solve the query symbolically. - Accepts either a raw string (tried via Wolfram) or a structured MathIntent (for SymPy). - """ - # Unwrap intent if passed - intent = None - raw_query = query - if isinstance(query, MathIntent): - intent = query - raw_query = intent.original_query or intent.expression - - logger.info(f"SymbolicSolver triggered for query: {raw_query}") - - # 1. Try WolframAlpha (best for natural language or complex stuff) - if self.wolfram_app_id: - try: - import httpx - import urllib.parse - - # Construct URL manually to avoid library assertion errors - # We request JSON output for easier parsing - encoded_query = urllib.parse.quote(raw_query) - url = f"https://api.wolframalpha.com/v2/query?appid={self.wolfram_app_id}&input={encoded_query}&output=json" - - async with httpx.AsyncClient() as client: - response = await client.get(url, timeout=30.0) - - if response.status_code != 200: - logger.warning(f"WolframAlpha API returned status {response.status_code}") - else: - data = response.json() - query_result = data.get("queryresult", {}) - - success = query_result.get("success") - error = query_result.get("error") - - logger.info(f"Wolfram Response: success={success}, error={error}") - - if not success: - logger.warning(f"Wolfram query returned success=false. Error: {error}") - else: - answer_text = "" - pods = query_result.get("pods", []) - - for pod in pods: - title = pod.get("title", "Result") - for sub in pod.get("subpods", []): - plaintext = sub.get("plaintext") - if plaintext: - answer_text += f"{title}: {plaintext}\n" - - if answer_text: - return { - "source": "wolfram_alpha", - "content": answer_text, - "status": "success" - } - - except Exception as e: - import traceback - logger.warning(f"WolframAlpha query failed: {repr(e)}\n{traceback.format_exc()}") - - # 2. Try SymPy (Local Fallback) - # We need a structured intent for SymPy to work reliably. - # If we just got a string and Wolfram failed, we can't easily use SymPy - # unless it was already normalized. - - if not intent: - return { - "source": "symbolic_solver", - "error": "WolframAlpha failed and no structured MathIntent provided for SymPy.", - "status": "error" - } - - try: - # Pre-processing for SymPy syntax - # handle power operator ^ -> ** - expr_str = intent.expression.replace("^", "**") - - # handle implicit multiplication (simple regex) - import re - expr_str = re.sub(r'(\d)([a-z])', r'\1*\2', expr_str) - expr_str = re.sub(r'\)\(', ')*(', expr_str) - - target_var = sympy.symbols(intent.variable or 'x') - result_latex = "" - - if intent.intent == "derivative": - expr = parse_expr(expr_str) - res = sympy.diff(expr, target_var) - result_latex = sympy.latex(res) - - elif intent.intent == "integral": - expr = parse_expr(expr_str) - res = sympy.integrate(expr, target_var) - result_latex = sympy.latex(res) - - elif intent.intent == "equation": - # Expecting "lhs = rhs" or just expression assumed = 0 - parts = expr_str.split("=") - if len(parts) == 2: - lhs = parse_expr(parts[0]) - rhs = parse_expr(parts[1]) - solution = sympy.solve(lhs - rhs, target_var) - else: - # Assume expr = 0 - expr = parse_expr(expr_str) - solution = sympy.solve(expr, target_var) - - result_latex = sympy.latex(solution) - - elif intent.intent == "limit": - # TODO: Parsing limits needs 'approaches' value, logic not fully here yet - # Fallback implementation - return {"source": "symbolic_solver", "status": "error", "error": "Limit parsing not fully implemented"} - - elif intent.intent == "arithmetic" or intent.intent == "simplification": - expr = parse_expr(expr_str) - res = sympy.simplify(expr) - result_latex = sympy.latex(res) - - else: - return {"source": "symbolic_solver", "status": "error", "error": f"Unknown intent: {intent.intent}"} - - return { - "source": "sympy_local", - "content": result_latex, - "status": "success" - } - - except Exception as e: - logger.warning(f"SymPy execution failed: {e}") - return { - "source": "symbolic_solver", - "error": str(e), - "status": "error" - } diff --git a/check_agent.py b/check_agent.py deleted file mode 100644 index 879fe6b6bd9d065588ccdc894afbb6665c66d343..0000000000000000000000000000000000000000 --- a/check_agent.py +++ /dev/null @@ -1,12 +0,0 @@ - -try: - from google.adk.agents import Agent - print("Agent class found in google.adk.agents") -except ImportError: - print("Agent class NOT found in google.adk.agents") - -try: - from google.adk.agents import LlmAgent - print("LlmAgent class found in google.adk.agents") -except ImportError: - print("LlmAgent class NOT found in google.adk.agents") diff --git a/check_redis.py b/check_redis.py deleted file mode 100644 index 5e6ae4808580c390dd9c2c7eb9fb91ab7ba98852..0000000000000000000000000000000000000000 --- a/check_redis.py +++ /dev/null @@ -1,18 +0,0 @@ -import redis -import os -from dotenv import load_dotenv - -load_dotenv() -redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") - -def check_redis(): - print(f"Checking Redis at: {redis_url}") - try: - r = redis.from_url(redis_url) - r.ping() - print("✅ Redis is UP!") - except Exception as e: - print(f"❌ Redis is DOWN or unreachable: {e}") - -if __name__ == "__main__": - check_redis() diff --git a/db_diag.py b/db_diag.py deleted file mode 100644 index bbf66841d187b2cc3525f42bbd0db7489b7b7f2b..0000000000000000000000000000000000000000 --- a/db_diag.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -from pymongo import MongoClient -from dotenv import load_dotenv - -load_dotenv() - -mongo_uri = os.getenv("MONGO_URI") -client = MongoClient(mongo_uri) -db = client.mathminds_db -# FIXED COLLECTION NAME -sessions = db.chat_sessions - -print("LAST 3 SESSIONS:") -for s in sessions.find().sort("created_at", -1).limit(3): - print(f"Session: {s.get('session_id')} | User: {s.get('user_id')}") - print(f"Title: {s.get('title')}") - msgs = s.get("messages", []) - print(f"Messages Count: {len(msgs)}") - for m in msgs[-10:]: - print(f" [{m.get('role')}] {m.get('content')[:100]} (RID: {m.get('request_id')})") - print("-" * 20) diff --git a/debug_adk.py b/debug_adk.py deleted file mode 100644 index aa9d05d37531cfb161ff36d6c5895cba7edd2ede..0000000000000000000000000000000000000000 --- a/debug_adk.py +++ /dev/null @@ -1,23 +0,0 @@ - -import sys -try: - import google - print("google imported") - print(dir(google)) - - try: - import google.adk - print("google.adk imported") - print(dir(google.adk)) - except ImportError as e: - print(f"Failed to import google.adk: {e}") - - try: - from google import adk - print("from google import adk succeeded") - print(dir(adk)) - except ImportError as e: - print(f"Failed to from google import adk: {e}") - -except ImportError as e: - print(f"Failed to import google: {e}") diff --git a/debug_adk_events.py b/debug_adk_events.py deleted file mode 100644 index d81498dfa6fb6e8eddde9ca1dd6ccdb03939797c..0000000000000000000000000000000000000000 --- a/debug_adk_events.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -Run this script STANDALONE — no FastAPI needed. -It directly invokes the ADK agent and prints every single event it emits, -so we can see exactly what is_final_response() returns and what text we get. - -Usage: - cd E:\madhuri\mathminds - python debug_adk_events.py -""" - -import asyncio -import sys -import os -sys.path.insert(0, os.getcwd()) - -from google.adk.agents import Agent -from google.adk.runners import Runner -from google.adk.sessions.in_memory_session_service import InMemorySessionService -from google.genai import types -from dotenv import load_dotenv -load_dotenv() - -QUESTION = "what is 9 + 8" - -async def main(): - from app.core.settings import settings - - agent = Agent( - name="math_minds_core", - model="gemini-2.5-flash", - tools=[], - instruction="You are a math assistant. Answer concisely." - ) - - session_service = InMemorySessionService() - runner = Runner( - app_name="mathminds_debug", - agent=agent, - session_service=session_service - ) - - await session_service.create_session( - app_name="mathminds_debug", - user_id="debug_user", - session_id="debug_session" - ) - - print(f"\nQuestion: {QUESTION}\n{'='*60}") - - all_text = "" - final_text = "" - event_num = 0 - - async for event in runner.run_async( - user_id="debug_user", - session_id="debug_session", - new_message=types.Content(role="user", parts=[types.Part.from_text(text=QUESTION)]) - ): - event_num += 1 - event_type = type(event).__name__ - author = getattr(event, "author", "N/A") - - # Check is_final_response - has_ifr = hasattr(event, "is_final_response") and callable(event.is_final_response) - is_final = event.is_final_response() if has_ifr else "method missing" - - print(f"\n[Event #{event_num}]") - print(f" type : {event_type}") - print(f" author : {author}") - print(f" is_final_response : {is_final}") - print(f" has content : {bool(event.content)}") - - if event.content and event.content.parts: - for i, part in enumerate(event.content.parts): - print(f" part[{i}].text : {repr(part.text)}") - print(f" part[{i}].function_call : {bool(getattr(part, 'function_call', None))}") - print(f" part[{i}].function_resp : {bool(getattr(part, 'function_response', None))}") - if part.text: - all_text += part.text - if is_final is True: - final_text += part.text - if author == "math_minds_core": - final_text += part.text - - print(f"\n{'='*60}") - print(f"Total events : {event_num}") - print(f"all_text (fallback): {repr(all_text)}") - print(f"final_text : {repr(final_text)}") - print(f"RESULT WOULD BE : {repr((final_text or all_text).strip())}") - -asyncio.run(main()) diff --git a/debug_celery_worker.py b/debug_celery_worker.py deleted file mode 100644 index 50f861db185bf07ccf8d200771d7ebe5aba4a3d3..0000000000000000000000000000000000000000 --- a/debug_celery_worker.py +++ /dev/null @@ -1,49 +0,0 @@ -import asyncio -import logging -import sys -import os - -# Add the current directory to sys.path so we can import 'app' -sys.path.append(os.getcwd()) - -from app.worker.tasks import scrape_task -import time - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -async def debug_scrape(): - print("Triggering Celery Scrape Task...") - query = "gold rate in india today" - - try: - # Dispatch task - result = scrape_task.delay(query) - print(f"Task ID: {result.id}") - - # Wait for result - start_time = time.time() - max_wait = 60 # seconds - - while time.time() - start_time < max_wait: - if result.ready(): - print("Task Ready!") - print("Result Status:", result.status) - # Safely handle potential encoding issues when printing to console - try: - res_content = str(result.result) - print("Result Content (partial):", res_content[:200].encode('ascii', 'ignore').decode('ascii')) - except Exception as e: - print(f"Result received, but print failed: {e}") - return - - print(f"Waiting... (status: {result.status})") - await asyncio.sleep(2) - - print("Task timed out. Is the worker running?") - - except Exception as e: - print(f"Dispatch failed: {e}") - -if __name__ == "__main__": - asyncio.run(debug_scrape()) diff --git a/debug_env.py b/debug_env.py deleted file mode 100644 index a719c95be550f75922deefdfb45ce25ecb9056a2..0000000000000000000000000000000000000000 --- a/debug_env.py +++ /dev/null @@ -1,23 +0,0 @@ - -import sys -import os - -print(f"Python Executable: {sys.executable}") -print(f"Python Version: {sys.version}") -print(f"Sys Path: {sys.path}") - -try: - import langchain - print(f"LangChain Version: {langchain.__version__}") - print(f"LangChain Path: {langchain.__file__}") -except ImportError as e: - print(f"ImportError: {e}") -except Exception as e: - print(f"Error: {e}") - -# Verify Agent Import -try: - from app.agents.langchain_mathminds import MathMindsLangChainAgent - print("✅ MathMindsLangChainAgent imported.") -except ImportError as e: - print(f"Agent Import Failed: {e}") diff --git a/debug_history.py b/debug_history.py deleted file mode 100644 index e6ffcbb2132d25bc230d52500468bf9b36df3f78..0000000000000000000000000000000000000000 --- a/debug_history.py +++ /dev/null @@ -1,15 +0,0 @@ - -import json -try: - with open('chat_history.json') as f: - data = json.load(f) - last_session_id = sorted(data.keys(), key=lambda k: data[k].get('created_at', 0))[-1] - print(f"Session: {last_session_id}") - messages = data[last_session_id]['messages'] - print(f"Total messages: {len(messages)}") - for i, m in enumerate(messages): - print(f"Index {i}: Role: {m['role']}") - print(f" Sent to API: {m.get('sent_to_api')}") - print(f" Content: {repr(m['content'])[:100]}...") -except Exception as e: - print(f"Error: {e}") diff --git a/debug_history_all.py b/debug_history_all.py deleted file mode 100644 index c2203e1d3abccd2009d8c3190b90106b2d034a19..0000000000000000000000000000000000000000 --- a/debug_history_all.py +++ /dev/null @@ -1,13 +0,0 @@ - -import json -try: - with open('chat_history.json') as f: - data = json.load(f) - print(f"Total sessions: {len(data)}") - for sid, sess in data.items(): - print(f"Session {sid}: {sess.get('title', 'Untitled')} ({len(sess['messages'])} msgs)") - for m in sess['messages']: - if m['role'] == 'assistant': - print(f" [FOUND ASSISTANT MSG in {sid}] Content: {repr(m['content'])[:50]}") -except Exception as e: - print(f"Error: {e}") diff --git a/debug_import.py b/debug_import.py deleted file mode 100644 index afef1785f429c9fc16e1a0746f7a404915646d65..0000000000000000000000000000000000000000 --- a/debug_import.py +++ /dev/null @@ -1,21 +0,0 @@ - -print("Start") -try: - import app.tools.web_scraper - print("Imported WebScraper") -except Exception as e: - print(f"Failed WebScraper: {e}") - -try: - import app.tools.vision_analyzer - print("Imported VisionAnalyzer") -except Exception as e: - print(f"Failed VisionAnalyzer: {e}") - -try: - from app.core.orchestrator import Orchestrator - print("Imported Orchestrator") - o = Orchestrator() - print("Instantiated Orchestrator") -except Exception as e: - print(f"Failed Orchestrator: {e}") diff --git a/debug_models.py b/debug_models.py deleted file mode 100644 index e6d0f624cfeb5f4523103148b0256424d04bb6f8..0000000000000000000000000000000000000000 --- a/debug_models.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -import asyncio -from google import genai -from dotenv import load_dotenv - -load_dotenv() - -async def list_models(): - api_key = os.getenv("GOOGLE_API_KEY") - if not api_key: - print("Error: GOOGLE_API_KEY not found.") - return - - client = genai.Client(api_key=api_key) - - print("Listing available models...") - try: - # Pager object, need to iterate - pager = client.models.list() - for model in pager: - print(f"Name: {model.name}") - print(f" DisplayName: {model.display_name}") - print(f" Supported Actions: {model.supported_actions}") - print("-" * 20) - - except Exception as e: - print(f"Error listing models: {e}") - -if __name__ == "__main__": - asyncio.run(list_models()) diff --git a/debug_response.py b/debug_response.py deleted file mode 100644 index c162e831fa986a29c2324190176216c06bdb0005..0000000000000000000000000000000000000000 --- a/debug_response.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -Run this script while your backend is running. -It bypasses the frontend completely and shows you EXACTLY what the API returns. - -Usage: - cd E:\madhuri\mathminds - python debug_response.py - -Replace TOKEN and QUESTION below. -""" - -import requests -import json - -# ── CONFIG ──────────────────────────────────────────────────────────────────── -API_URL = "http://localhost:8000/solve" -TOKEN = "PASTE_YOUR_FIREBASE_TOKEN_HERE" # grab from browser devtools -QUESTION = "what is 9 + 8" -# ───────────────────────────────────────────────────────────────────────────── - -headers = {"Authorization": f"Bearer {TOKEN}"} -payload = { - "text": QUESTION, - "model_preference": "agent", - "session_id": "debug-session-001", - "request_id": "debug-req-001", -} - -print(f"\n{'='*60}") -print(f"POST {API_URL}") -print(f"Question: {QUESTION}") -print(f"{'='*60}\n") - -try: - r = requests.post(API_URL, json=payload, headers=headers, timeout=120) - print(f"HTTP Status: {r.status_code}") - print(f"\nFull Response JSON:") - data = r.json() - print(json.dumps(data, indent=2)) - - print(f"\n{'='*60}") - print(f"status : {data.get('status')}") - print(f"answer : {repr(data.get('answer'))}") - print(f"source : {data.get('source')}") - print(f"explain : {repr(data.get('explanation'))}") - print(f"error : {data.get('error')}") - print(f"{'='*60}\n") - -except Exception as e: - print(f"Request failed: {e}") diff --git a/debug_scraper.py b/debug_scraper.py deleted file mode 100644 index 2ecf16837ed13592bd1a0b12e9f5658789613dd4..0000000000000000000000000000000000000000 --- a/debug_scraper.py +++ /dev/null @@ -1,21 +0,0 @@ -from app.tools.web_scraper import run_playwright_sync - -query = "calculate the price of 2 kg gold according to todays gold rate" -print(f"Running scraper for: {query}") - -result = run_playwright_sync(query, headless=True) - -print("\n--- STATUS ---") -print(result.get("status")) - -print("\n--- CONTENT SNIPPET ---") -content = result.get("content", "") -# Print first 5000 chars to be sure we see the body -print(content[:5000]) - -if "unusual traffic" in content.lower() or "captcha" in content.lower(): - print("\n[!] DETECTED CAPTCHA/BLOCKING") -elif "Gold Rate" in content or "Silver Rate" in content: - print("\n[+] SUCCESS: Found Gold Rate related content") -else: - print("\n[?] Content unclear. Check snippet above.") diff --git a/debug_scraper_manual.py b/debug_scraper_manual.py deleted file mode 100644 index 72ed9cd47fc79892376f6f581c8bb8e433cc367e..0000000000000000000000000000000000000000 --- a/debug_scraper_manual.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -import sys -import os - -# Add project root to path -sys.path.append(os.getcwd()) - -from app.tools.web_scraper import WebScraper - -async def main(): - print("Initializing WebScraper...") - scraper = WebScraper(headless=True) - - print("\n--- Test 1: Generic Search (Yahoo Finance via Logic) ---") - # Logic in scraper: if "stock" in query -> yahoo finance - query1 = "stock price of apple" - print(f"Query: {query1}") - result1 = await scraper.scrape(query1) - print(f"Status: {result1.get('status')}") - if result1.get('error'): - print(f"Error: {result1.get('error')}") - else: - content = result1.get('content', '') - print(f"Content Length: {len(content)}") - print(f"Preview: {content[:200]}...") - - print("\n--- Test 2: Gold Rate (Goodreturns via Logic) ---") - # Logic in scraper: if "gold" and "rate" -> goodreturns - query2 = "gold rate today" - print(f"Query: {query2}") - result2 = await scraper.scrape(query2) - print(f"Status: {result2.get('status')}") - if result2.get('error'): - print(f"Error: {result2.get('error')}") - else: - content = result2.get('content', '') - print(f"Content Length: {len(content)}") - print(f"Preview: {content[:200]}...") - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/debug_ui_v2.py b/debug_ui_v2.py deleted file mode 100644 index 5516ce36d2e79a2462970dd50fd7d7b2d9f3578e..0000000000000000000000000000000000000000 --- a/debug_ui_v2.py +++ /dev/null @@ -1,30 +0,0 @@ - -import json -import os - -HISTORY_FILE = 'chat_history.json' - -def debug(): - if not os.path.exists(HISTORY_FILE): - print("File not found") - return - - try: - with open(HISTORY_FILE, "r", encoding="utf-8") as f: - data = json.load(f) - - last_sid = sorted(data.keys(), key=lambda k: data[k].get('created_at', 0))[-1] - sess = data[last_sid] - print(f"Session: {last_sid} (Title: {sess.get('title')})") - print(f"Total messages: {len(sess['messages'])}") - - for i, m in enumerate(sess['messages']): - print(f"[{i}] {m['role'].upper()}: {repr(m['content'])[:80]}...") - if 'metadata' in m: - print(f" Metadata keys: {list(m['metadata'].keys())}") - - except Exception as e: - print(f"CRITICAL ERROR reading history: {e}") - -if __name__ == "__main__": - debug() diff --git a/find_embedding_models.py b/find_embedding_models.py deleted file mode 100644 index a31ab63d0898b683d0f964a882c381f2dd38b2fc..0000000000000000000000000000000000000000 --- a/find_embedding_models.py +++ /dev/null @@ -1,11 +0,0 @@ -import os -from dotenv import load_dotenv -import google.generativeai as genai - -load_dotenv() -genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) - -print("Available models:") -for model in genai.list_models(): - if "gemini" in model.name: - print(f" {model.name:50} {model.display_name}") \ No newline at end of file diff --git a/frontend/app.py b/frontend/app.py index 9b002288b9d858d84dadddf5ee39e53d74c09723..aa0a391d5e80387b8809bb9a6a4d757679dd2e80 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -1,61 +1,70 @@ +""" +frontend/app.py — MathMinds AI Streamlit frontend. + +STRUCTURE +───────── + 1. CONFIG & CONSTANTS — env vars, page config, CSS + 2. SESSION STATE — defaults, init, clear + 3. API LAYER — all HTTP calls, no st.* calls, returns plain dicts + 4. RENDER HELPERS — pure display functions, never mutate state + 5. SESSION MANAGEMENT — state mutations, never call st.rerun() + 6. CHAT INTERFACE — the 3-state machine (idle / processing / done) + 7. MAIN ENTRY — auth gate + router + +STATE MACHINE (why the previous version had bugs) +────────────────────────────────────────────────── +Every bug in the old version came from collapsing 3 distinct states into one +render pass: add user message + call API + render answer all happened together. + +The fix is a strict 3-state machine: + + IDLE → render history + show input + user submits → _add_message("user") → is_processing=True → rerun() + + PROCESSING → render history (user question VISIBLE above spinner) + make API call → _add_message("assistant") → is_processing=False → rerun() + + IDLE again → render history (assistant answer now visible) + +Rules that make this impossible to break: + - st.rerun() is ALWAYS the last statement — never inside a with-block + - The API call NEVER happens inside with st.chat_message() + - Render helpers never write session_state + - load_messages() only called on session switch, never after an answer +""" + import streamlit as st import requests -import json import base64 -from PIL import Image +import logging import io import os -import uuid import time +import uuid +from PIL import Image from streamlit_drawable_canvas import st_canvas from dotenv import load_dotenv from firebase_utils import sign_in_with_email, sign_up_with_email +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) load_dotenv() -# ── Session state: ALL keys initialized ONCE at the very top ───────────────── -# CRITICAL: These must be the very first st.session_state accesses, before any -# st.* UI calls. Streamlit re-runs the entire script on every interaction. -if "is_processing" not in st.session_state: - st.session_state.is_processing = False -if "user" not in st.session_state: - st.session_state.user = None # None = logged out -if "current_view" not in st.session_state: - st.session_state.current_view = "Chat" - -# MULTIUSER FIX ─ these three keys must be RESET on logout. -# They are initialized here so first-run doesn't KeyError. -if "chat_sessions" not in st.session_state: - st.session_state.chat_sessions = [] -if "active_session_id" not in st.session_state: - st.session_state.active_session_id = None -if "messages" not in st.session_state: - st.session_state.messages = [] - -# MULTIUSER FIX ─ track WHICH user's data is currently loaded. -# If this doesn't match st.session_state.user["uid"], we know we need to reload. -if "loaded_for_user" not in st.session_state: - st.session_state.loaded_for_user = None - -if "renaming_session_id" not in st.session_state: - st.session_state.renaming_session_id = None - -if "canvas_key" not in st.session_state: - st.session_state.canvas_key = "main_canvas" - -# ==================================================== -# Page Config — must come before any st.* calls -# ==================================================== + +# ══════════════════════════════════════════════════════════════════════════════ +# 1. CONFIG & CONSTANTS +# ══════════════════════════════════════════════════════════════════════════════ + +BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000") +ENABLE_AUTH = os.getenv("ENABLE_AUTH", "True").lower() == "true" + st.set_page_config( page_title="MathMinds AI", page_icon="🧠", layout="wide", - initial_sidebar_state="expanded" + initial_sidebar_state="expanded", ) -# ==================================================== -# Premium Global Styling -# ==================================================== st.markdown(""" """, unsafe_allow_html=True) -# ==================================================== -# Config -# ==================================================== -BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000") -API_URL = f"{BACKEND_URL}/solve" +# ══════════════════════════════════════════════════════════════════════════════ +# 2. SESSION STATE +# ══════════════════════════════════════════════════════════════════════════════ + +_DEFAULTS = { + "user": None, + "current_view": "Chat", + "chat_sessions": [], + "active_session_id": None, + "messages": [], + "loaded_for_user": None, + "loaded_for_session": None, + "is_processing": False, + "renaming_session_id": None, + "canvas_key": "main_canvas", +} + +for _k, _v in _DEFAULTS.items(): + if _k not in st.session_state: + st.session_state[_k] = _v + +# Dev mode bypass +if not ENABLE_AUTH and st.session_state.user is None: + st.session_state.user = { + "email": "dev@mathminds.ai", + "token": "mock_dev_token", + "uid": "dev_user_123", + } -# ==================================================== -# MULTIUSER ISOLATION — Core helper -# ==================================================== -def _clear_user_state(): - """ - Wipe ALL per-user data from Streamlit session state. - Called on logout and whenever a different user logs in. +def _reset_state(keep_user=False): + """Clear all state. Optionally preserve the user object.""" + saved_user = st.session_state.user if keep_user else None + for k, v in _DEFAULTS.items(): + st.session_state[k] = v + st.session_state.canvas_key = f"canvas_{uuid.uuid4()}" + if keep_user: + st.session_state.user = saved_user - WHY THIS IS THE MOST IMPORTANT FUNCTION FOR MULTIUSER ISOLATION: - Streamlit's st.session_state is per browser-tab, not per user. If User A - logs in, chats, then User B logs in on the same tab, all of User A's - chat_sessions and messages are still sitting in st.session_state. The - backend correctly refuses to serve User A's data to User B (every DB query - filters by user_id), but the frontend would still DISPLAY User A's messages - briefly until the next API call returns. This function prevents that. - """ - st.session_state.chat_sessions = [] - st.session_state.active_session_id = None - st.session_state.messages = [] - st.session_state.loaded_for_user = None - st.session_state.is_processing = False - st.session_state.current_view = "Chat" - st.session_state.renaming_session_id = None - st.session_state.canvas_key = f"canvas_{uuid.uuid4()}" - # Also clear profile cache if it exists - if "profile_data" in st.session_state: - del st.session_state["profile_data"] - - -# ==================================================== -# Helper Functions -# ==================================================== -def get_auth_headers(): - if st.session_state.user and "token" in st.session_state.user: - return {"Authorization": f"Bearer {st.session_state.user['token']}"} - return {} - - -def load_sessions(): - """Fetch THIS user's chat sessions from the backend and populate state.""" + +def _get_headers() -> dict: + u = st.session_state.user + return {"Authorization": f"Bearer {u['token']}"} if u and "token" in u else {} + + +def _add_message(role: str, content, **kwargs): + msg = {"role": role, "content": content, "timestamp": time.time()} + msg.update(kwargs) + st.session_state.messages.append(msg) + + +def _is_blank_canvas(image_data) -> bool: + if image_data is None: + return True + import numpy as np + return image_data[:, :, 3].max() == 0 + + +# ══════════════════════════════════════════════════════════════════════════════ +# 3. API LAYER +# ───────────────────────────────────────────────────────────────────────────── +# Rules: +# - NO st.* calls anywhere in this section +# - Returns plain dict, always has "ok" key +# - Never raises — exceptions are caught and returned as {"ok": False} +# ══════════════════════════════════════════════════════════════════════════════ + +@st.cache_data(ttl=30) +def api_health() -> dict: + """Cached — only hits the backend once every 30 seconds, not on every rerender.""" try: - headers = get_auth_headers() - try: - response = requests.get(f"{BACKEND_URL}/chat/sessions", headers=headers, timeout=10) - except requests.exceptions.ConnectionError: - st.info("⌛ **MathMinds API is warming up...** Please wait a few seconds.") - st.stop() - - if response.status_code == 200: - st.session_state.chat_sessions = response.json() - # Mark that we've successfully loaded data for this specific user - if st.session_state.user: - st.session_state.loaded_for_user = st.session_state.user["uid"] - # Auto-select first session if none active - if not st.session_state.active_session_id and st.session_state.chat_sessions: - st.session_state.active_session_id = st.session_state.chat_sessions[0]["session_id"] - load_messages(st.session_state.active_session_id) - elif st.session_state.active_session_id and not any( - s["session_id"] == st.session_state.active_session_id - for s in st.session_state.chat_sessions - ): - # Active session was deleted — pick first or clear - if st.session_state.chat_sessions: - st.session_state.active_session_id = st.session_state.chat_sessions[0]["session_id"] - load_messages(st.session_state.active_session_id) - else: - st.session_state.active_session_id = None - st.session_state.messages = [] - elif response.status_code == 401: - # JWT expired — force re-login - _clear_user_state() - st.session_state.user = None - st.error("Session expired. Please log in again.") - else: - st.error(f"Failed to load sessions: {response.status_code}") - st.session_state.chat_sessions = [] + r = requests.get(f"{BACKEND_URL}/health", timeout=5) + return {"ok": r.status_code == 200, **r.json()} if r.status_code == 200 else {"ok": False} + except Exception: + return {"ok": False} + + +def api_solve(text: str, image, session_id: str, request_id: str) -> dict: + try: + r = requests.post( + f"{BACKEND_URL}/solve", + json={"text": text, "image": image, "session_id": session_id, "request_id": request_id}, + headers=_get_headers(), + timeout=360, + ) + if r.status_code == 200: + return {"ok": True, **r.json()} + if r.status_code == 401: + return {"ok": False, "error": "AUTH_EXPIRED"} + if r.status_code == 429: + return {"ok": False, "error": "Daily limit reached. Please try again tomorrow."} + return {"ok": False, "error": f"Backend error {r.status_code}"} + except requests.exceptions.ConnectionError: + return {"ok": False, "error": "Cannot reach backend. Is the server running?"} + except requests.exceptions.Timeout: + return {"ok": False, "error": "Request timed out."} except Exception as e: - st.error(f"Error loading sessions: {e}") - st.session_state.chat_sessions = [] + logger.error(f"api_solve: {e}") + return {"ok": False, "error": "Unexpected error. Please try again."} -def load_messages(session_id): - """ - Load messages for a session. - The backend enforces user ownership — it will 404 if session_id - doesn't belong to the authenticated user, so this is safe. - """ +def api_load_messages(session_id: str) -> dict: try: - headers = get_auth_headers() - response = requests.get( + r = requests.get( f"{BACKEND_URL}/chat/sessions/{session_id}/messages", - headers=headers, timeout=30 + headers=_get_headers(), timeout=30, ) - if response.status_code == 200: - server_messages = response.json() - local_messages = st.session_state.get("messages", []) - - # ✅ INDESTRUCTIBLE MERGE LOGIC - # 1. Start with server messages as the definitive baseline. - merged = [] - server_keys = set() - for m in server_messages: - merged.append(m) - rid = m.get("request_id") - role = m.get("role") - if rid and role: - server_keys.add((role, rid)) - - # 2. Append local messages that have NOT yet reached the server. - # This protects local "optimistic" messages from vanishing if DB is slow. - for lm in local_messages: - rid = lm.get("request_id") - role = lm.get("role") - if rid and role: - if (role, rid) not in server_keys: - merged.append(lm) - elif not rid: - # Fallback for messages without IDs (should be rare) - content_prefix = str(lm.get("content", ""))[:50] - if not any(str(sm.get("content", "")).startswith(content_prefix) for sm in server_messages): - merged.append(lm) - - st.session_state.messages = merged - elif response.status_code == 404: - # Session doesn't belong to this user — clear silently - st.session_state.messages = [] - st.session_state.active_session_id = None - st.warning("Session not found.") - else: - st.session_state.messages = [] - st.error(f"Failed to load messages: {response.status_code}") + if r.status_code == 200: return {"ok": True, "messages": r.json()} + if r.status_code == 404: return {"ok": False, "error": "SESSION_NOT_FOUND"} + if r.status_code == 401: return {"ok": False, "error": "AUTH_EXPIRED"} + return {"ok": False, "error": f"Status {r.status_code}"} except Exception as e: - logger.error(f"Error loading messages: {e}") - st.error(f"Error loading messages: {e}") - st.session_state.messages = [] - - -def get_active_session(): - for s in st.session_state.chat_sessions: - if s["session_id"] == st.session_state.active_session_id: - return s - return None - - -def add_message(role, content, sent_to_api=False, request_id=None, **kwargs): - """Optimistic UI update only — persistence happens in the backend via /solve.""" - msg = { - "role": role, - "content": content, - "timestamp": time.time(), - "sent_to_api": sent_to_api, - "request_id": request_id - } - msg.update(kwargs) - st.session_state.messages.append(msg) + return {"ok": False, "error": str(e)} -def new_chat(): +def api_load_sessions() -> dict: try: - headers = get_auth_headers() - response = requests.post(f"{BACKEND_URL}/chat/sessions", headers=headers, timeout=30) - if response.status_code == 200: - new_s = response.json() - st.session_state.active_session_id = new_s["session_id"] - st.session_state.messages = [] - load_sessions() - st.rerun() - else: - st.error("Failed to create new chat") + r = requests.get(f"{BACKEND_URL}/chat/sessions", headers=_get_headers(), timeout=10) + if r.status_code == 200: return {"ok": True, "sessions": r.json()} + if r.status_code == 401: return {"ok": False, "error": "AUTH_EXPIRED"} + return {"ok": False, "error": f"Status {r.status_code}"} + except requests.exceptions.ConnectionError: + return {"ok": False, "error": "BACKEND_OFFLINE"} except Exception as e: - st.error(f"Error: {e}") + return {"ok": False, "error": str(e)} -def delete_chat(sid): +def api_new_session() -> dict: try: - headers = get_auth_headers() - response = requests.delete(f"{BACKEND_URL}/chat/sessions/{sid}", headers=headers, timeout=30) - if response.status_code == 200: - if st.session_state.active_session_id == sid: - st.session_state.active_session_id = None - st.session_state.messages = [] - load_sessions() - st.rerun() - else: - st.error("Failed to delete chat") + r = requests.post(f"{BACKEND_URL}/chat/sessions", headers=_get_headers(), timeout=30) + return {"ok": True, "session": r.json()} if r.status_code == 200 else {"ok": False, "error": f"Status {r.status_code}"} + except Exception as e: + return {"ok": False, "error": str(e)} + + +def api_delete_session(sid: str) -> dict: + try: + r = requests.delete(f"{BACKEND_URL}/chat/sessions/{sid}", headers=_get_headers(), timeout=30) + return {"ok": r.status_code == 200} except Exception as e: - st.error(f"Error: {e}") + return {"ok": False, "error": str(e)} -def rename_chat(sid, new_title): +def api_rename_session(sid: str, title: str) -> dict: try: - headers = get_auth_headers() - response = requests.patch( + r = requests.patch( f"{BACKEND_URL}/chat/sessions/{sid}", - headers=headers, json={"title": new_title}, timeout=30 + json={"title": title}, headers=_get_headers(), timeout=30, ) - if response.status_code == 200: - load_sessions() - st.rerun() - else: - st.error("Failed to rename chat") + return {"ok": r.status_code == 200} except Exception as e: - st.error(f"Error: {e}") + return {"ok": False, "error": str(e)} + + +# ══════════════════════════════════════════════════════════════════════════════ +# 4. RENDER HELPERS +# ───────────────────────────────────────────────────────────────────────────── +# Rules: +# - Only READ session_state, never write +# - Only call st.* display functions +# - Never call st.rerun() or make API calls +# ══════════════════════════════════════════════════════════════════════════════ + +def _badge(source: str, status: str) -> str: + b = "" + if source == "sympy_preflight": + b += '⚡ INSTANT' + elif source in ("cache", "semantic_cache"): + b += '💾 CACHED' + elif source in ("google_adk_agent", "agent"): + b += '🤖 AI' + if status == "error": + b += '🔴 ERROR' + return b + + +def _render_message(msg: dict): + role = msg.get("role", "assistant") + with st.chat_message(role, avatar="👤" if role == "user" else "🤖"): + if role == "user": + if msg.get("image_data"): + try: + st.image(base64.b64decode(msg["image_data"]), width=300) + except Exception: + pass + st.write(msg.get("content", "")) + else: + meta = msg.get("metadata") or {} + status = meta.get("status") or msg.get("status", "success") + badges = _badge(meta.get("source", ""), status) + if badges: + st.markdown(badges, unsafe_allow_html=True) + + logic = meta.get("logic_trace") or msg.get("reasoning") + if logic: + steps = [s for s in (logic if isinstance(logic, list) else logic.split("\n")) if s] + if steps: + with st.expander("💭 Reasoning", expanded=False): + for step in steps: + st.caption(step) + + content = msg.get("content", "") + if status == "error": + st.error(content) + elif isinstance(content, dict) and "final_answer" in content: + st.markdown(content["final_answer"]) + else: + st.markdown(str(content)) -# ==================================================== -# Login Screen -# ==================================================== -def login_screen(): - c1, c2, c3 = st.columns([1, 2, 1]) - with c2: - st.write("") - st.write("") +def _render_login(): + _, c, _ = st.columns([1, 2, 1]) + with c: st.markdown(""" -
Your intelligent quantitative assistant.
+Your intelligent math assistant.
Powered by Gemini & SymPy
", - unsafe_allow_html=True + unsafe_allow_html=True, ) -# ── Auth gate ───────────────────────────────────────────────────────────────── -if not st.session_state.user: - login_screen() - st.stop() - -# ==================================================== -# ✅ MULTIUSER FIX — Per-rerun data isolation check -# ==================================================== -# At this point we know a user IS logged in. -# Check: is the data currently in state actually for THIS user? -# This handles the scenario where User A's browser tab is reused by User B -# (e.g. token swap, shared kiosk, etc.) without a full page reload. -_current_uid = st.session_state.user["uid"] -if st.session_state.loaded_for_user != _current_uid: - # Data in state belongs to a different user (or nobody) — reload for current user - _clear_user_state() - load_sessions() - # loaded_for_user is set inside load_sessions() on success - - -# ==================================================== -# Profile Interface -# ==================================================== -def profile_interface(): +def _render_profile(): st.title("👤 User Profile") - st.markdown("Customize your MathMinds experience.") - headers = get_auth_headers() - if "profile_data" not in st.session_state: try: - r = requests.get(f"{BACKEND_URL}/users/profile", headers=headers, timeout=30) + r = requests.get(f"{BACKEND_URL}/users/profile", headers=_get_headers(), timeout=30) st.session_state.profile_data = r.json() if r.status_code == 200 else {} except Exception: st.session_state.profile_data = {} - data = st.session_state.profile_data - levels = ["High School", "Undergraduate", "Graduate", "Researcher"] + data = st.session_state.profile_data + levels = ["High School", "Undergraduate", "Graduate", "Researcher"] interests_all = ["Algebra", "Calculus", "Geometry", "Statistics", "Physics", "Computer Science", "Finance"] with st.form("profile_form"): display_name = st.text_input("Display Name", value=data.get("display_name", "")) math_level = st.selectbox( - "Math Proficiency Level", levels, - index=levels.index(data.get("math_level", "Undergraduate")) - if data.get("math_level") in levels else 1 + "Math Level", levels, + index=levels.index(data["math_level"]) if data.get("math_level") in levels else 1, ) interests = st.multiselect( - "Areas of Interest", interests_all, - default=[i for i in data.get("interests", []) if i in interests_all] + "Interests", interests_all, + default=[i for i in data.get("interests", []) if i in interests_all], ) - if st.form_submit_button("Save Profile", use_container_width=True, type="primary"): + if st.form_submit_button("Save", use_container_width=True, type="primary"): payload = {"display_name": display_name, "math_level": math_level, "interests": interests} try: - r = requests.post(f"{BACKEND_URL}/users/profile", json=payload, headers=headers) + r = requests.post(f"{BACKEND_URL}/users/profile", json=payload, headers=_get_headers()) if r.status_code == 200: - st.success("Profile updated!") + st.success("Saved!") st.session_state.profile_data = payload - time.sleep(1) - st.rerun() else: - st.error(f"Update failed: {r.text}") + st.error(f"Failed: {r.text}") except Exception as e: - st.error(f"Error saving: {e}") + st.error(str(e)) + + +# ══════════════════════════════════════════════════════════════════════════════ +# 5. SESSION MANAGEMENT +# ───────────────────────────────────────────────────────────────────────────── +# These functions mutate session_state but never call st.rerun(). +# Callers decide when to rerun. +# ══════════════════════════════════════════════════════════════════════════════ + +def _refresh_sessions(): + result = api_load_sessions() + if result["ok"]: + st.session_state.chat_sessions = result["sessions"] + st.session_state.loaded_for_user = st.session_state.user["uid"] + elif result.get("error") == "AUTH_EXPIRED": + _reset_state() + elif result.get("error") == "BACKEND_OFFLINE": + st.info("⌛ Backend warming up...") + st.stop() + + +def _switch_session(sid: str): + """Load messages for a session. Only fetches if session actually changed.""" + if st.session_state.loaded_for_session == sid: + st.session_state.active_session_id = sid + return + + result = api_load_messages(sid) + if result["ok"]: + st.session_state.active_session_id = sid + st.session_state.messages = result["messages"] + st.session_state.loaded_for_session = sid + elif result.get("error") == "SESSION_NOT_FOUND": + st.warning("Session not found.") + st.session_state.active_session_id = None + st.session_state.messages = [] + elif result.get("error") == "AUTH_EXPIRED": + _reset_state() + + +def _ensure_session() -> bool: + """ + Make sure there's an active session. Returns True if ready, False if rerun needed. + """ + if st.session_state.active_session_id: + return True + + sessions = st.session_state.chat_sessions + if sessions: + _switch_session(sessions[0]["session_id"]) + return True + + # Create first session + result = api_new_session() + if result["ok"]: + sess = result["session"] + st.session_state.active_session_id = sess["session_id"] + st.session_state.messages = [] + st.session_state.loaded_for_session = sess["session_id"] + _refresh_sessions() + return False # caller must rerun + return True + + +def _render_sidebar(): + with st.sidebar: + st.markdown("### 🧠 MathMinds") + st.caption(f"👤 {st.session_state.user['email']}") + + view = st.radio("Nav", ["Chat", "Profile"], + index=0 if st.session_state.current_view == "Chat" else 1, + label_visibility="collapsed") + if view != st.session_state.current_view: + st.session_state.current_view = view + st.rerun() + + # Health + st.divider() + health = api_health() + if health.get("ok"): + st.markdown('