Spaces:
Running
Running
Commit ·
c14a92e
1
Parent(s): c00b41f
Offloaded persistence tasks to the background
Browse files- app/agents/adk_mathminds.py +69 -107
- app/api/main.py +152 -54
- app/core/orchestrator.py +117 -177
- app/core/schemas.py +37 -0
- app/core/security.py +17 -27
- app/core/settings.py +5 -0
- app/memory/database.py +128 -48
- frontend/app.py +423 -206
- requirements.txt +3 -0
app/agents/adk_mathminds.py
CHANGED
|
@@ -1,18 +1,8 @@
|
|
| 1 |
-
"""
|
| 2 |
-
adk_mathminds.py — MathMinds ADK Agent
|
| 3 |
-
Key changes vs original:
|
| 4 |
-
1. Semaphore removed. Replaced with Redis-backed daily quota via llm_guard.
|
| 5 |
-
2. Tenacity retries scoped to 429/rate-limit errors ONLY (not all exceptions),
|
| 6 |
-
so a quota block is not retried.
|
| 7 |
-
3. ADK event loop now filters for is_final_response() to avoid
|
| 8 |
-
collecting tool-call intermediate text.
|
| 9 |
-
4. Redis client injected via constructor so it can be shared with CacheManager.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
import logging
|
| 13 |
import asyncio
|
| 14 |
import base64
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
from google.adk.agents import Agent
|
| 18 |
from google.adk.runners import Runner
|
|
@@ -22,12 +12,11 @@ from google.genai.errors import ClientError
|
|
| 22 |
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
| 23 |
|
| 24 |
from app.core.settings import settings
|
| 25 |
-
from app.core.llm_guard import check_and_increment
|
| 26 |
from app.tools.web_scraper import WebScraper
|
| 27 |
from app.tools.symbolic_solver import SymbolicSolver
|
| 28 |
from app.tools.similarity_search import SimilarProblemFinder
|
| 29 |
-
from app.
|
| 30 |
-
from app.tools.vision_analyzer import VisionAnalyzer
|
| 31 |
from app.core.math_normalizer import MathQueryNormalizer
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
|
@@ -36,12 +25,12 @@ logger = logging.getLogger(__name__)
|
|
| 36 |
class MathMindsADKAgent:
|
| 37 |
"""
|
| 38 |
Agent-based architecture using Google ADK.
|
| 39 |
-
|
| 40 |
"""
|
| 41 |
|
| 42 |
-
def __init__(self, model_name: str = "gemini-2.
|
| 43 |
self.api_key = settings.GOOGLE_API_KEY
|
| 44 |
-
self.redis_client = redis_client
|
| 45 |
|
| 46 |
if not self.api_key:
|
| 47 |
logger.warning("No Google API Key found. Agent will fail.")
|
|
@@ -51,8 +40,7 @@ class MathMindsADKAgent:
|
|
| 51 |
self.symbolic_solver = SymbolicSolver()
|
| 52 |
self.normalizer = MathQueryNormalizer()
|
| 53 |
self.similar_finder = SimilarProblemFinder()
|
| 54 |
-
self.
|
| 55 |
-
self.vision_analyzer = VisionAnalyzer()
|
| 56 |
|
| 57 |
# ── Tool definitions ──────────────────────────────────────────────────
|
| 58 |
async def web_search(query: str) -> str:
|
|
@@ -79,6 +67,18 @@ class MathMindsADKAgent:
|
|
| 79 |
return result.get("content", "No solution found.")
|
| 80 |
return f"Error solving math: {result.get('error')}"
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def find_similar_problems(query: str) -> str:
|
| 83 |
"""
|
| 84 |
Find similar solved problems from the database for reference.
|
|
@@ -93,41 +93,18 @@ class MathMindsADKAgent:
|
|
| 93 |
formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n"
|
| 94 |
return formatted
|
| 95 |
|
| 96 |
-
def read_image(image_data: str) -> str:
|
| 97 |
-
"""
|
| 98 |
-
Extract text/equations from an image using OCR.
|
| 99 |
-
Args:
|
| 100 |
-
image_data: Base64 string of the image.
|
| 101 |
-
"""
|
| 102 |
-
try:
|
| 103 |
-
text = self.ocr.extract_text(image_data=image_data)
|
| 104 |
-
return text if text else "No text found in image."
|
| 105 |
-
except Exception as e:
|
| 106 |
-
return f"Error reading image: {str(e)}"
|
| 107 |
-
|
| 108 |
-
async def analyze_image(image_data: str, focus: str = "") -> str:
|
| 109 |
-
"""
|
| 110 |
-
Analyze an image mathematically: count objects, describe graphs, extract equations.
|
| 111 |
-
Args:
|
| 112 |
-
image_data: Base64 string of the image.
|
| 113 |
-
focus: Optional focus hint (e.g. 'count red balls').
|
| 114 |
-
"""
|
| 115 |
-
try:
|
| 116 |
-
result = self.vision_analyzer.analyze(image_data, focus)
|
| 117 |
-
return str(result)
|
| 118 |
-
except Exception as e:
|
| 119 |
-
return f"Image analysis failed: {str(e)}"
|
| 120 |
-
|
| 121 |
# ── Agent & Runner ─────────────────────────────────────────────────��──
|
| 122 |
self.agent = Agent(
|
| 123 |
name="math_minds_core",
|
| 124 |
model=model_name,
|
| 125 |
-
tools=[web_search, math_solver,
|
| 126 |
instruction=(
|
| 127 |
"You are MathMinds AI, a precise mathematical assistant. "
|
| 128 |
-
"When an image is provided,
|
| 129 |
-
"count objects, or interpret graphs.
|
| 130 |
-
"
|
|
|
|
|
|
|
| 131 |
)
|
| 132 |
)
|
| 133 |
|
|
@@ -138,7 +115,7 @@ class MathMindsADKAgent:
|
|
| 138 |
session_service=self.session_service
|
| 139 |
)
|
| 140 |
|
| 141 |
-
logger.info("MathMindsADKAgent initialized
|
| 142 |
|
| 143 |
async def solve(
|
| 144 |
self,
|
|
@@ -146,22 +123,17 @@ class MathMindsADKAgent:
|
|
| 146 |
image_data: Optional[str] = None,
|
| 147 |
session_id: str = "default_session",
|
| 148 |
user_id: str = "default_user"
|
| 149 |
-
) -> str:
|
| 150 |
"""
|
| 151 |
-
|
| 152 |
-
Returns the agent's answer string, or an error message.
|
| 153 |
"""
|
| 154 |
|
| 155 |
# ── 1. Daily quota check ──────────────────────────────────────────────
|
| 156 |
-
# This is the ONLY gate. One check per user question = one LLM call.
|
| 157 |
if self.redis_client:
|
| 158 |
allowed, used, limit = check_and_increment(self.redis_client, user_id)
|
| 159 |
if not allowed:
|
| 160 |
-
|
| 161 |
-
return
|
| 162 |
-
f"⚠️ Daily limit reached ({limit} questions per day). "
|
| 163 |
-
"Please try again tomorrow."
|
| 164 |
-
)
|
| 165 |
else:
|
| 166 |
logger.warning("Redis unavailable — skipping quota check (failing open).")
|
| 167 |
|
|
@@ -174,66 +146,56 @@ class MathMindsADKAgent:
|
|
| 174 |
await self.session_service.create_session(
|
| 175 |
app_name="mathminds", user_id=user_id, session_id=session_id
|
| 176 |
)
|
| 177 |
-
except Exception:
|
| 178 |
-
|
| 179 |
-
await self.session_service.create_session(
|
| 180 |
-
app_name="mathminds", user_id=user_id, session_id=session_id
|
| 181 |
-
)
|
| 182 |
-
except Exception as e:
|
| 183 |
-
logger.warning(f"Session create warning: {e}")
|
| 184 |
|
| 185 |
# ── 3. Build message parts ────────────────────────────────────────────
|
| 186 |
-
parts = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
if image_data:
|
| 189 |
try:
|
| 190 |
-
if image_data.startswith("/9j/"):
|
| 191 |
-
mime_type = "image/jpeg"
|
| 192 |
-
elif image_data.startswith("iVBORw"):
|
| 193 |
-
mime_type = "image/png"
|
| 194 |
-
elif image_data.startswith("R0lGOD"):
|
| 195 |
-
mime_type = "image/gif"
|
| 196 |
-
elif image_data.startswith("UklGR"):
|
| 197 |
-
mime_type = "image/webp"
|
| 198 |
-
else:
|
| 199 |
-
mime_type = "image/png"
|
| 200 |
-
|
| 201 |
img_bytes = base64.b64decode(image_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
|
| 203 |
-
logger.info("Image attached to agent request.")
|
| 204 |
except Exception as e:
|
| 205 |
-
logger.error(f"
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
@retry(
|
| 210 |
-
stop=stop_after_attempt(2), # max 2 attempts total
|
| 211 |
-
wait=wait_exponential(multiplier=2, min=5, max=30),
|
| 212 |
-
retry=retry_if_exception_type(ClientError), # only retry on API errors
|
| 213 |
-
reraise=True
|
| 214 |
-
)
|
| 215 |
-
async def run_agent_safely() -> str:
|
| 216 |
-
outcome = ""
|
| 217 |
async for event in self.runner.run_async(
|
| 218 |
user_id=user_id,
|
| 219 |
session_id=session_id,
|
| 220 |
new_message=types.Content(role="user", parts=parts)
|
| 221 |
):
|
| 222 |
-
#
|
| 223 |
-
if hasattr(event, "
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
except Exception as e:
|
| 238 |
-
logger.error(f"
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import asyncio
|
| 3 |
import base64
|
| 4 |
+
import json
|
| 5 |
+
from typing import Optional, AsyncGenerator, Dict, Any
|
| 6 |
|
| 7 |
from google.adk.agents import Agent
|
| 8 |
from google.adk.runners import Runner
|
|
|
|
| 12 |
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
| 13 |
|
| 14 |
from app.core.settings import settings
|
| 15 |
+
from app.core.llm_guard import check_and_increment
|
| 16 |
from app.tools.web_scraper import WebScraper
|
| 17 |
from app.tools.symbolic_solver import SymbolicSolver
|
| 18 |
from app.tools.similarity_search import SimilarProblemFinder
|
| 19 |
+
from app.tools.python_executor import PythonInterpreter
|
|
|
|
| 20 |
from app.core.math_normalizer import MathQueryNormalizer
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
|
|
|
| 25 |
class MathMindsADKAgent:
|
| 26 |
"""
|
| 27 |
Agent-based architecture using Google ADK.
|
| 28 |
+
Supports real-time streaming of reasoning steps and final answers.
|
| 29 |
"""
|
| 30 |
|
| 31 |
+
def __init__(self, model_name: str = "gemini-2.0-flash", redis_client=None):
|
| 32 |
self.api_key = settings.GOOGLE_API_KEY
|
| 33 |
+
self.redis_client = redis_client
|
| 34 |
|
| 35 |
if not self.api_key:
|
| 36 |
logger.warning("No Google API Key found. Agent will fail.")
|
|
|
|
| 40 |
self.symbolic_solver = SymbolicSolver()
|
| 41 |
self.normalizer = MathQueryNormalizer()
|
| 42 |
self.similar_finder = SimilarProblemFinder()
|
| 43 |
+
self.python_executor = PythonInterpreter()
|
|
|
|
| 44 |
|
| 45 |
# ── Tool definitions ──────────────────────────────────────────────────
|
| 46 |
async def web_search(query: str) -> str:
|
|
|
|
| 67 |
return result.get("content", "No solution found.")
|
| 68 |
return f"Error solving math: {result.get('error')}"
|
| 69 |
|
| 70 |
+
async def execute_python(code: str) -> str:
|
| 71 |
+
"""
|
| 72 |
+
Execute arbitrary Python code for simulations, complex logic, or data analysis.
|
| 73 |
+
Use this when SymPy is too restrictive or you need to run a simulation.
|
| 74 |
+
Args:
|
| 75 |
+
code: The Python code to execute.
|
| 76 |
+
"""
|
| 77 |
+
result = await self.python_executor.execute(code)
|
| 78 |
+
if result.get("status") == "success":
|
| 79 |
+
return f"Output:\n{result.get('content')}\nResult: {result.get('result')}"
|
| 80 |
+
return f"Error in Python execution: {result.get('content')}"
|
| 81 |
+
|
| 82 |
def find_similar_problems(query: str) -> str:
|
| 83 |
"""
|
| 84 |
Find similar solved problems from the database for reference.
|
|
|
|
| 93 |
formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n"
|
| 94 |
return formatted
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# ── Agent & Runner ─────────────────────────────────────────────────��──
|
| 97 |
self.agent = Agent(
|
| 98 |
name="math_minds_core",
|
| 99 |
model=model_name,
|
| 100 |
+
tools=[web_search, math_solver, execute_python, find_similar_problems],
|
| 101 |
instruction=(
|
| 102 |
"You are MathMinds AI, a precise mathematical assistant. "
|
| 103 |
+
"You can see images natively! When an image is provided, examine it "
|
| 104 |
+
"carefully to extract equations, count objects, or interpret graphs. "
|
| 105 |
+
"\n\nCRITICAL: Always start by explaining your step-by-step approach "
|
| 106 |
+
"before using any tools. Your internal monologue should be clear "
|
| 107 |
+
"and explain the reasoning behind your tool choices."
|
| 108 |
)
|
| 109 |
)
|
| 110 |
|
|
|
|
| 115 |
session_service=self.session_service
|
| 116 |
)
|
| 117 |
|
| 118 |
+
logger.info(f"MathMindsADKAgent initialized with model: {model_name}")
|
| 119 |
|
| 120 |
async def solve(
|
| 121 |
self,
|
|
|
|
| 123 |
image_data: Optional[str] = None,
|
| 124 |
session_id: str = "default_session",
|
| 125 |
user_id: str = "default_user"
|
| 126 |
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
| 127 |
"""
|
| 128 |
+
Streaming entry point. Yields events as they occur.
|
|
|
|
| 129 |
"""
|
| 130 |
|
| 131 |
# ── 1. Daily quota check ──────────────────────────────────────────────
|
|
|
|
| 132 |
if self.redis_client:
|
| 133 |
allowed, used, limit = check_and_increment(self.redis_client, user_id)
|
| 134 |
if not allowed:
|
| 135 |
+
yield {"type": "error", "content": f"⚠️ Daily limit reached ({limit} today)."}
|
| 136 |
+
return
|
|
|
|
|
|
|
|
|
|
| 137 |
else:
|
| 138 |
logger.warning("Redis unavailable — skipping quota check (failing open).")
|
| 139 |
|
|
|
|
| 146 |
await self.session_service.create_session(
|
| 147 |
app_name="mathminds", user_id=user_id, session_id=session_id
|
| 148 |
)
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.warning(f"Session setup warning: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# ── 3. Build message parts ────────────────────────────────────────────
|
| 153 |
+
parts = []
|
| 154 |
+
if problem:
|
| 155 |
+
parts.append(types.Part.from_text(text=problem))
|
| 156 |
+
else:
|
| 157 |
+
parts.append(types.Part.from_text(text="Analyze this image."))
|
| 158 |
|
| 159 |
if image_data:
|
| 160 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
img_bytes = base64.b64decode(image_data)
|
| 162 |
+
mime_type = "image/png" # Default
|
| 163 |
+
# Basic sniff
|
| 164 |
+
if image_data.startswith("/9j/"): mime_type = "image/jpeg"
|
| 165 |
+
elif image_data.startswith("iVBORw"): mime_type = "image/png"
|
| 166 |
+
|
| 167 |
parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
|
|
|
|
| 168 |
except Exception as e:
|
| 169 |
+
logger.error(f"Image decode failed: {e}")
|
| 170 |
+
|
| 171 |
+
# ── 4. Run agent (Streaming) ──────────────────────────────────────────
|
| 172 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
async for event in self.runner.run_async(
|
| 174 |
user_id=user_id,
|
| 175 |
session_id=session_id,
|
| 176 |
new_message=types.Content(role="user", parts=parts)
|
| 177 |
):
|
| 178 |
+
# ── Capture Reasoning / Thoughts ──
|
| 179 |
+
if hasattr(event, "content") and event.content:
|
| 180 |
+
for part in event.content.parts:
|
| 181 |
+
if part.text:
|
| 182 |
+
# We treat intermittent text as reasoning/logic
|
| 183 |
+
yield {"type": "thought", "content": part.text}
|
| 184 |
+
|
| 185 |
+
# ── Capture Tool Usage ──
|
| 186 |
+
if hasattr(event, "tool_call") and event.tool_call:
|
| 187 |
+
yield {
|
| 188 |
+
"type": "action",
|
| 189 |
+
"content": f"Using tool: {event.tool_call.function_call.name}"
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
# ── Capture Tool Response ──
|
| 193 |
+
if hasattr(event, "tool_response") and event.tool_response:
|
| 194 |
+
yield {
|
| 195 |
+
"type": "observation",
|
| 196 |
+
"content": f"Obtained result from {event.tool_response.function_response.name}"
|
| 197 |
+
}
|
| 198 |
|
| 199 |
except Exception as e:
|
| 200 |
+
logger.error(f"Streaming execution failed: {e}")
|
| 201 |
+
yield {"type": "error", "content": str(e)}
|
app/api/main.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
import sys
|
| 3 |
import asyncio
|
| 4 |
|
|
@@ -6,18 +8,19 @@ if sys.platform == 'win32':
|
|
| 6 |
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
| 7 |
|
| 8 |
import logging
|
| 9 |
-
import datetime
|
| 10 |
-
from datetime import datetime
|
| 11 |
import uuid
|
| 12 |
import sys
|
|
|
|
| 13 |
|
| 14 |
from fastapi import FastAPI, HTTPException, status, Depends, Request
|
| 15 |
-
from fastapi.responses import JSONResponse
|
| 16 |
from slowapi import _rate_limit_exceeded_handler
|
| 17 |
from slowapi.errors import RateLimitExceeded
|
| 18 |
from app.core.limiter import limiter
|
| 19 |
from app.core.orchestrator import Orchestrator
|
| 20 |
-
from app.core.schemas import SolveRequest, SolveResponse, HealthResponse
|
|
|
|
| 21 |
from app.core.logging_config import configure_logging
|
| 22 |
from app.core.errors import AppError, ErrorCodes, ERROR_MESSAGES
|
| 23 |
from app.core.settings import settings # New settings module
|
|
@@ -200,7 +203,7 @@ async def health_check():
|
|
| 200 |
|
| 201 |
return health_status
|
| 202 |
|
| 203 |
-
@app.post("/solve"
|
| 204 |
@limiter.limit("5/minute")
|
| 205 |
async def solve_problem(
|
| 206 |
request: Request,
|
|
@@ -243,55 +246,89 @@ async def solve_problem(
|
|
| 243 |
logger.warning(f"Redis dedup failed (failing open): {e}")
|
| 244 |
# If Redis fails, we allow the request to proceed (fail open)
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
answer=result.get("answer"),
|
| 267 |
-
steps=result.get("steps", []),
|
| 268 |
-
explanation=result.get("explanation"),
|
| 269 |
-
confidence=result.get("confidence", 0.0),
|
| 270 |
-
cached=result.get("cached", False),
|
| 271 |
-
error=result.get("error_msg"), # Keep for backward compat if any
|
| 272 |
-
error_code=result.get("error_code"),
|
| 273 |
-
metadata=public_metadata
|
| 274 |
-
)
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
# --- User Profile Endpoints ---
|
| 297 |
from pydantic import BaseModel
|
|
@@ -348,3 +385,64 @@ async def update_profile(
|
|
| 348 |
if __name__ == "__main__":
|
| 349 |
import uvicorn
|
| 350 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
|
| 3 |
+
from typing import Any, Dict, Optional, List
|
| 4 |
import sys
|
| 5 |
import asyncio
|
| 6 |
|
|
|
|
| 8 |
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
| 9 |
|
| 10 |
import logging
|
| 11 |
+
from datetime import datetime, timezone
|
|
|
|
| 12 |
import uuid
|
| 13 |
import sys
|
| 14 |
+
import json
|
| 15 |
|
| 16 |
from fastapi import FastAPI, HTTPException, status, Depends, Request
|
| 17 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 18 |
from slowapi import _rate_limit_exceeded_handler
|
| 19 |
from slowapi.errors import RateLimitExceeded
|
| 20 |
from app.core.limiter import limiter
|
| 21 |
from app.core.orchestrator import Orchestrator
|
| 22 |
+
from app.core.schemas import SolveRequest, SolveResponse, HealthResponse, Message, ChatSession, SessionRename, UserSignup, UserLogin, TokenResponse
|
| 23 |
+
from app.core.auth_utils import hash_password, verify_password, create_access_token
|
| 24 |
from app.core.logging_config import configure_logging
|
| 25 |
from app.core.errors import AppError, ErrorCodes, ERROR_MESSAGES
|
| 26 |
from app.core.settings import settings # New settings module
|
|
|
|
| 203 |
|
| 204 |
return health_status
|
| 205 |
|
| 206 |
+
@app.post("/solve")
|
| 207 |
@limiter.limit("5/minute")
|
| 208 |
async def solve_problem(
|
| 209 |
request: Request,
|
|
|
|
| 246 |
logger.warning(f"Redis dedup failed (failing open): {e}")
|
| 247 |
# If Redis fails, we allow the request to proceed (fail open)
|
| 248 |
|
| 249 |
+
async def event_generator():
|
| 250 |
+
try:
|
| 251 |
+
async for chunk in orchestrator.process_problem(
|
| 252 |
+
text=solve_req.effective_text,
|
| 253 |
+
image=solve_req.image,
|
| 254 |
+
request_id=final_request_id,
|
| 255 |
+
model_preference=solve_req.model_preference,
|
| 256 |
+
session_id=solve_req.session_id,
|
| 257 |
+
user_id=current_user.get("uid")
|
| 258 |
+
):
|
| 259 |
+
yield json.dumps(chunk) + "\n"
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.error(f"Streaming error: {e}")
|
| 262 |
+
yield json.dumps({"type": "error", "content": "Internal processing error"}) + "\n"
|
| 263 |
+
finally:
|
| 264 |
+
if redis_client:
|
| 265 |
+
try:
|
| 266 |
+
redis_client.delete(dedup_key)
|
| 267 |
+
except Exception:
|
| 268 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
+
return StreamingResponse(event_generator(), media_type="application/x-ndjson")
|
| 271 |
+
|
| 272 |
+
# --- Chat History Endpoints ---
|
| 273 |
+
|
| 274 |
+
@app.get("/chat/sessions", response_model=List[ChatSession])
|
| 275 |
+
async def list_chat_sessions(
|
| 276 |
+
current_user: dict = Depends(get_current_user),
|
| 277 |
+
db_manager = Depends(get_db_manager)
|
| 278 |
+
):
|
| 279 |
+
"""List all chat sessions for the current user."""
|
| 280 |
+
return db_manager.list_sessions(current_user["uid"])
|
| 281 |
+
|
| 282 |
+
@app.post("/chat/sessions", response_model=ChatSession)
|
| 283 |
+
async def create_chat_session(
|
| 284 |
+
current_user: dict = Depends(get_current_user),
|
| 285 |
+
db_manager = Depends(get_db_manager)
|
| 286 |
+
):
|
| 287 |
+
"""Create a new chat session."""
|
| 288 |
+
session_id = str(uuid.uuid4())
|
| 289 |
+
title = "New Chat"
|
| 290 |
+
if db_manager.create_session(current_user["uid"], session_id, title):
|
| 291 |
+
return {
|
| 292 |
+
"session_id": session_id,
|
| 293 |
+
"title": title,
|
| 294 |
+
"created_at": datetime.utcnow()
|
| 295 |
+
}
|
| 296 |
+
raise HTTPException(status_code=500, detail="Failed to create session")
|
| 297 |
+
|
| 298 |
+
@app.get("/chat/sessions/{session_id}/messages", response_model=List[Message])
|
| 299 |
+
async def get_session_history(
|
| 300 |
+
session_id: str,
|
| 301 |
+
current_user: dict = Depends(get_current_user),
|
| 302 |
+
db_manager = Depends(get_db_manager)
|
| 303 |
+
):
|
| 304 |
+
"""Get message history for a specific session."""
|
| 305 |
+
history = db_manager.get_chat_history(current_user["uid"], session_id)
|
| 306 |
+
if not history and history != []:
|
| 307 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 308 |
+
return history
|
| 309 |
+
|
| 310 |
+
@app.patch("/chat/sessions/{session_id}")
|
| 311 |
+
async def rename_chat_session(
|
| 312 |
+
session_id: str,
|
| 313 |
+
rename_data: SessionRename,
|
| 314 |
+
current_user: dict = Depends(get_current_user),
|
| 315 |
+
db_manager = Depends(get_db_manager)
|
| 316 |
+
):
|
| 317 |
+
"""Rename a chat session."""
|
| 318 |
+
if db_manager.rename_session(current_user["uid"], session_id, rename_data.title):
|
| 319 |
+
return {"status": "success", "title": rename_data.title}
|
| 320 |
+
raise HTTPException(status_code=404, detail="Session not found or rename failed")
|
| 321 |
+
|
| 322 |
+
@app.delete("/chat/sessions/{session_id}")
|
| 323 |
+
async def delete_chat_session(
|
| 324 |
+
session_id: str,
|
| 325 |
+
current_user: dict = Depends(get_current_user),
|
| 326 |
+
db_manager = Depends(get_db_manager)
|
| 327 |
+
):
|
| 328 |
+
"""Delete a chat session."""
|
| 329 |
+
if db_manager.delete_session(current_user["uid"], session_id):
|
| 330 |
+
return {"status": "success", "message": "Session deleted"}
|
| 331 |
+
raise HTTPException(status_code=404, detail="Session not found or delete failed")
|
| 332 |
|
| 333 |
# --- User Profile Endpoints ---
|
| 334 |
from pydantic import BaseModel
|
|
|
|
| 385 |
if __name__ == "__main__":
|
| 386 |
import uvicorn
|
| 387 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 388 |
+
# ── Auth Endpoints ─────────────────────────────────────────────────────────
|
| 389 |
+
|
| 390 |
+
@app.post("/auth/signup", response_model=TokenResponse)
|
| 391 |
+
async def signup(
|
| 392 |
+
user_in: UserSignup,
|
| 393 |
+
db_manager = Depends(get_db_manager)
|
| 394 |
+
):
|
| 395 |
+
"""Register a new user."""
|
| 396 |
+
# Check if user already exists
|
| 397 |
+
existing_user = db_manager.get_user_by_email(user_in.email)
|
| 398 |
+
if existing_user:
|
| 399 |
+
raise HTTPException(
|
| 400 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 401 |
+
detail="User with this email already exists"
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
user_id = str(uuid.uuid4())
|
| 405 |
+
hashed_pw = hash_password(user_in.password)
|
| 406 |
+
|
| 407 |
+
user_dict = {
|
| 408 |
+
"user_id": user_id,
|
| 409 |
+
"email": user_in.email,
|
| 410 |
+
"hashed_password": hashed_pw,
|
| 411 |
+
"full_name": user_in.full_name,
|
| 412 |
+
"created_at": datetime.now(timezone.utc)
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
if db_manager.create_user(user_dict):
|
| 416 |
+
token = create_access_token(data={"sub": user_id, "email": user_in.email})
|
| 417 |
+
return {
|
| 418 |
+
"access_token": token,
|
| 419 |
+
"token_type": "bearer",
|
| 420 |
+
"user_id": user_id,
|
| 421 |
+
"email": user_in.email
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
raise HTTPException(status_code=500, detail="Failed to create user")
|
| 425 |
+
|
| 426 |
+
@app.post("/auth/login", response_model=TokenResponse)
|
| 427 |
+
async def login(
|
| 428 |
+
login_in: UserLogin,
|
| 429 |
+
db_manager = Depends(get_db_manager)
|
| 430 |
+
):
|
| 431 |
+
"""Login and get access token."""
|
| 432 |
+
user = db_manager.get_user_by_email(login_in.email)
|
| 433 |
+
if not user or not verify_password(login_in.password, user["hashed_password"]):
|
| 434 |
+
raise HTTPException(
|
| 435 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 436 |
+
detail="Incorrect email or password",
|
| 437 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
user_id = user["user_id"]
|
| 441 |
+
token = create_access_token(data={"sub": user_id, "email": user["email"]})
|
| 442 |
+
|
| 443 |
+
return {
|
| 444 |
+
"access_token": token,
|
| 445 |
+
"token_type": "bearer",
|
| 446 |
+
"user_id": user_id,
|
| 447 |
+
"email": user["email"]
|
| 448 |
+
}
|
app/core/orchestrator.py
CHANGED
|
@@ -3,7 +3,8 @@ import time
|
|
| 3 |
import hashlib
|
| 4 |
import json
|
| 5 |
import re
|
| 6 |
-
|
|
|
|
| 7 |
|
| 8 |
import sympy
|
| 9 |
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
|
|
@@ -17,21 +18,12 @@ from app.core.settings import settings
|
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
-
# Transformations that allow "2x" → "2*x" in SymPy
|
| 21 |
_SYMPY_TRANSFORMATIONS = standard_transformations + (implicit_multiplication_application,)
|
| 22 |
|
| 23 |
|
| 24 |
class Orchestrator:
|
| 25 |
"""
|
| 26 |
-
Pipeline: Cache → Pre-flight (SymPy) →
|
| 27 |
-
|
| 28 |
-
Pre-flight layer solves arithmetic, algebra, derivatives, integrals, and
|
| 29 |
-
equations locally with SymPy — ZERO LLM calls. Only queries that cannot
|
| 30 |
-
be resolved symbolically are forwarded to the ADK Agent.
|
| 31 |
-
|
| 32 |
-
With a 20 RPD quota this means:
|
| 33 |
-
- Simple math → instant, free, no quota used
|
| 34 |
-
- Complex math → ADK agent, 1-2 LLM calls as needed
|
| 35 |
"""
|
| 36 |
|
| 37 |
def __init__(
|
|
@@ -51,9 +43,6 @@ class Orchestrator:
|
|
| 51 |
logger.critical(f"Failed to initialize Orchestrator: {e}")
|
| 52 |
raise
|
| 53 |
|
| 54 |
-
# ──────────────────────────────────────────────────────────────────────────
|
| 55 |
-
# Public entry point
|
| 56 |
-
# ──────────────────────────────────────────────────────────────────────────
|
| 57 |
async def process_problem(
|
| 58 |
self,
|
| 59 |
text: Optional[str] = None,
|
|
@@ -62,222 +51,173 @@ class Orchestrator:
|
|
| 62 |
model_preference: str = "fast",
|
| 63 |
session_id: Optional[str] = None,
|
| 64 |
user_id: Optional[str] = None,
|
| 65 |
-
) -> Dict[str, Any]:
|
| 66 |
|
| 67 |
start_time = time.time()
|
| 68 |
request_id = request_id or "unknown"
|
| 69 |
|
| 70 |
result_schema: Dict[str, Any] = {
|
| 71 |
"request_id": request_id,
|
| 72 |
-
"status": "
|
| 73 |
-
"source": "
|
| 74 |
-
"answer":
|
| 75 |
-
"
|
| 76 |
-
"explanation": None,
|
| 77 |
-
"confidence": 0.0,
|
| 78 |
-
"cached": False,
|
| 79 |
-
"metadata": {"latency_ms": 0, "model": "sympy_preflight", "tools_used": []},
|
| 80 |
}
|
| 81 |
|
| 82 |
try:
|
| 83 |
-
# ──
|
| 84 |
processed = self.input_processor.process_compound(text_input=text, image_input=image)
|
| 85 |
if not processed.is_valid:
|
| 86 |
-
|
| 87 |
-
return
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
# ──
|
| 93 |
if settings.ENABLE_CACHE and not image_data:
|
| 94 |
-
cache_key
|
| 95 |
-
cached
|
| 96 |
if cached:
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
else:
|
| 102 |
cache_key = None
|
| 103 |
|
| 104 |
-
# ──
|
| 105 |
-
# Only attempted when there is no image (images need vision model).
|
| 106 |
if not image_data:
|
| 107 |
preflight_result = self._try_sympy(query)
|
| 108 |
if preflight_result is not None:
|
|
|
|
|
|
|
|
|
|
| 109 |
result_schema.update({
|
| 110 |
-
"
|
| 111 |
-
"
|
| 112 |
-
"
|
| 113 |
-
"explanation": "Solved locally by SymPy — no LLM call needed.",
|
| 114 |
-
"confidence": 1.0,
|
| 115 |
-
"metadata": {"latency_ms": 0, "model": "sympy_preflight", "tools_used": ["sympy"]},
|
| 116 |
})
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
logger.error(f"ADK Agent execution failed: {e}")
|
| 145 |
-
result_schema["explanation"] = f"Agent Error: {str(e)}"
|
| 146 |
-
return self._finalize(result_schema, start_time)
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
self.cache_manager.set_cached_answer(cache_key, result_schema)
|
| 152 |
-
self.db_manager.save_problem({"content": query}, result_schema)
|
| 153 |
-
|
| 154 |
-
return self._finalize(result_schema, start_time)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
except Exception as e:
|
| 157 |
-
logger.error(f"
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
# ──────────────────────────────────────────────────────────────────────────
|
| 162 |
-
# Pre-flight: pure SymPy resolution, no LLM
|
| 163 |
-
# ──────────────────────────────────────────────────────────────────────────
|
| 164 |
def _try_sympy(self, query: str) -> Optional[str]:
|
| 165 |
-
"""
|
| 166 |
-
Attempt to solve the query with the MathQueryNormalizer + SymPy.
|
| 167 |
-
Returns a human-readable answer string, or None if it can't be solved
|
| 168 |
-
locally (which means the ADK Agent should handle it).
|
| 169 |
-
"""
|
| 170 |
try:
|
| 171 |
intent: Optional[MathIntent] = self.normalizer.normalize(query)
|
| 172 |
-
if intent is None:
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
if intent.intent == "arithmetic":
|
| 179 |
-
return self._solve_arithmetic(expr_str)
|
| 180 |
-
|
| 181 |
-
if intent.intent == "equation":
|
| 182 |
-
return self._solve_equation(expr_str, target_var)
|
| 183 |
-
|
| 184 |
if intent.intent == "derivative":
|
| 185 |
-
expr
|
| 186 |
-
|
| 187 |
-
return f"d/d{target_var}({intent.expression}) = {sympy.latex(result)}"
|
| 188 |
-
|
| 189 |
if intent.intent == "integral":
|
| 190 |
-
expr
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
if intent.intent == "limit":
|
| 195 |
-
return self._solve_limit(intent, query)
|
| 196 |
-
|
| 197 |
if intent.intent == "simplification":
|
| 198 |
-
expr
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
except Exception as e:
|
| 203 |
-
# Any parse/evaluation error → fall through to agent
|
| 204 |
-
logger.debug(f"Pre-flight SymPy failed for '{query}': {e}")
|
| 205 |
-
|
| 206 |
return None
|
| 207 |
|
| 208 |
def _prep_expr(self, expr: str) -> str:
|
| 209 |
-
|
| 210 |
-
expr =
|
| 211 |
-
expr = re.sub(r"
|
| 212 |
-
expr = re.sub(r"\)\s*\(", ")*(", expr) # )( → )*(
|
| 213 |
return expr.strip()
|
| 214 |
|
| 215 |
def _solve_arithmetic(self, expr_str: str) -> Optional[str]:
|
| 216 |
-
"""Evaluate a pure arithmetic/algebraic expression."""
|
| 217 |
try:
|
| 218 |
-
|
| 219 |
-
result = sympy.simplify(expr)
|
| 220 |
-
# If the result is a number, show it plainly; otherwise use LaTeX
|
| 221 |
if result.is_number:
|
| 222 |
numeric = float(result)
|
| 223 |
-
|
| 224 |
-
if numeric == int(numeric):
|
| 225 |
-
return str(int(numeric))
|
| 226 |
-
return f"{numeric:.6g}"
|
| 227 |
return sympy.latex(result)
|
| 228 |
-
except Exception:
|
| 229 |
-
return None
|
| 230 |
|
| 231 |
def _solve_equation(self, expr_str: str, var: sympy.Symbol) -> Optional[str]:
|
| 232 |
-
"""Solve an equation of the form lhs = rhs, or expr = 0."""
|
| 233 |
try:
|
| 234 |
parts = expr_str.split("=", 1)
|
| 235 |
if len(parts) == 2:
|
| 236 |
-
lhs
|
| 237 |
-
rhs
|
| 238 |
solution = sympy.solve(lhs - rhs, var)
|
| 239 |
else:
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
if len(solution) == 1:
|
| 246 |
-
return f"{var} = {sympy.latex(solution[0])}"
|
| 247 |
-
sols = ", ".join(sympy.latex(s) for s in solution)
|
| 248 |
-
return f"{var} ∈ {{{sols}}}"
|
| 249 |
-
except Exception:
|
| 250 |
-
return None
|
| 251 |
|
| 252 |
def _solve_limit(self, intent: MathIntent, original_query: str) -> Optional[str]:
|
| 253 |
-
"""Parse and evaluate a limit expression."""
|
| 254 |
try:
|
| 255 |
-
|
| 256 |
-
match
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
expr_raw = self._prep_expr(match.group(1))
|
| 263 |
-
var_name = match.group(2).strip()
|
| 264 |
-
point_raw = self._prep_expr(match.group(3).strip())
|
| 265 |
-
|
| 266 |
-
expr = parse_expr(expr_raw, transformations=_SYMPY_TRANSFORMATIONS)
|
| 267 |
-
var = sympy.Symbol(var_name)
|
| 268 |
-
point = parse_expr(point_raw, transformations=_SYMPY_TRANSFORMATIONS)
|
| 269 |
|
| 270 |
-
result = sympy.limit(expr, var, point)
|
| 271 |
-
return f"lim({var}→{point}) {sympy.latex(expr)} = {sympy.latex(result)}"
|
| 272 |
-
except Exception:
|
| 273 |
-
return None
|
| 274 |
-
|
| 275 |
-
# ──────────────────────────────────────────────────────────────────────────
|
| 276 |
-
# Utilities
|
| 277 |
-
# ──────────────────────────────────────────────────────────────────────────
|
| 278 |
def _make_cache_key(self, query: str) -> str:
|
| 279 |
return hashlib.sha256(query.strip().lower().encode()).hexdigest()
|
| 280 |
-
|
| 281 |
-
def _finalize(self, schema: Dict, start_time: float) -> Dict:
|
| 282 |
-
schema["metadata"]["latency_ms"] = int((time.time() - start_time) * 1000)
|
| 283 |
-
return schema
|
|
|
|
| 3 |
import hashlib
|
| 4 |
import json
|
| 5 |
import re
|
| 6 |
+
import asyncio
|
| 7 |
+
from typing import Any, Dict, Optional, AsyncGenerator
|
| 8 |
|
| 9 |
import sympy
|
| 10 |
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
|
|
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
|
|
|
| 21 |
_SYMPY_TRANSFORMATIONS = standard_transformations + (implicit_multiplication_application,)
|
| 22 |
|
| 23 |
|
| 24 |
class Orchestrator:
|
| 25 |
"""
|
| 26 |
+
Evolved Pipeline: Cache → Pre-flight (SymPy) → Agentic Streaming Loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
|
| 29 |
def __init__(
|
|
|
|
| 43 |
logger.critical(f"Failed to initialize Orchestrator: {e}")
|
| 44 |
raise
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
async def process_problem(
|
| 47 |
self,
|
| 48 |
text: Optional[str] = None,
|
|
|
|
| 51 |
model_preference: str = "fast",
|
| 52 |
session_id: Optional[str] = None,
|
| 53 |
user_id: Optional[str] = None,
|
| 54 |
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
| 55 |
|
| 56 |
start_time = time.time()
|
| 57 |
request_id = request_id or "unknown"
|
| 58 |
|
| 59 |
result_schema: Dict[str, Any] = {
|
| 60 |
"request_id": request_id,
|
| 61 |
+
"status": "success",
|
| 62 |
+
"source": "agent",
|
| 63 |
+
"answer": "",
|
| 64 |
+
"metadata": {"latency_ms": 0, "model": "gemini-2.0-flash", "tools_used": []},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
}
|
| 66 |
|
| 67 |
try:
|
| 68 |
+
# ── 1. Input processing ───────────────────────────────────────────
|
| 69 |
processed = self.input_processor.process_compound(text_input=text, image_input=image)
|
| 70 |
if not processed.is_valid:
|
| 71 |
+
yield {"type": "error", "content": processed.error_message}
|
| 72 |
+
return
|
| 73 |
+
|
| 74 |
+
query = processed.cleaned_content
|
| 75 |
+
image_data = processed.metadata.get("image_data")
|
| 76 |
|
| 77 |
+
# Background: Persist user message
|
| 78 |
+
if user_id and session_id:
|
| 79 |
+
asyncio.create_task(self._persist_message(
|
| 80 |
+
user_id=user_id, session_id=session_id, role="user",
|
| 81 |
+
content=text or "Uploaded an image", image_data=image_data
|
| 82 |
+
))
|
| 83 |
|
| 84 |
+
# ── 2. Cache lookup ───────────────────────────────────────────────
|
| 85 |
if settings.ENABLE_CACHE and not image_data:
|
| 86 |
+
cache_key = self._make_cache_key(query)
|
| 87 |
+
cached = self.cache_manager.get_cached_answer(cache_key)
|
| 88 |
if cached:
|
| 89 |
+
yield {"type": "thought", "content": "Retrieving answer from memory..."}
|
| 90 |
+
yield {"type": "answer", "content": cached.get("answer")}
|
| 91 |
+
# Background: Persist assistant response
|
| 92 |
+
if user_id and session_id:
|
| 93 |
+
asyncio.create_task(self._persist_message(
|
| 94 |
+
user_id=user_id, session_id=session_id, role="assistant",
|
| 95 |
+
content=cached.get("answer"), metadata=cached.get("metadata")
|
| 96 |
+
))
|
| 97 |
+
return
|
| 98 |
else:
|
| 99 |
cache_key = None
|
| 100 |
|
| 101 |
+
# ── 3. Pre-flight (SymPy) ─────────────────────────────────────────
|
|
|
|
| 102 |
if not image_data:
|
| 103 |
preflight_result = self._try_sympy(query)
|
| 104 |
if preflight_result is not None:
|
| 105 |
+
yield {"type": "thought", "content": "Calculating result symbolically..."}
|
| 106 |
+
yield {"type": "answer", "content": preflight_result}
|
| 107 |
+
|
| 108 |
result_schema.update({
|
| 109 |
+
"source": "sympy_preflight",
|
| 110 |
+
"answer": preflight_result,
|
| 111 |
+
"metadata": {"model": "sympy", "tools_used": ["sympy"]}
|
|
|
|
|
|
|
|
|
|
| 112 |
})
|
| 113 |
+
|
| 114 |
+
self._background_log(query, result_schema, user_id, session_id, cache_key)
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
# ── 4. Agentic Streaming Loop ─────────────────────────────────────
|
| 118 |
+
full_answer = ""
|
| 119 |
+
async for event in self.adk_agent.solve(
|
| 120 |
+
problem=query, image_data=image_data,
|
| 121 |
+
session_id=session_id, user_id=user_id
|
| 122 |
+
):
|
| 123 |
+
if event["type"] == "thought":
|
| 124 |
+
yield event
|
| 125 |
+
elif event["type"] in ("action", "observation"):
|
| 126 |
+
yield event
|
| 127 |
+
elif event["type"] == "error":
|
| 128 |
+
yield event
|
| 129 |
+
else:
|
| 130 |
+
# Treat everything else as part of the answer
|
| 131 |
+
full_answer += event["content"]
|
| 132 |
+
yield {"type": "answer", "content": event["content"]}
|
| 133 |
+
|
| 134 |
+
# ── 5. Finalize ───────────────────────────────────────────────────
|
| 135 |
+
result_schema["answer"] = full_answer
|
| 136 |
+
result_schema["metadata"]["latency_ms"] = int((time.time() - start_time) * 1000)
|
| 137 |
+
|
| 138 |
+
if full_answer:
|
| 139 |
+
self._background_log(query, result_schema, user_id, session_id, cache_key)
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Orchestrator Error: {e}")
|
| 143 |
+
yield {"type": "error", "content": f"Internal Error: {str(e)}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
async def _persist_message(self, user_id, session_id, role, content, **kwargs):
|
| 146 |
+
try:
|
| 147 |
+
self.db_manager.create_session(user_id, session_id)
|
| 148 |
+
self.db_manager.save_chat_message(user_id, session_id, role, content, **kwargs)
|
| 149 |
except Exception as e:
|
| 150 |
+
logger.error(f"Failed to persist message: {e}")
|
| 151 |
+
|
| 152 |
+
def _background_log(self, query, schema, user_id, session_id, cache_key):
|
| 153 |
+
"""Fire and forget persistence tasks."""
|
| 154 |
+
asyncio.create_task(self._persist_message(
|
| 155 |
+
user_id=user_id, session_id=session_id, role="assistant",
|
| 156 |
+
content=schema["answer"], metadata=schema["metadata"]
|
| 157 |
+
))
|
| 158 |
+
if settings.ENABLE_CACHE and cache_key:
|
| 159 |
+
self.cache_manager.set_cached_answer(cache_key, schema)
|
| 160 |
+
self.db_manager.save_problem({"content": query}, schema)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
| 162 |
def _try_sympy(self, query: str) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
try:
|
| 164 |
intent: Optional[MathIntent] = self.normalizer.normalize(query)
|
| 165 |
+
if intent is None: return None
|
| 166 |
+
expr_str = self._prep_expr(intent.expression)
|
| 167 |
+
target_var = sympy.Symbol(intent.variable or "x")
|
| 168 |
+
if intent.intent == "arithmetic": return self._solve_arithmetic(expr_str)
|
| 169 |
+
if intent.intent == "equation": return self._solve_equation(expr_str, target_var)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
if intent.intent == "derivative":
|
| 171 |
+
expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
|
| 172 |
+
return f"d/d{target_var}({intent.expression}) = {sympy.latex(sympy.diff(expr, target_var))}"
|
|
|
|
|
|
|
| 173 |
if intent.intent == "integral":
|
| 174 |
+
expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
|
| 175 |
+
return f"∫({intent.expression}) d{target_var} = {sympy.latex(sympy.integrate(expr, target_var))} + C"
|
| 176 |
+
if intent.intent == "limit": return self._solve_limit(intent, query)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
if intent.intent == "simplification":
|
| 178 |
+
expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
|
| 179 |
+
return f"Simplified: {sympy.latex(sympy.simplify(expr))}"
|
| 180 |
+
except Exception: pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
return None
|
| 182 |
|
| 183 |
def _prep_expr(self, expr: str) -> str:
|
| 184 |
+
expr = expr.replace("^", "**")
|
| 185 |
+
expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
|
| 186 |
+
expr = re.sub(r"\)\s*\(", ")*(", expr)
|
|
|
|
| 187 |
return expr.strip()
|
| 188 |
|
| 189 |
def _solve_arithmetic(self, expr_str: str) -> Optional[str]:
|
|
|
|
| 190 |
try:
|
| 191 |
+
result = sympy.simplify(parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS))
|
|
|
|
|
|
|
| 192 |
if result.is_number:
|
| 193 |
numeric = float(result)
|
| 194 |
+
return str(int(numeric)) if numeric == int(numeric) else f"{numeric:.6g}"
|
|
|
|
|
|
|
|
|
|
| 195 |
return sympy.latex(result)
|
| 196 |
+
except Exception: return None
|
|
|
|
| 197 |
|
| 198 |
def _solve_equation(self, expr_str: str, var: sympy.Symbol) -> Optional[str]:
|
|
|
|
| 199 |
try:
|
| 200 |
parts = expr_str.split("=", 1)
|
| 201 |
if len(parts) == 2:
|
| 202 |
+
lhs = parse_expr(self._prep_expr(parts[0]), transformations=_SYMPY_TRANSFORMATIONS)
|
| 203 |
+
rhs = parse_expr(self._prep_expr(parts[1]), transformations=_SYMPY_TRANSFORMATIONS)
|
| 204 |
solution = sympy.solve(lhs - rhs, var)
|
| 205 |
else:
|
| 206 |
+
solution = sympy.solve(parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS), var)
|
| 207 |
+
if not solution: return "No solution found."
|
| 208 |
+
if len(solution) == 1: return f"{var} = {sympy.latex(solution[0])}"
|
| 209 |
+
return f"{var} ∈ {{{', '.join(sympy.latex(s) for s in solution)}}}"
|
| 210 |
+
except Exception: return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
def _solve_limit(self, intent: MathIntent, original_query: str) -> Optional[str]:
|
|
|
|
| 213 |
try:
|
| 214 |
+
match = re.search(r"limit of\s+(.+?)\s+as\s+(\w+)\s+approaches\s+(.+)", original_query, re.IGNORECASE)
|
| 215 |
+
if not match: return None
|
| 216 |
+
expr = parse_expr(self._prep_expr(match.group(1)), transformations=_SYMPY_TRANSFORMATIONS)
|
| 217 |
+
var = sympy.Symbol(match.group(2).strip())
|
| 218 |
+
point = parse_expr(self._prep_expr(match.group(3).strip()), transformations=_SYMPY_TRANSFORMATIONS)
|
| 219 |
+
return f"lim({var}→{point}) {sympy.latex(expr)} = {sympy.latex(sympy.limit(expr, var, point))}"
|
| 220 |
+
except Exception: return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
def _make_cache_key(self, query: str) -> str:
|
| 223 |
return hashlib.sha256(query.strip().lower().encode()).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
app/core/schemas.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from typing import Any, Dict, Optional, List
|
| 2 |
from pydantic import BaseModel, Field, model_validator
|
| 3 |
|
|
@@ -58,3 +59,39 @@ class HealthResponse(BaseModel):
|
|
| 58 |
"""
|
| 59 |
status: str
|
| 60 |
version: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
from typing import Any, Dict, Optional, List
|
| 3 |
from pydantic import BaseModel, Field, model_validator
|
| 4 |
|
|
|
|
| 59 |
"""
|
| 60 |
status: str
|
| 61 |
version: str
|
| 62 |
+
|
| 63 |
+
# --- Chat History Schemas ---
|
| 64 |
+
|
| 65 |
+
class Message(BaseModel):
|
| 66 |
+
role: str
|
| 67 |
+
content: str
|
| 68 |
+
timestamp: datetime
|
| 69 |
+
reasoning: Optional[str] = None
|
| 70 |
+
metadata: Dict[str, Any] = {}
|
| 71 |
+
steps: List[str] = []
|
| 72 |
+
|
| 73 |
+
class ChatSession(BaseModel):
|
| 74 |
+
session_id: str
|
| 75 |
+
title: str
|
| 76 |
+
created_at: datetime
|
| 77 |
+
# messages: Optional[List[Message]] = None # Optional for listing
|
| 78 |
+
|
| 79 |
+
class SessionRename(BaseModel):
|
| 80 |
+
title: str = Field(..., min_length=1, max_length=100)
|
| 81 |
+
|
| 82 |
+
# --- Auth Schemas ---
|
| 83 |
+
|
| 84 |
+
class UserSignup(BaseModel):
|
| 85 |
+
email: str
|
| 86 |
+
password: str = Field(..., min_length=8, max_length=72)
|
| 87 |
+
full_name: Optional[str] = None
|
| 88 |
+
|
| 89 |
+
class UserLogin(BaseModel):
|
| 90 |
+
email: str
|
| 91 |
+
password: str = Field(..., max_length=72)
|
| 92 |
+
|
| 93 |
+
class TokenResponse(BaseModel):
|
| 94 |
+
access_token: str
|
| 95 |
+
token_type: str = "bearer"
|
| 96 |
+
user_id: str
|
| 97 |
+
email: str
|
app/core/security.py
CHANGED
|
@@ -1,31 +1,16 @@
|
|
| 1 |
import logging
|
| 2 |
-
import firebase_admin
|
| 3 |
-
from firebase_admin import credentials, auth
|
| 4 |
from fastapi import HTTPException, status, Security
|
| 5 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 6 |
from app.core.settings import settings
|
|
|
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
-
# Initialize Firebase Admin SDK
|
| 11 |
-
try:
|
| 12 |
-
if settings.FIREBASE_CREDENTIALS_PATH:
|
| 13 |
-
if not firebase_admin._apps:
|
| 14 |
-
cred = credentials.Certificate(settings.FIREBASE_CREDENTIALS_PATH)
|
| 15 |
-
firebase_admin.initialize_app(cred)
|
| 16 |
-
logger.info("Firebase Admin SDK initialized successfully.")
|
| 17 |
-
else:
|
| 18 |
-
logger.info("Firebase Admin SDK already initialized.")
|
| 19 |
-
else:
|
| 20 |
-
logger.warning("FIREBASE_CREDENTIALS_PATH not set. Auth will fail if enabled.")
|
| 21 |
-
except Exception as e:
|
| 22 |
-
logger.error(f"Failed to initialize Firebase: {e}")
|
| 23 |
-
|
| 24 |
security = HTTPBearer()
|
| 25 |
|
| 26 |
def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
|
| 27 |
"""
|
| 28 |
-
Verifies the
|
| 29 |
Returns the decoded token dict if valid.
|
| 30 |
"""
|
| 31 |
token = credentials.credentials
|
|
@@ -39,16 +24,21 @@ def verify_token(credentials: HTTPAuthorizationCredentials = Security(security))
|
|
| 39 |
logger.info(f"Using MOCK AUTH for token: {token}")
|
| 40 |
return {"uid": "dev_user_123", "email": "dev@mathminds.ai"}
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
def get_current_user(token: dict = Security(verify_token)):
|
| 54 |
"""
|
|
|
|
| 1 |
import logging
|
|
|
|
|
|
|
| 2 |
from fastapi import HTTPException, status, Security
|
| 3 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 4 |
from app.core.settings import settings
|
| 5 |
+
from app.core.auth_utils import decode_access_token
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
security = HTTPBearer()
|
| 10 |
|
| 11 |
def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
|
| 12 |
"""
|
| 13 |
+
Verifies the Local JWT access token.
|
| 14 |
Returns the decoded token dict if valid.
|
| 15 |
"""
|
| 16 |
token = credentials.credentials
|
|
|
|
| 24 |
logger.info(f"Using MOCK AUTH for token: {token}")
|
| 25 |
return {"uid": "dev_user_123", "email": "dev@mathminds.ai"}
|
| 26 |
|
| 27 |
+
# Use local JWT verification
|
| 28 |
+
payload = decode_access_token(token)
|
| 29 |
+
if payload:
|
| 30 |
+
# Map 'sub' from JWT to 'uid' to maintain compatibility with existing code
|
| 31 |
+
return {
|
| 32 |
+
"uid": payload.get("sub"),
|
| 33 |
+
"email": payload.get("email")
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
logger.warning(f"Invalid or expired token provided.")
|
| 37 |
+
raise HTTPException(
|
| 38 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 39 |
+
detail="Invalid or expired authentication credentials",
|
| 40 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 41 |
+
)
|
| 42 |
|
| 43 |
def get_current_user(token: dict = Security(verify_token)):
|
| 44 |
"""
|
app/core/settings.py
CHANGED
|
@@ -40,6 +40,11 @@ class Settings(BaseSettings):
|
|
| 40 |
SUPABASE_KEY: Optional[str] = None
|
| 41 |
WOLFRAM_APP_ID: Optional[str] = None
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
model_config = {
|
| 44 |
"env_file": ".env",
|
| 45 |
"case_sensitive": True,
|
|
|
|
| 40 |
SUPABASE_KEY: Optional[str] = None
|
| 41 |
WOLFRAM_APP_ID: Optional[str] = None
|
| 42 |
|
| 43 |
+
# Security
|
| 44 |
+
JWT_SECRET_KEY: str = "super_secret_key_change_me"
|
| 45 |
+
JWT_ALGORITHM: str = "HS256"
|
| 46 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7 # 1 Week
|
| 47 |
+
|
| 48 |
model_config = {
|
| 49 |
"env_file": ".env",
|
| 50 |
"case_sensitive": True,
|
app/memory/database.py
CHANGED
|
@@ -55,10 +55,15 @@ class DatabaseManager:
|
|
| 55 |
|
| 56 |
self.db = self.client[db_name]
|
| 57 |
self.collection = self.db["solved_problems"]
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
# Ensure
|
| 60 |
-
|
| 61 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
logger.info(f"Successfully connected to MongoDB at {self.mongo_uri} (DB: {db_name})")
|
| 64 |
|
|
@@ -134,18 +139,20 @@ class DatabaseManager:
|
|
| 134 |
logger.error(f"Failed to save problem: {e}")
|
| 135 |
return False
|
| 136 |
|
| 137 |
-
def create_session(self, session_id: str, title: str = "New Chat") -> bool:
|
| 138 |
"""
|
| 139 |
-
|
|
|
|
| 140 |
"""
|
| 141 |
-
if self.
|
| 142 |
return False
|
| 143 |
try:
|
| 144 |
-
self.
|
| 145 |
{"session_id": session_id},
|
| 146 |
{
|
| 147 |
"$setOnInsert": {
|
| 148 |
"session_id": session_id,
|
|
|
|
| 149 |
"title": title,
|
| 150 |
"created_at": datetime.now(timezone.utc),
|
| 151 |
"messages": []
|
|
@@ -154,78 +161,151 @@ class DatabaseManager:
|
|
| 154 |
upsert=True
|
| 155 |
)
|
| 156 |
return True
|
| 157 |
-
except
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
return False
|
| 160 |
|
| 161 |
-
def
|
| 162 |
"""
|
| 163 |
-
Retrieve
|
| 164 |
"""
|
| 165 |
-
if self.
|
| 166 |
return []
|
| 167 |
try:
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
{"session_id":
|
| 171 |
-
|
| 172 |
-
)
|
| 173 |
-
if doc and "messages" in doc:
|
| 174 |
-
return doc["messages"]
|
| 175 |
-
return []
|
| 176 |
except PyMongoError as e:
|
| 177 |
-
logger.error(f"Failed to
|
| 178 |
return []
|
| 179 |
|
| 180 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
"""
|
| 182 |
Append a message to the session history.
|
| 183 |
-
|
| 184 |
"""
|
| 185 |
-
if self.
|
| 186 |
return False
|
| 187 |
try:
|
| 188 |
# logic to update title if it's currently "New Chat" and this is a user message
|
| 189 |
if role == "user":
|
| 190 |
-
session = self.
|
| 191 |
-
if session and session.get("title") == "New Chat":
|
| 192 |
-
# Generate title from content (truncate)
|
| 193 |
new_title = content[:50] + "..." if len(content) > 50 else content
|
| 194 |
-
self.
|
| 195 |
-
{"session_id": session_id},
|
| 196 |
{"$set": {"title": new_title}}
|
| 197 |
)
|
| 198 |
|
| 199 |
# Push the new message
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
},
|
| 211 |
-
upsert=True
|
| 212 |
)
|
| 213 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
except PyMongoError as e:
|
| 215 |
-
logger.error(f"Failed to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
return False
|
| 217 |
|
| 218 |
# -------------------------------------------------------------------------
|
| 219 |
# User Profile Management
|
| 220 |
# -------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
|
| 222 |
"""
|
| 223 |
-
Retrieve user profile by ID (Firebase UID).
|
| 224 |
"""
|
| 225 |
-
if self.
|
| 226 |
return None
|
| 227 |
try:
|
| 228 |
-
return self.
|
| 229 |
except PyMongoError as e:
|
| 230 |
logger.error(f"Failed to get profile for {user_id}: {e}")
|
| 231 |
return None
|
|
@@ -234,10 +314,10 @@ class DatabaseManager:
|
|
| 234 |
"""
|
| 235 |
Update or create user profile.
|
| 236 |
"""
|
| 237 |
-
if self.
|
| 238 |
return False
|
| 239 |
try:
|
| 240 |
-
self.
|
| 241 |
{"user_id": user_id},
|
| 242 |
{"$set": {**data, "updated_at": datetime.now(timezone.utc)}},
|
| 243 |
upsert=True
|
|
|
|
| 55 |
|
| 56 |
self.db = self.client[db_name]
|
| 57 |
self.collection = self.db["solved_problems"]
|
| 58 |
+
self.sessions_collection = self.db["chat_sessions"]
|
| 59 |
+
self.users_collection = self.db["users"]
|
| 60 |
|
| 61 |
+
# Ensure indexes
|
| 62 |
+
self.collection.create_index([("hash", ASCENDING)], name="hash_index")
|
| 63 |
+
self.sessions_collection.create_index([("user_id", ASCENDING)], name="user_id_index")
|
| 64 |
+
self.sessions_collection.create_index([("session_id", ASCENDING)], name="session_id_index", unique=True)
|
| 65 |
+
self.users_collection.create_index([("email", ASCENDING)], name="email_index", unique=True)
|
| 66 |
+
self.users_collection.create_index([("user_id", ASCENDING)], name="user_id_index", unique=True)
|
| 67 |
|
| 68 |
logger.info(f"Successfully connected to MongoDB at {self.mongo_uri} (DB: {db_name})")
|
| 69 |
|
|
|
|
| 139 |
logger.error(f"Failed to save problem: {e}")
|
| 140 |
return False
|
| 141 |
|
| 142 |
+
def create_session(self, user_id: str, session_id: str, title: str = "New Chat") -> bool:
|
| 143 |
"""
|
| 144 |
+
Create a new chat session for a user.
|
| 145 |
+
Uses upsert with $setOnInsert to be idempotent.
|
| 146 |
"""
|
| 147 |
+
if self.sessions_collection is None:
|
| 148 |
return False
|
| 149 |
try:
|
| 150 |
+
self.sessions_collection.update_one(
|
| 151 |
{"session_id": session_id},
|
| 152 |
{
|
| 153 |
"$setOnInsert": {
|
| 154 |
"session_id": session_id,
|
| 155 |
+
"user_id": user_id,
|
| 156 |
"title": title,
|
| 157 |
"created_at": datetime.now(timezone.utc),
|
| 158 |
"messages": []
|
|
|
|
| 161 |
upsert=True
|
| 162 |
)
|
| 163 |
return True
|
| 164 |
+
except Exception as e:
|
| 165 |
+
# If it's a duplicate key error, it means another thread just inserted it.
|
| 166 |
+
# That's fine, we consider the session "created" or at least existing.
|
| 167 |
+
if "E11000" in str(e) or "duplicate key" in str(e).lower():
|
| 168 |
+
return True
|
| 169 |
+
logger.error(f"Failed to create session {session_id} for user {user_id}: {e}")
|
| 170 |
return False
|
| 171 |
|
| 172 |
+
def list_sessions(self, user_id: str) -> List[Dict[str, Any]]:
|
| 173 |
"""
|
| 174 |
+
Retrieve all sessions for a specific user.
|
| 175 |
"""
|
| 176 |
+
if self.sessions_collection is None:
|
| 177 |
return []
|
| 178 |
try:
|
| 179 |
+
cursor = self.sessions_collection.find(
|
| 180 |
+
{"user_id": user_id},
|
| 181 |
+
{"session_id": 1, "title": 1, "created_at": 1, "_id": 0}
|
| 182 |
+
).sort("created_at", -1)
|
| 183 |
+
return list(cursor)
|
|
|
|
|
|
|
|
|
|
| 184 |
except PyMongoError as e:
|
| 185 |
+
logger.error(f"Failed to list sessions for user {user_id}: {e}")
|
| 186 |
return []
|
| 187 |
|
| 188 |
+
def get_chat_history(self, user_id: str, session_id: str, limit: int = 50) -> Optional[List[Dict[str, Any]]]:
|
| 189 |
+
"""
|
| 190 |
+
Retrieve recent messages for a session, ensuring it belongs to the user.
|
| 191 |
+
Returns None if session not found or not owned by user.
|
| 192 |
+
"""
|
| 193 |
+
if self.sessions_collection is None:
|
| 194 |
+
return None
|
| 195 |
+
try:
|
| 196 |
+
doc = self.sessions_collection.find_one(
|
| 197 |
+
{"session_id": session_id, "user_id": user_id},
|
| 198 |
+
{"messages": {"$slice": -limit}, "_id": 0}
|
| 199 |
+
)
|
| 200 |
+
if doc is not None:
|
| 201 |
+
return doc.get("messages", [])
|
| 202 |
+
return None
|
| 203 |
+
except PyMongoError as e:
|
| 204 |
+
logger.error(f"Failed to get history for {session_id} (user: {user_id}): {e}")
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
def save_chat_message(self, user_id: str, session_id: str, role: str, content: str, **kwargs) -> bool:
|
| 208 |
"""
|
| 209 |
Append a message to the session history.
|
| 210 |
+
Only succeeds if the session belongs to the user_id.
|
| 211 |
"""
|
| 212 |
+
if self.sessions_collection is None:
|
| 213 |
return False
|
| 214 |
try:
|
| 215 |
# logic to update title if it's currently "New Chat" and this is a user message
|
| 216 |
if role == "user":
|
| 217 |
+
session = self.sessions_collection.find_one({"session_id": session_id, "user_id": user_id})
|
| 218 |
+
if session and (session.get("title") == "New Chat" or session.get("title") == "New Session" or session.get("title") == "Untitled"):
|
|
|
|
| 219 |
new_title = content[:50] + "..." if len(content) > 50 else content
|
| 220 |
+
self.sessions_collection.update_one(
|
| 221 |
+
{"session_id": session_id, "user_id": user_id},
|
| 222 |
{"$set": {"title": new_title}}
|
| 223 |
)
|
| 224 |
|
| 225 |
# Push the new message
|
| 226 |
+
msg = {
|
| 227 |
+
"role": role,
|
| 228 |
+
"content": content,
|
| 229 |
+
"timestamp": datetime.now(timezone.utc)
|
| 230 |
+
}
|
| 231 |
+
msg.update(kwargs)
|
| 232 |
+
|
| 233 |
+
result = self.sessions_collection.update_one(
|
| 234 |
+
{"session_id": session_id, "user_id": user_id},
|
| 235 |
+
{"$push": {"messages": msg}}
|
|
|
|
|
|
|
| 236 |
)
|
| 237 |
+
# return True if we found the document to update
|
| 238 |
+
return result.matched_count > 0
|
| 239 |
+
except PyMongoError as e:
|
| 240 |
+
logger.error(f"Failed to save message to {session_id} for user {user_id}: {e}")
|
| 241 |
+
return False
|
| 242 |
+
|
| 243 |
+
def delete_session(self, user_id: str, session_id: str) -> bool:
|
| 244 |
+
"""
|
| 245 |
+
Delete a session belonging to a user.
|
| 246 |
+
"""
|
| 247 |
+
if self.sessions_collection is None:
|
| 248 |
+
return False
|
| 249 |
+
try:
|
| 250 |
+
result = self.sessions_collection.delete_one({"session_id": session_id, "user_id": user_id})
|
| 251 |
+
return result.deleted_count > 0
|
| 252 |
except PyMongoError as e:
|
| 253 |
+
logger.error(f"Failed to delete session {session_id} for user {user_id}: {e}")
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
def rename_session(self, user_id: str, session_id: str, new_title: str) -> bool:
|
| 257 |
+
"""
|
| 258 |
+
Rename a session belonging to a user.
|
| 259 |
+
"""
|
| 260 |
+
if self.sessions_collection is None:
|
| 261 |
+
return False
|
| 262 |
+
try:
|
| 263 |
+
result = self.sessions_collection.update_one(
|
| 264 |
+
{"session_id": session_id, "user_id": user_id},
|
| 265 |
+
{"$set": {"title": new_title}}
|
| 266 |
+
)
|
| 267 |
+
# return True if we found the document to update (even if title was same)
|
| 268 |
+
return result.matched_count > 0
|
| 269 |
+
except PyMongoError as e:
|
| 270 |
+
logger.error(f"Failed to rename session {session_id} for user {user_id}: {e}")
|
| 271 |
return False
|
| 272 |
|
| 273 |
# -------------------------------------------------------------------------
|
| 274 |
# User Profile Management
|
| 275 |
# -------------------------------------------------------------------------
|
| 276 |
+
def create_user(self, user_data: Dict[str, Any]) -> bool:
|
| 277 |
+
"""
|
| 278 |
+
Create a new user in the database.
|
| 279 |
+
"""
|
| 280 |
+
if self.users_collection is None:
|
| 281 |
+
return False
|
| 282 |
+
try:
|
| 283 |
+
self.users_collection.insert_one(user_data)
|
| 284 |
+
return True
|
| 285 |
+
except PyMongoError as e:
|
| 286 |
+
logger.error(f"Failed to create user: {e}")
|
| 287 |
+
return False
|
| 288 |
+
|
| 289 |
+
def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
|
| 290 |
+
"""
|
| 291 |
+
Retrieve a user by email.
|
| 292 |
+
"""
|
| 293 |
+
if self.users_collection is None:
|
| 294 |
+
return None
|
| 295 |
+
try:
|
| 296 |
+
return self.users_collection.find_one({"email": email})
|
| 297 |
+
except PyMongoError as e:
|
| 298 |
+
logger.error(f"Failed to get user by email {email}: {e}")
|
| 299 |
+
return None
|
| 300 |
+
|
| 301 |
def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
|
| 302 |
"""
|
| 303 |
+
Retrieve user profile by ID (Firebase UID or local UID).
|
| 304 |
"""
|
| 305 |
+
if self.users_collection is None:
|
| 306 |
return None
|
| 307 |
try:
|
| 308 |
+
return self.users_collection.find_one({"user_id": user_id})
|
| 309 |
except PyMongoError as e:
|
| 310 |
logger.error(f"Failed to get profile for {user_id}: {e}")
|
| 311 |
return None
|
|
|
|
| 314 |
"""
|
| 315 |
Update or create user profile.
|
| 316 |
"""
|
| 317 |
+
if self.users_collection is None:
|
| 318 |
return False
|
| 319 |
try:
|
| 320 |
+
self.users_collection.update_one(
|
| 321 |
{"user_id": user_id},
|
| 322 |
{"$set": {**data, "updated_at": datetime.now(timezone.utc)}},
|
| 323 |
upsert=True
|
frontend/app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import requests
|
| 3 |
-
import json
|
| 4 |
import base64
|
| 5 |
from PIL import Image
|
| 6 |
import io
|
|
@@ -13,21 +12,35 @@ from dotenv import load_dotenv
|
|
| 13 |
load_dotenv()
|
| 14 |
|
| 15 |
# ── Session state: ALL keys initialized ONCE at the very top ─────────────────
|
| 16 |
-
#
|
| 17 |
-
#
|
| 18 |
-
# Streamlit re-runs the whole script top-to-bottom on every rerun. On the rerun
|
| 19 |
-
# triggered after login, the second block executed and found "user" ALREADY in
|
| 20 |
-
# session_state (because we just set it during login), so it was a no-op —
|
| 21 |
-
# BUT on a hard browser refresh the two blocks ran in the same execution pass
|
| 22 |
-
# and the second one re-initialized user=None, wiping the login state.
|
| 23 |
-
# FIX: One single initialization block here, never again below.
|
| 24 |
if "is_processing" not in st.session_state:
|
| 25 |
st.session_state.is_processing = False
|
| 26 |
if "user" not in st.session_state:
|
| 27 |
-
st.session_state.user = None
|
| 28 |
if "current_view" not in st.session_state:
|
| 29 |
st.session_state.current_view = "Chat"
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# ====================================================
|
| 32 |
# Page Config — must come before any st.* calls
|
| 33 |
# ====================================================
|
|
@@ -91,59 +104,177 @@ st.markdown("""
|
|
| 91 |
""", unsafe_allow_html=True)
|
| 92 |
|
| 93 |
# ====================================================
|
| 94 |
-
# Config
|
| 95 |
# ====================================================
|
| 96 |
-
|
| 97 |
-
|
| 98 |
|
| 99 |
-
if "chat_sessions" not in st.session_state:
|
| 100 |
-
if os.path.exists(HISTORY_FILE):
|
| 101 |
-
with open(HISTORY_FILE, "r") as f:
|
| 102 |
-
st.session_state.chat_sessions = json.load(f)
|
| 103 |
-
else:
|
| 104 |
-
st.session_state.chat_sessions = {}
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
# ── IMPORTANT: No second "user" init block here. See top of file. ─────────────
|
| 112 |
|
| 113 |
# ====================================================
|
| 114 |
# Helper Functions
|
| 115 |
# ====================================================
|
| 116 |
-
def
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def get_active_session():
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
def add_message(role, content, sent_to_api=False, **kwargs):
|
| 124 |
-
|
| 125 |
msg = {"role": role, "content": content, "timestamp": time.time(), "sent_to_api": sent_to_api}
|
| 126 |
msg.update(kwargs)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
|
| 130 |
def new_chat():
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
def delete_chat(sid):
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# ====================================================
|
| 149 |
# Login Screen
|
|
@@ -168,23 +299,29 @@ def login_screen():
|
|
| 168 |
password = st.text_input("Password", type="password")
|
| 169 |
if st.form_submit_button("Sign In", use_container_width=True):
|
| 170 |
if email and password:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
else:
|
| 189 |
st.error("Please enter email and password.")
|
| 190 |
|
|
@@ -193,49 +330,78 @@ def login_screen():
|
|
| 193 |
new_email = st.text_input("New Email", placeholder="new@student.edu")
|
| 194 |
new_password = st.text_input("New Password", type="password")
|
| 195 |
confirm_password = st.text_input("Confirm Password", type="password")
|
|
|
|
| 196 |
if st.form_submit_button("Create Account", use_container_width=True):
|
| 197 |
if new_email and new_password:
|
| 198 |
if new_password != confirm_password:
|
| 199 |
st.error("Passwords do not match!")
|
| 200 |
else:
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
else:
|
| 219 |
st.error("Please fill all fields.")
|
| 220 |
|
| 221 |
-
st.markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
# ── Auth gate ─────────────────────────────────────────────────────────────────
|
| 224 |
if not st.session_state.user:
|
| 225 |
login_screen()
|
| 226 |
st.stop()
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
# ====================================================
|
| 229 |
# Profile Interface
|
| 230 |
# ====================================================
|
| 231 |
def profile_interface():
|
| 232 |
st.title("👤 User Profile")
|
| 233 |
st.markdown("Customize your MathMinds experience.")
|
| 234 |
-
headers =
|
| 235 |
|
| 236 |
if "profile_data" not in st.session_state:
|
| 237 |
try:
|
| 238 |
-
r = requests.get(f"{
|
| 239 |
st.session_state.profile_data = r.json() if r.status_code == 200 else {}
|
| 240 |
except Exception:
|
| 241 |
st.session_state.profile_data = {}
|
|
@@ -246,43 +412,57 @@ def profile_interface():
|
|
| 246 |
|
| 247 |
with st.form("profile_form"):
|
| 248 |
display_name = st.text_input("Display Name", value=data.get("display_name", ""))
|
| 249 |
-
math_level = st.selectbox(
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
if st.form_submit_button("Save Profile", use_container_width=True, type="primary"):
|
| 255 |
payload = {"display_name": display_name, "math_level": math_level, "interests": interests}
|
| 256 |
try:
|
| 257 |
-
r = requests.post(f"{
|
| 258 |
if r.status_code == 200:
|
| 259 |
st.success("Profile updated!")
|
| 260 |
st.session_state.profile_data = payload
|
| 261 |
-
time.sleep(1)
|
|
|
|
| 262 |
else:
|
| 263 |
st.error(f"Update failed: {r.text}")
|
| 264 |
except Exception as e:
|
| 265 |
st.error(f"Error saving: {e}")
|
| 266 |
|
|
|
|
| 267 |
# ====================================================
|
| 268 |
# Chat Interface
|
| 269 |
# ====================================================
|
| 270 |
def chat_interface():
|
| 271 |
-
if
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
|
| 277 |
# ── 1. Render history ─────────────────────────────────────────────────────
|
| 278 |
-
for msg in
|
| 279 |
-
|
| 280 |
-
|
|
|
|
| 281 |
if msg.get("image_data"):
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
| 283 |
st.write(msg["content"])
|
| 284 |
-
|
| 285 |
-
with st.chat_message("assistant", avatar="🤖"):
|
| 286 |
meta = msg.get("metadata", {})
|
| 287 |
if meta:
|
| 288 |
badges = ""
|
|
@@ -293,16 +473,17 @@ def chat_interface():
|
|
| 293 |
badges += '<span class="badge badge-blue">💾 CACHED</span>'
|
| 294 |
elif src in ("google_adk_agent", "agent"):
|
| 295 |
badges += '<span class="badge badge-purple">🤖 AGENT</span>'
|
| 296 |
-
model = meta.get("model_used")
|
| 297 |
if model:
|
| 298 |
badges += f'<span class="badge" style="background:rgba(255,255,255,0.1);">{model}</span>'
|
| 299 |
if badges:
|
| 300 |
st.markdown(badges, unsafe_allow_html=True)
|
| 301 |
|
| 302 |
-
|
| 303 |
-
if msg.get("reasoning"):
|
| 304 |
with st.expander("Show Reasoning Steps"):
|
| 305 |
-
st.markdown(msg
|
|
|
|
|
|
|
| 306 |
if isinstance(content, dict) and "final_answer" in content:
|
| 307 |
st.markdown(f"**Answer:**\n\n> {content['final_answer']}")
|
| 308 |
else:
|
|
@@ -316,24 +497,29 @@ def chat_interface():
|
|
| 316 |
is_processing = st.session_state.get("is_processing", False)
|
| 317 |
|
| 318 |
with tab_text:
|
| 319 |
-
|
|
|
|
|
|
|
| 320 |
|
| 321 |
with tab_draw:
|
| 322 |
col_canvas, col_controls = st.columns([3, 1])
|
| 323 |
with col_canvas:
|
| 324 |
-
if "canvas_key" not in st.session_state:
|
| 325 |
-
st.session_state.canvas_key = "main_canvas"
|
| 326 |
canvas_result = st_canvas(
|
| 327 |
stroke_width=3, stroke_color="#FFFFFF", background_color="#000000",
|
| 328 |
-
height=300, width=600, drawing_mode="freedraw",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
)
|
| 330 |
-
draw_prompt = st.text_input("Question about drawing (optional)", placeholder="Solve this handwritten problem...")
|
| 331 |
with col_controls:
|
| 332 |
st.caption("Controls")
|
| 333 |
-
if st.button("Clear
|
| 334 |
st.session_state.canvas_key = f"canvas_{uuid.uuid4()}"
|
| 335 |
st.rerun()
|
| 336 |
-
if st.button("Solve
|
| 337 |
if canvas_result.image_data is not None:
|
| 338 |
img = Image.fromarray(canvas_result.image_data.astype("uint8"), "RGBA")
|
| 339 |
bg = Image.new("RGB", img.size, (0, 0, 0))
|
|
@@ -341,102 +527,99 @@ def chat_interface():
|
|
| 341 |
buf = io.BytesIO()
|
| 342 |
bg.save(buf, format="PNG")
|
| 343 |
image_b64 = base64.b64encode(buf.getvalue()).decode()
|
| 344 |
-
prompt =
|
| 345 |
|
| 346 |
with tab_upload:
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
if
|
| 350 |
-
image_b64 = base64.b64encode(
|
| 351 |
-
prompt =
|
| 352 |
|
| 353 |
-
# ── 3. New user message → optimistic
|
| 354 |
if prompt:
|
| 355 |
req_id = str(uuid.uuid4())
|
| 356 |
add_message("user", prompt, image_data=image_b64, request_id=req_id, sent_to_api=False)
|
| 357 |
st.session_state.is_processing = True
|
| 358 |
st.rerun()
|
| 359 |
|
| 360 |
-
# ── 4.
|
| 361 |
-
if session["messages"] and session["messages"][-1]["role"] == "user":
|
| 362 |
-
last = session["messages"][-1]
|
| 363 |
-
if last.get("sent_to_api") and not st.session_state.is_processing:
|
| 364 |
-
last["sent_to_api"] = False
|
| 365 |
-
save_history()
|
| 366 |
-
|
| 367 |
-
# ── 5. Fire API call if last message is unsent user message ───────────────
|
| 368 |
if (
|
| 369 |
-
|
| 370 |
-
and
|
| 371 |
-
and not
|
| 372 |
):
|
| 373 |
-
last =
|
| 374 |
current_request_id = last.get("request_id") or str(uuid.uuid4())
|
| 375 |
-
last["request_id"]
|
| 376 |
|
| 377 |
with st.chat_message("assistant", avatar="🤖"):
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
if
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
add_message(
|
| 402 |
-
"assistant",
|
| 403 |
-
|
| 404 |
-
reasoning=
|
| 405 |
-
metadata={
|
| 406 |
-
"source": data.get("source", "agent"),
|
| 407 |
-
"model_used": meta.get("model", "gemini-2.5-flash"),
|
| 408 |
-
"latency": f"{meta.get('latency_ms',0)/1000:.2f}s",
|
| 409 |
-
"tools": meta.get("tools_used", []),
|
| 410 |
-
},
|
| 411 |
-
steps=data.get("steps", [])
|
| 412 |
)
|
| 413 |
-
|
| 414 |
-
st.rerun() # ← BUG 2 FIX: missing rerun caused blank UI
|
| 415 |
-
|
| 416 |
-
else:
|
| 417 |
-
error_msg = data.get("error", "Unknown error")
|
| 418 |
-
add_message("assistant", f"⚠️ Error: {error_msg}")
|
| 419 |
-
st.session_state.is_processing = False
|
| 420 |
st.rerun()
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
st.
|
| 424 |
-
st.
|
| 425 |
st.rerun()
|
| 426 |
-
|
| 427 |
else:
|
| 428 |
-
|
| 429 |
-
err_msg = response.json().get("error", f"HTTP {response.status_code}")
|
| 430 |
-
except Exception:
|
| 431 |
-
err_msg = f"HTTP {response.status_code}"
|
| 432 |
-
add_message("assistant", f"❌ Server Error: {err_msg}")
|
| 433 |
st.session_state.is_processing = False
|
| 434 |
-
st.rerun()
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
|
|
|
| 440 |
|
| 441 |
# ====================================================
|
| 442 |
# Sidebar
|
|
@@ -445,16 +628,31 @@ with st.sidebar:
|
|
| 445 |
st.markdown("### 🧠 MathMinds")
|
| 446 |
st.write(f"Logged in as **{st.session_state.user['email']}**")
|
| 447 |
|
| 448 |
-
view = st.radio(
|
| 449 |
-
|
|
|
|
|
|
|
| 450 |
if view != st.session_state.current_view:
|
| 451 |
st.session_state.current_view = view
|
| 452 |
st.rerun()
|
| 453 |
|
| 454 |
if st.button("Sign Out", type="secondary"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
st.session_state.user = None
|
| 456 |
st.rerun()
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
st.divider()
|
| 459 |
|
| 460 |
if st.session_state.current_view == "Chat":
|
|
@@ -462,26 +660,45 @@ with st.sidebar:
|
|
| 462 |
new_chat()
|
| 463 |
|
| 464 |
st.markdown("#### History")
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
with col_nav:
|
| 476 |
-
if st.button(f"{'📍 ' if isActive else ''}{title}", key=sid, use_container_width=True):
|
| 477 |
st.session_state.active_session_id = sid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
st.rerun()
|
| 479 |
-
with
|
| 480 |
-
if
|
| 481 |
delete_chat(sid)
|
| 482 |
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import requests
|
|
|
|
| 3 |
import base64
|
| 4 |
from PIL import Image
|
| 5 |
import io
|
|
|
|
| 12 |
load_dotenv()
|
| 13 |
|
| 14 |
# ── Session state: ALL keys initialized ONCE at the very top ─────────────────
|
| 15 |
+
# CRITICAL: These must be the very first st.session_state accesses, before any
|
| 16 |
+
# st.* UI calls. Streamlit re-runs the entire script on every interaction.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
if "is_processing" not in st.session_state:
|
| 18 |
st.session_state.is_processing = False
|
| 19 |
if "user" not in st.session_state:
|
| 20 |
+
st.session_state.user = None # None = logged out
|
| 21 |
if "current_view" not in st.session_state:
|
| 22 |
st.session_state.current_view = "Chat"
|
| 23 |
|
| 24 |
+
# MULTIUSER FIX ─ these three keys must be RESET on logout.
|
| 25 |
+
# They are initialized here so first-run doesn't KeyError.
|
| 26 |
+
if "chat_sessions" not in st.session_state:
|
| 27 |
+
st.session_state.chat_sessions = []
|
| 28 |
+
if "active_session_id" not in st.session_state:
|
| 29 |
+
st.session_state.active_session_id = None
|
| 30 |
+
if "messages" not in st.session_state:
|
| 31 |
+
st.session_state.messages = []
|
| 32 |
+
|
| 33 |
+
# MULTIUSER FIX ─ track WHICH user's data is currently loaded.
|
| 34 |
+
# If this doesn't match st.session_state.user["uid"], we know we need to reload.
|
| 35 |
+
if "loaded_for_user" not in st.session_state:
|
| 36 |
+
st.session_state.loaded_for_user = None
|
| 37 |
+
|
| 38 |
+
if "renaming_session_id" not in st.session_state:
|
| 39 |
+
st.session_state.renaming_session_id = None
|
| 40 |
+
|
| 41 |
+
if "canvas_key" not in st.session_state:
|
| 42 |
+
st.session_state.canvas_key = "main_canvas"
|
| 43 |
+
|
| 44 |
# ====================================================
|
| 45 |
# Page Config — must come before any st.* calls
|
| 46 |
# ====================================================
|
|
|
|
| 104 |
""", unsafe_allow_html=True)
|
| 105 |
|
| 106 |
# ====================================================
|
| 107 |
+
# Config
|
| 108 |
# ====================================================
|
| 109 |
+
BASE_API_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
|
| 110 |
+
API_URL = f"{BASE_API_URL}/solve"
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
# ====================================================
|
| 114 |
+
# MULTIUSER ISOLATION — Core helper
|
| 115 |
+
# ====================================================
|
| 116 |
+
def _clear_user_state():
|
| 117 |
+
"""
|
| 118 |
+
Wipe ALL per-user data from Streamlit session state.
|
| 119 |
+
|
| 120 |
+
Called on logout and whenever a different user logs in.
|
| 121 |
+
|
| 122 |
+
WHY THIS IS THE MOST IMPORTANT FUNCTION FOR MULTIUSER ISOLATION:
|
| 123 |
+
Streamlit's st.session_state is per browser-tab, not per user. If User A
|
| 124 |
+
logs in, chats, then User B logs in on the same tab, all of User A's
|
| 125 |
+
chat_sessions and messages are still sitting in st.session_state. The
|
| 126 |
+
backend correctly refuses to serve User A's data to User B (every DB query
|
| 127 |
+
filters by user_id), but the frontend would still DISPLAY User A's messages
|
| 128 |
+
briefly until the next API call returns. This function prevents that.
|
| 129 |
+
"""
|
| 130 |
+
st.session_state.chat_sessions = []
|
| 131 |
+
st.session_state.active_session_id = None
|
| 132 |
+
st.session_state.messages = []
|
| 133 |
+
st.session_state.loaded_for_user = None
|
| 134 |
+
st.session_state.is_processing = False
|
| 135 |
+
st.session_state.current_view = "Chat"
|
| 136 |
+
st.session_state.renaming_session_id = None
|
| 137 |
+
st.session_state.canvas_key = f"canvas_{uuid.uuid4()}"
|
| 138 |
+
# Also clear profile cache if it exists
|
| 139 |
+
if "profile_data" in st.session_state:
|
| 140 |
+
del st.session_state["profile_data"]
|
| 141 |
|
|
|
|
| 142 |
|
| 143 |
# ====================================================
|
| 144 |
# Helper Functions
|
| 145 |
# ====================================================
|
| 146 |
+
def get_auth_headers():
|
| 147 |
+
if st.session_state.user and "token" in st.session_state.user:
|
| 148 |
+
return {"Authorization": f"Bearer {st.session_state.user['token']}"}
|
| 149 |
+
return {}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def load_sessions():
|
| 153 |
+
"""Fetch THIS user's chat sessions from the backend and populate state."""
|
| 154 |
+
try:
|
| 155 |
+
headers = get_auth_headers()
|
| 156 |
+
response = requests.get(f"{BASE_API_URL}/chat/sessions", headers=headers, timeout=30)
|
| 157 |
+
if response.status_code == 200:
|
| 158 |
+
st.session_state.chat_sessions = response.json()
|
| 159 |
+
# Mark that we've successfully loaded data for this specific user
|
| 160 |
+
if st.session_state.user:
|
| 161 |
+
st.session_state.loaded_for_user = st.session_state.user["uid"]
|
| 162 |
+
# Auto-select first session if none active
|
| 163 |
+
if not st.session_state.active_session_id and st.session_state.chat_sessions:
|
| 164 |
+
st.session_state.active_session_id = st.session_state.chat_sessions[0]["session_id"]
|
| 165 |
+
load_messages(st.session_state.active_session_id)
|
| 166 |
+
elif st.session_state.active_session_id and not any(
|
| 167 |
+
s["session_id"] == st.session_state.active_session_id
|
| 168 |
+
for s in st.session_state.chat_sessions
|
| 169 |
+
):
|
| 170 |
+
# Active session was deleted — pick first or clear
|
| 171 |
+
if st.session_state.chat_sessions:
|
| 172 |
+
st.session_state.active_session_id = st.session_state.chat_sessions[0]["session_id"]
|
| 173 |
+
load_messages(st.session_state.active_session_id)
|
| 174 |
+
else:
|
| 175 |
+
st.session_state.active_session_id = None
|
| 176 |
+
st.session_state.messages = []
|
| 177 |
+
elif response.status_code == 401:
|
| 178 |
+
# JWT expired — force re-login
|
| 179 |
+
_clear_user_state()
|
| 180 |
+
st.session_state.user = None
|
| 181 |
+
st.error("Session expired. Please log in again.")
|
| 182 |
+
else:
|
| 183 |
+
st.error(f"Failed to load sessions: {response.status_code}")
|
| 184 |
+
st.session_state.chat_sessions = []
|
| 185 |
+
except Exception as e:
|
| 186 |
+
st.error(f"Error loading sessions: {e}")
|
| 187 |
+
st.session_state.chat_sessions = []
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def load_messages(session_id):
|
| 191 |
+
"""
|
| 192 |
+
Load messages for a session.
|
| 193 |
+
The backend enforces user ownership — it will 404 if session_id
|
| 194 |
+
doesn't belong to the authenticated user, so this is safe.
|
| 195 |
+
"""
|
| 196 |
+
try:
|
| 197 |
+
headers = get_auth_headers()
|
| 198 |
+
response = requests.get(
|
| 199 |
+
f"{BASE_API_URL}/chat/sessions/{session_id}/messages",
|
| 200 |
+
headers=headers, timeout=30
|
| 201 |
+
)
|
| 202 |
+
if response.status_code == 200:
|
| 203 |
+
st.session_state.messages = response.json()
|
| 204 |
+
elif response.status_code == 404:
|
| 205 |
+
# Session doesn't belong to this user — clear silently
|
| 206 |
+
st.session_state.messages = []
|
| 207 |
+
st.session_state.active_session_id = None
|
| 208 |
+
st.warning("Session not found.")
|
| 209 |
+
else:
|
| 210 |
+
st.session_state.messages = []
|
| 211 |
+
st.error(f"Failed to load messages: {response.status_code}")
|
| 212 |
+
except Exception as e:
|
| 213 |
+
st.error(f"Error loading messages: {e}")
|
| 214 |
+
st.session_state.messages = []
|
| 215 |
+
|
| 216 |
|
| 217 |
def get_active_session():
|
| 218 |
+
for s in st.session_state.chat_sessions:
|
| 219 |
+
if s["session_id"] == st.session_state.active_session_id:
|
| 220 |
+
return s
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
|
| 224 |
def add_message(role, content, sent_to_api=False, **kwargs):
|
| 225 |
+
"""Optimistic UI update only — persistence happens in the backend via /solve."""
|
| 226 |
msg = {"role": role, "content": content, "timestamp": time.time(), "sent_to_api": sent_to_api}
|
| 227 |
msg.update(kwargs)
|
| 228 |
+
st.session_state.messages.append(msg)
|
| 229 |
+
|
| 230 |
|
| 231 |
def new_chat():
|
| 232 |
+
try:
|
| 233 |
+
headers = get_auth_headers()
|
| 234 |
+
response = requests.post(f"{BASE_API_URL}/chat/sessions", headers=headers, timeout=30)
|
| 235 |
+
if response.status_code == 200:
|
| 236 |
+
new_s = response.json()
|
| 237 |
+
st.session_state.active_session_id = new_s["session_id"]
|
| 238 |
+
st.session_state.messages = []
|
| 239 |
+
load_sessions()
|
| 240 |
+
st.rerun()
|
| 241 |
+
else:
|
| 242 |
+
st.error("Failed to create new chat")
|
| 243 |
+
except Exception as e:
|
| 244 |
+
st.error(f"Error: {e}")
|
| 245 |
+
|
| 246 |
|
| 247 |
def delete_chat(sid):
|
| 248 |
+
try:
|
| 249 |
+
headers = get_auth_headers()
|
| 250 |
+
response = requests.delete(f"{BASE_API_URL}/chat/sessions/{sid}", headers=headers, timeout=30)
|
| 251 |
+
if response.status_code == 200:
|
| 252 |
+
if st.session_state.active_session_id == sid:
|
| 253 |
+
st.session_state.active_session_id = None
|
| 254 |
+
st.session_state.messages = []
|
| 255 |
+
load_sessions()
|
| 256 |
+
st.rerun()
|
| 257 |
+
else:
|
| 258 |
+
st.error("Failed to delete chat")
|
| 259 |
+
except Exception as e:
|
| 260 |
+
st.error(f"Error: {e}")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def rename_chat(sid, new_title):
|
| 264 |
+
try:
|
| 265 |
+
headers = get_auth_headers()
|
| 266 |
+
response = requests.patch(
|
| 267 |
+
f"{BASE_API_URL}/chat/sessions/{sid}",
|
| 268 |
+
headers=headers, json={"title": new_title}, timeout=30
|
| 269 |
+
)
|
| 270 |
+
if response.status_code == 200:
|
| 271 |
+
load_sessions()
|
| 272 |
+
st.rerun()
|
| 273 |
+
else:
|
| 274 |
+
st.error("Failed to rename chat")
|
| 275 |
+
except Exception as e:
|
| 276 |
+
st.error(f"Error: {e}")
|
| 277 |
+
|
| 278 |
|
| 279 |
# ====================================================
|
| 280 |
# Login Screen
|
|
|
|
| 299 |
password = st.text_input("Password", type="password")
|
| 300 |
if st.form_submit_button("Sign In", use_container_width=True):
|
| 301 |
if email and password:
|
| 302 |
+
try:
|
| 303 |
+
r = requests.post(
|
| 304 |
+
f"{BASE_API_URL}/auth/login",
|
| 305 |
+
json={"email": email, "password": password},
|
| 306 |
+
timeout=30
|
| 307 |
+
)
|
| 308 |
+
if r.status_code == 200:
|
| 309 |
+
d = r.json()
|
| 310 |
+
# ✅ MULTIUSER FIX: Clear ALL previous user data
|
| 311 |
+
# BEFORE setting the new user identity.
|
| 312 |
+
_clear_user_state()
|
| 313 |
+
st.session_state.user = {
|
| 314 |
+
"email": d["email"],
|
| 315 |
+
"token": d["access_token"],
|
| 316 |
+
"uid": d["user_id"]
|
| 317 |
+
}
|
| 318 |
+
st.success(f"Welcome back, {d['email']}!")
|
| 319 |
+
time.sleep(0.5)
|
| 320 |
+
st.rerun()
|
| 321 |
+
else:
|
| 322 |
+
st.error(f"Login Failed: {r.json().get('detail', 'Unknown error')}")
|
| 323 |
+
except Exception as e:
|
| 324 |
+
st.error(f"Connection Error: {e}")
|
| 325 |
else:
|
| 326 |
st.error("Please enter email and password.")
|
| 327 |
|
|
|
|
| 330 |
new_email = st.text_input("New Email", placeholder="new@student.edu")
|
| 331 |
new_password = st.text_input("New Password", type="password")
|
| 332 |
confirm_password = st.text_input("Confirm Password", type="password")
|
| 333 |
+
full_name = st.text_input("Full Name", placeholder="Optional")
|
| 334 |
if st.form_submit_button("Create Account", use_container_width=True):
|
| 335 |
if new_email and new_password:
|
| 336 |
if new_password != confirm_password:
|
| 337 |
st.error("Passwords do not match!")
|
| 338 |
else:
|
| 339 |
+
try:
|
| 340 |
+
r = requests.post(
|
| 341 |
+
f"{BASE_API_URL}/auth/signup",
|
| 342 |
+
json={
|
| 343 |
+
"email": new_email,
|
| 344 |
+
"password": new_password,
|
| 345 |
+
"full_name": full_name
|
| 346 |
+
},
|
| 347 |
+
timeout=30
|
| 348 |
+
)
|
| 349 |
+
if r.status_code == 200:
|
| 350 |
+
d = r.json()
|
| 351 |
+
# ✅ MULTIUSER FIX: Same as login — clear first
|
| 352 |
+
_clear_user_state()
|
| 353 |
+
st.session_state.user = {
|
| 354 |
+
"email": d["email"],
|
| 355 |
+
"token": d["access_token"],
|
| 356 |
+
"uid": d["user_id"]
|
| 357 |
+
}
|
| 358 |
+
st.success(f"Account Created! Welcome, {d['email']}!")
|
| 359 |
+
time.sleep(0.5)
|
| 360 |
+
st.rerun()
|
| 361 |
+
else:
|
| 362 |
+
st.error(f"Sign Up Failed: {r.json().get('detail', 'Unknown error')}")
|
| 363 |
+
except Exception as e:
|
| 364 |
+
st.error(f"Connection Error: {e}")
|
| 365 |
else:
|
| 366 |
st.error("Please fill all fields.")
|
| 367 |
|
| 368 |
+
st.markdown(
|
| 369 |
+
"<p style='text-align:center;font-size:0.8rem;color:#6b7280;'>Powered by Gemini & SymPy</p>",
|
| 370 |
+
unsafe_allow_html=True
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
|
| 374 |
# ── Auth gate ─────────────────────────────────────────────────────────────────
|
| 375 |
if not st.session_state.user:
|
| 376 |
login_screen()
|
| 377 |
st.stop()
|
| 378 |
|
| 379 |
+
# ====================================================
|
| 380 |
+
# ✅ MULTIUSER FIX — Per-rerun data isolation check
|
| 381 |
+
# ====================================================
|
| 382 |
+
# At this point we know a user IS logged in.
|
| 383 |
+
# Check: is the data currently in state actually for THIS user?
|
| 384 |
+
# This handles the scenario where User A's browser tab is reused by User B
|
| 385 |
+
# (e.g. token swap, shared kiosk, etc.) without a full page reload.
|
| 386 |
+
_current_uid = st.session_state.user["uid"]
|
| 387 |
+
if st.session_state.loaded_for_user != _current_uid:
|
| 388 |
+
# Data in state belongs to a different user (or nobody) — reload for current user
|
| 389 |
+
_clear_user_state()
|
| 390 |
+
load_sessions()
|
| 391 |
+
# loaded_for_user is set inside load_sessions() on success
|
| 392 |
+
|
| 393 |
+
|
| 394 |
# ====================================================
|
| 395 |
# Profile Interface
|
| 396 |
# ====================================================
|
| 397 |
def profile_interface():
|
| 398 |
st.title("👤 User Profile")
|
| 399 |
st.markdown("Customize your MathMinds experience.")
|
| 400 |
+
headers = get_auth_headers()
|
| 401 |
|
| 402 |
if "profile_data" not in st.session_state:
|
| 403 |
try:
|
| 404 |
+
r = requests.get(f"{BASE_API_URL}/users/profile", headers=headers, timeout=30)
|
| 405 |
st.session_state.profile_data = r.json() if r.status_code == 200 else {}
|
| 406 |
except Exception:
|
| 407 |
st.session_state.profile_data = {}
|
|
|
|
| 412 |
|
| 413 |
with st.form("profile_form"):
|
| 414 |
display_name = st.text_input("Display Name", value=data.get("display_name", ""))
|
| 415 |
+
math_level = st.selectbox(
|
| 416 |
+
"Math Proficiency Level", levels,
|
| 417 |
+
index=levels.index(data.get("math_level", "Undergraduate"))
|
| 418 |
+
if data.get("math_level") in levels else 1
|
| 419 |
+
)
|
| 420 |
+
interests = st.multiselect(
|
| 421 |
+
"Areas of Interest", interests_all,
|
| 422 |
+
default=[i for i in data.get("interests", []) if i in interests_all]
|
| 423 |
+
)
|
| 424 |
if st.form_submit_button("Save Profile", use_container_width=True, type="primary"):
|
| 425 |
payload = {"display_name": display_name, "math_level": math_level, "interests": interests}
|
| 426 |
try:
|
| 427 |
+
r = requests.post(f"{BASE_API_URL}/users/profile", json=payload, headers=headers)
|
| 428 |
if r.status_code == 200:
|
| 429 |
st.success("Profile updated!")
|
| 430 |
st.session_state.profile_data = payload
|
| 431 |
+
time.sleep(1)
|
| 432 |
+
st.rerun()
|
| 433 |
else:
|
| 434 |
st.error(f"Update failed: {r.text}")
|
| 435 |
except Exception as e:
|
| 436 |
st.error(f"Error saving: {e}")
|
| 437 |
|
| 438 |
+
|
| 439 |
# ====================================================
|
| 440 |
# Chat Interface
|
| 441 |
# ====================================================
|
| 442 |
def chat_interface():
|
| 443 |
+
if not st.session_state.active_session_id:
|
| 444 |
+
if st.session_state.chat_sessions:
|
| 445 |
+
st.session_state.active_session_id = st.session_state.chat_sessions[0]["session_id"]
|
| 446 |
+
load_messages(st.session_state.active_session_id)
|
| 447 |
+
else:
|
| 448 |
+
new_chat()
|
| 449 |
+
return
|
| 450 |
|
| 451 |
+
active_sess = get_active_session()
|
| 452 |
+
st.title(active_sess["title"] if active_sess else "Chat")
|
| 453 |
|
| 454 |
# ── 1. Render history ─────────────────────────────────────────────────────
|
| 455 |
+
for msg in st.session_state.messages:
|
| 456 |
+
role = msg["role"]
|
| 457 |
+
with st.chat_message(role, avatar="👤" if role == "user" else "🤖"):
|
| 458 |
+
if role == "user":
|
| 459 |
if msg.get("image_data"):
|
| 460 |
+
try:
|
| 461 |
+
st.image(base64.b64decode(msg["image_data"]), width=300)
|
| 462 |
+
except Exception:
|
| 463 |
+
pass
|
| 464 |
st.write(msg["content"])
|
| 465 |
+
else:
|
|
|
|
| 466 |
meta = msg.get("metadata", {})
|
| 467 |
if meta:
|
| 468 |
badges = ""
|
|
|
|
| 473 |
badges += '<span class="badge badge-blue">💾 CACHED</span>'
|
| 474 |
elif src in ("google_adk_agent", "agent"):
|
| 475 |
badges += '<span class="badge badge-purple">🤖 AGENT</span>'
|
| 476 |
+
model = meta.get("model_used") or meta.get("model")
|
| 477 |
if model:
|
| 478 |
badges += f'<span class="badge" style="background:rgba(255,255,255,0.1);">{model}</span>'
|
| 479 |
if badges:
|
| 480 |
st.markdown(badges, unsafe_allow_html=True)
|
| 481 |
|
| 482 |
+
if msg.get("reasoning") or msg.get("explanation"):
|
|
|
|
| 483 |
with st.expander("Show Reasoning Steps"):
|
| 484 |
+
st.markdown(msg.get("reasoning") or msg.get("explanation"))
|
| 485 |
+
|
| 486 |
+
content = msg["content"]
|
| 487 |
if isinstance(content, dict) and "final_answer" in content:
|
| 488 |
st.markdown(f"**Answer:**\n\n> {content['final_answer']}")
|
| 489 |
else:
|
|
|
|
| 497 |
is_processing = st.session_state.get("is_processing", False)
|
| 498 |
|
| 499 |
with tab_text:
|
| 500 |
+
text_prompt = st.chat_input("Ask a math question...", disabled=is_processing)
|
| 501 |
+
if text_prompt:
|
| 502 |
+
prompt = text_prompt
|
| 503 |
|
| 504 |
with tab_draw:
|
| 505 |
col_canvas, col_controls = st.columns([3, 1])
|
| 506 |
with col_canvas:
|
|
|
|
|
|
|
| 507 |
canvas_result = st_canvas(
|
| 508 |
stroke_width=3, stroke_color="#FFFFFF", background_color="#000000",
|
| 509 |
+
height=300, width=600, drawing_mode="freedraw",
|
| 510 |
+
key=st.session_state.canvas_key,
|
| 511 |
+
)
|
| 512 |
+
draw_prompt_input = st.text_input(
|
| 513 |
+
"Question about drawing (optional)",
|
| 514 |
+
placeholder="Solve this handwritten problem...",
|
| 515 |
+
key="draw_prompt_input"
|
| 516 |
)
|
|
|
|
| 517 |
with col_controls:
|
| 518 |
st.caption("Controls")
|
| 519 |
+
if st.button("Clear"):
|
| 520 |
st.session_state.canvas_key = f"canvas_{uuid.uuid4()}"
|
| 521 |
st.rerun()
|
| 522 |
+
if st.button("Solve", type="primary", disabled=is_processing):
|
| 523 |
if canvas_result.image_data is not None:
|
| 524 |
img = Image.fromarray(canvas_result.image_data.astype("uint8"), "RGBA")
|
| 525 |
bg = Image.new("RGB", img.size, (0, 0, 0))
|
|
|
|
| 527 |
buf = io.BytesIO()
|
| 528 |
bg.save(buf, format="PNG")
|
| 529 |
image_b64 = base64.b64encode(buf.getvalue()).decode()
|
| 530 |
+
prompt = draw_prompt_input or "Solve this handwritten math problem."
|
| 531 |
|
| 532 |
with tab_upload:
|
| 533 |
+
uploaded_file = st.file_uploader("Upload", type=["png", "jpg"], disabled=is_processing)
|
| 534 |
+
upload_prompt_input = st.text_input("Question", placeholder="Analyze...", disabled=is_processing, key="upload_prompt_input")
|
| 535 |
+
if uploaded_file and st.button("Analyze", disabled=is_processing):
|
| 536 |
+
image_b64 = base64.b64encode(uploaded_file.getvalue()).decode()
|
| 537 |
+
prompt = upload_prompt_input or "Analyze this image."
|
| 538 |
|
| 539 |
+
# ── 3. New user message → optimistic UI update + rerun ────────────────────
|
| 540 |
if prompt:
|
| 541 |
req_id = str(uuid.uuid4())
|
| 542 |
add_message("user", prompt, image_data=image_b64, request_id=req_id, sent_to_api=False)
|
| 543 |
st.session_state.is_processing = True
|
| 544 |
st.rerun()
|
| 545 |
|
| 546 |
+
# ── 4. Fire API call if last message is an unsent user message ────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
if (
|
| 548 |
+
st.session_state.messages
|
| 549 |
+
and st.session_state.messages[-1]["role"] == "user"
|
| 550 |
+
and not st.session_state.messages[-1].get("sent_to_api", False)
|
| 551 |
):
|
| 552 |
+
last = st.session_state.messages[-1]
|
| 553 |
current_request_id = last.get("request_id") or str(uuid.uuid4())
|
| 554 |
+
last["request_id"] = current_request_id
|
| 555 |
|
| 556 |
with st.chat_message("assistant", avatar="🤖"):
|
| 557 |
+
status_msg = st.status("Thinking...", expanded=True)
|
| 558 |
+
logic_placeholder = status_msg.empty()
|
| 559 |
+
answer_placeholder = st.empty()
|
| 560 |
+
|
| 561 |
+
full_answer = ""
|
| 562 |
+
logic_trace = []
|
| 563 |
+
|
| 564 |
+
try:
|
| 565 |
+
last["sent_to_api"] = True
|
| 566 |
+
payload = {
|
| 567 |
+
"text": last["content"],
|
| 568 |
+
"image": last.get("image_data"),
|
| 569 |
+
"session_id": st.session_state.active_session_id,
|
| 570 |
+
"request_id": current_request_id,
|
| 571 |
+
}
|
| 572 |
+
headers = get_auth_headers()
|
| 573 |
+
|
| 574 |
+
with requests.post(API_URL, json=payload, headers=headers, stream=True, timeout=360) as r:
|
| 575 |
+
if r.status_code == 200:
|
| 576 |
+
for line in r.iter_lines():
|
| 577 |
+
if line:
|
| 578 |
+
try:
|
| 579 |
+
data = json.loads(line)
|
| 580 |
+
if data["type"] == "thought":
|
| 581 |
+
logic_trace.append(data["content"])
|
| 582 |
+
elif data["type"] == "action":
|
| 583 |
+
logic_trace.append(f"⚙️ {data['content']}")
|
| 584 |
+
elif data["type"] == "observation":
|
| 585 |
+
logic_trace.append(f"👁️ {data['content']}")
|
| 586 |
+
elif data["type"] == "answer":
|
| 587 |
+
full_answer += data["content"]
|
| 588 |
+
answer_placeholder.markdown(full_answer)
|
| 589 |
+
elif data["type"] == "error":
|
| 590 |
+
st.error(data["content"])
|
| 591 |
+
|
| 592 |
+
# Update logic trace UI
|
| 593 |
+
logic_placeholder.markdown("\n".join(logic_trace))
|
| 594 |
+
except Exception:
|
| 595 |
+
continue
|
| 596 |
+
|
| 597 |
+
status_msg.update(label="Solved!", state="complete", expanded=False)
|
| 598 |
+
st.session_state.is_processing = False
|
| 599 |
+
|
| 600 |
+
if full_answer:
|
| 601 |
add_message(
|
| 602 |
+
"assistant",
|
| 603 |
+
full_answer,
|
| 604 |
+
reasoning="\n".join(logic_trace),
|
| 605 |
+
metadata={"source": "agent"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
)
|
| 607 |
+
load_sessions() # Update titles
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
st.rerun()
|
| 609 |
+
elif r.status_code == 401:
|
| 610 |
+
_clear_user_state()
|
| 611 |
+
st.session_state.user = None
|
| 612 |
+
st.error("Session expired. Please log in again.")
|
| 613 |
st.rerun()
|
|
|
|
| 614 |
else:
|
| 615 |
+
st.error(f"Error: {r.status_code}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
st.session_state.is_processing = False
|
|
|
|
| 617 |
|
| 618 |
+
except Exception as e:
|
| 619 |
+
st.error(f"Connection error: {e}")
|
| 620 |
+
st.session_state.is_processing = False
|
| 621 |
+
st.rerun()
|
| 622 |
+
|
| 623 |
|
| 624 |
# ====================================================
|
| 625 |
# Sidebar
|
|
|
|
| 628 |
st.markdown("### 🧠 MathMinds")
|
| 629 |
st.write(f"Logged in as **{st.session_state.user['email']}**")
|
| 630 |
|
| 631 |
+
view = st.radio(
|
| 632 |
+
"Navigation", ["Chat", "Profile"],
|
| 633 |
+
index=0 if st.session_state.current_view == "Chat" else 1
|
| 634 |
+
)
|
| 635 |
if view != st.session_state.current_view:
|
| 636 |
st.session_state.current_view = view
|
| 637 |
st.rerun()
|
| 638 |
|
| 639 |
if st.button("Sign Out", type="secondary"):
|
| 640 |
+
# ✅ MULTIUSER FIX: Wipe ALL user-specific state first, THEN clear identity.
|
| 641 |
+
# Without _clear_user_state() here, the next user to log in on the same
|
| 642 |
+
# browser tab would see User A's chat history briefly before load_sessions
|
| 643 |
+
# returns, because st.session_state persists across logins within a tab.
|
| 644 |
+
_clear_user_state()
|
| 645 |
st.session_state.user = None
|
| 646 |
st.rerun()
|
| 647 |
|
| 648 |
+
if st.session_state.is_processing:
|
| 649 |
+
if st.button(
|
| 650 |
+
"🔓 Reset Processing Lock", type="primary",
|
| 651 |
+
help="Use if UI is stuck despite answer finishing."
|
| 652 |
+
):
|
| 653 |
+
st.session_state.is_processing = False
|
| 654 |
+
st.rerun()
|
| 655 |
+
|
| 656 |
st.divider()
|
| 657 |
|
| 658 |
if st.session_state.current_view == "Chat":
|
|
|
|
| 660 |
new_chat()
|
| 661 |
|
| 662 |
st.markdown("#### History")
|
| 663 |
+
|
| 664 |
+
for session in st.session_state.chat_sessions:
|
| 665 |
+
sid = session["session_id"]
|
| 666 |
+
title = session["title"]
|
| 667 |
+
|
| 668 |
+
cols = st.columns([0.8, 0.1, 0.1])
|
| 669 |
+
with cols[0]:
|
| 670 |
+
is_active = (st.session_state.active_session_id == sid)
|
| 671 |
+
btn_type = "primary" if is_active else "secondary"
|
| 672 |
+
if st.button(title, key=f"sel_{sid}", use_container_width=True, type=btn_type):
|
|
|
|
|
|
|
| 673 |
st.session_state.active_session_id = sid
|
| 674 |
+
load_messages(sid)
|
| 675 |
+
st.rerun()
|
| 676 |
+
with cols[1]:
|
| 677 |
+
if st.button("🖊️", key=f"ren_{sid}", help="Rename"):
|
| 678 |
+
st.session_state.renaming_session_id = (
|
| 679 |
+
sid if st.session_state.renaming_session_id != sid else None
|
| 680 |
+
)
|
| 681 |
st.rerun()
|
| 682 |
+
with cols[2]:
|
| 683 |
+
if st.button("🗑️", key=f"del_{sid}", help="Delete"):
|
| 684 |
delete_chat(sid)
|
| 685 |
|
| 686 |
+
if st.session_state.renaming_session_id == sid:
|
| 687 |
+
with st.container():
|
| 688 |
+
new_title = st.text_input(
|
| 689 |
+
"New title", value=title,
|
| 690 |
+
key=f"in_{sid}", label_visibility="collapsed"
|
| 691 |
+
)
|
| 692 |
+
if st.button("Save", key=f"save_{sid}", use_container_width=True):
|
| 693 |
+
rename_chat(sid, new_title)
|
| 694 |
+
st.session_state.renaming_session_id = None
|
| 695 |
+
st.rerun()
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
# ====================================================
|
| 699 |
+
# Main Content Area
|
| 700 |
+
# ====================================================
|
| 701 |
+
if st.session_state.current_view == "Chat":
|
| 702 |
+
chat_interface()
|
| 703 |
+
elif st.session_state.current_view == "Profile":
|
| 704 |
+
profile_interface()
|
requirements.txt
CHANGED
|
@@ -33,6 +33,9 @@ pandas
|
|
| 33 |
gradio
|
| 34 |
transformers
|
| 35 |
torch
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
selenium
|
| 38 |
webdriver-manager
|
|
|
|
| 33 |
gradio
|
| 34 |
transformers
|
| 35 |
torch
|
| 36 |
+
passlib[bcrypt]
|
| 37 |
+
bcrypt
|
| 38 |
+
PyJWT
|
| 39 |
|
| 40 |
selenium
|
| 41 |
webdriver-manager
|