ghadgemadhuri92 commited on
Commit
c14a92e
·
1 Parent(s): c00b41f

Offloaded persistence tasks to the background

Browse files
app/agents/adk_mathminds.py CHANGED
@@ -1,18 +1,8 @@
1
- """
2
- adk_mathminds.py — MathMinds ADK Agent
3
- Key changes vs original:
4
- 1. Semaphore removed. Replaced with Redis-backed daily quota via llm_guard.
5
- 2. Tenacity retries scoped to 429/rate-limit errors ONLY (not all exceptions),
6
- so a quota block is not retried.
7
- 3. ADK event loop now filters for is_final_response() to avoid
8
- collecting tool-call intermediate text.
9
- 4. Redis client injected via constructor so it can be shared with CacheManager.
10
- """
11
-
12
  import logging
13
  import asyncio
14
  import base64
15
- from typing import Optional
 
16
 
17
  from google.adk.agents import Agent
18
  from google.adk.runners import Runner
@@ -22,12 +12,11 @@ from google.genai.errors import ClientError
22
  from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
23
 
24
  from app.core.settings import settings
25
- from app.core.llm_guard import check_and_increment # ← new import
26
  from app.tools.web_scraper import WebScraper
27
  from app.tools.symbolic_solver import SymbolicSolver
28
  from app.tools.similarity_search import SimilarProblemFinder
29
- from app.core.ocr import OCRProcessor
30
- from app.tools.vision_analyzer import VisionAnalyzer
31
  from app.core.math_normalizer import MathQueryNormalizer
32
 
33
  logger = logging.getLogger(__name__)
@@ -36,12 +25,12 @@ logger = logging.getLogger(__name__)
36
  class MathMindsADKAgent:
37
  """
38
  Agent-based architecture using Google ADK.
39
- Uses a Redis-backed daily quota (not a semaphore) to stay within 20 RPD.
40
  """
41
 
42
- def __init__(self, model_name: str = "gemini-2.5-flash", redis_client=None):
43
  self.api_key = settings.GOOGLE_API_KEY
44
- self.redis_client = redis_client # injected — shared with CacheManager
45
 
46
  if not self.api_key:
47
  logger.warning("No Google API Key found. Agent will fail.")
@@ -51,8 +40,7 @@ class MathMindsADKAgent:
51
  self.symbolic_solver = SymbolicSolver()
52
  self.normalizer = MathQueryNormalizer()
53
  self.similar_finder = SimilarProblemFinder()
54
- self.ocr = OCRProcessor()
55
- self.vision_analyzer = VisionAnalyzer()
56
 
57
  # ── Tool definitions ──────────────────────────────────────────────────
58
  async def web_search(query: str) -> str:
@@ -79,6 +67,18 @@ class MathMindsADKAgent:
79
  return result.get("content", "No solution found.")
80
  return f"Error solving math: {result.get('error')}"
81
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def find_similar_problems(query: str) -> str:
83
  """
84
  Find similar solved problems from the database for reference.
@@ -93,41 +93,18 @@ class MathMindsADKAgent:
93
  formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n"
94
  return formatted
95
 
96
- def read_image(image_data: str) -> str:
97
- """
98
- Extract text/equations from an image using OCR.
99
- Args:
100
- image_data: Base64 string of the image.
101
- """
102
- try:
103
- text = self.ocr.extract_text(image_data=image_data)
104
- return text if text else "No text found in image."
105
- except Exception as e:
106
- return f"Error reading image: {str(e)}"
107
-
108
- async def analyze_image(image_data: str, focus: str = "") -> str:
109
- """
110
- Analyze an image mathematically: count objects, describe graphs, extract equations.
111
- Args:
112
- image_data: Base64 string of the image.
113
- focus: Optional focus hint (e.g. 'count red balls').
114
- """
115
- try:
116
- result = self.vision_analyzer.analyze(image_data, focus)
117
- return str(result)
118
- except Exception as e:
119
- return f"Image analysis failed: {str(e)}"
120
-
121
  # ── Agent & Runner ─────────────────────────────────────────────────��──
122
  self.agent = Agent(
123
  name="math_minds_core",
124
  model=model_name,
125
- tools=[web_search, math_solver, find_similar_problems, read_image, analyze_image],
126
  instruction=(
127
  "You are MathMinds AI, a precise mathematical assistant. "
128
- "When an image is provided, analyze it first — extract equations, "
129
- "count objects, or interpret graphs. Then combine image analysis with "
130
- "the text prompt. Use tools only when needed. Show your reasoning clearly."
 
 
131
  )
132
  )
133
 
@@ -138,7 +115,7 @@ class MathMindsADKAgent:
138
  session_service=self.session_service
139
  )
140
 
141
- logger.info("MathMindsADKAgent initialized.")
142
 
143
  async def solve(
144
  self,
@@ -146,22 +123,17 @@ class MathMindsADKAgent:
146
  image_data: Optional[str] = None,
147
  session_id: str = "default_session",
148
  user_id: str = "default_user"
149
- ) -> str:
150
  """
151
- Main entry point. Enforces daily quota before calling the LLM.
152
- Returns the agent's answer string, or an error message.
153
  """
154
 
155
  # ── 1. Daily quota check ──────────────────────────────────────────────
156
- # This is the ONLY gate. One check per user question = one LLM call.
157
  if self.redis_client:
158
  allowed, used, limit = check_and_increment(self.redis_client, user_id)
159
  if not allowed:
160
- logger.warning(f"Quota blocked for user={user_id} ({used}/{limit} today)")
161
- return (
162
- f"⚠️ Daily limit reached ({limit} questions per day). "
163
- "Please try again tomorrow."
164
- )
165
  else:
166
  logger.warning("Redis unavailable — skipping quota check (failing open).")
167
 
@@ -174,66 +146,56 @@ class MathMindsADKAgent:
174
  await self.session_service.create_session(
175
  app_name="mathminds", user_id=user_id, session_id=session_id
176
  )
177
- except Exception:
178
- try:
179
- await self.session_service.create_session(
180
- app_name="mathminds", user_id=user_id, session_id=session_id
181
- )
182
- except Exception as e:
183
- logger.warning(f"Session create warning: {e}")
184
 
185
  # ── 3. Build message parts ────────────────────────────────────────────
186
- parts = [types.Part.from_text(text=problem)]
 
 
 
 
187
 
188
  if image_data:
189
  try:
190
- if image_data.startswith("/9j/"):
191
- mime_type = "image/jpeg"
192
- elif image_data.startswith("iVBORw"):
193
- mime_type = "image/png"
194
- elif image_data.startswith("R0lGOD"):
195
- mime_type = "image/gif"
196
- elif image_data.startswith("UklGR"):
197
- mime_type = "image/webp"
198
- else:
199
- mime_type = "image/png"
200
-
201
  img_bytes = base64.b64decode(image_data)
 
 
 
 
 
202
  parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
203
- logger.info("Image attached to agent request.")
204
  except Exception as e:
205
- logger.error(f"Failed to process image: {e}")
206
- parts.append(types.Part.from_text(text="[Error: image could not be processed]"))
207
-
208
- # ── 4. Run agent (retry on 429 only, not all exceptions) ─────────────
209
- @retry(
210
- stop=stop_after_attempt(2), # max 2 attempts total
211
- wait=wait_exponential(multiplier=2, min=5, max=30),
212
- retry=retry_if_exception_type(ClientError), # only retry on API errors
213
- reraise=True
214
- )
215
- async def run_agent_safely() -> str:
216
- outcome = ""
217
  async for event in self.runner.run_async(
218
  user_id=user_id,
219
  session_id=session_id,
220
  new_message=types.Content(role="user", parts=parts)
221
  ):
222
- # Only collect the final response, not tool-call intermediates
223
- if hasattr(event, "is_final_response") and event.is_final_response():
224
- if event.content and event.content.parts:
225
- for part in event.content.parts:
226
- if part.text:
227
- outcome += part.text
228
- return outcome
229
-
230
- try:
231
- response_text = await run_agent_safely()
232
- if not response_text:
233
- logger.warning("Agent returned empty response.")
234
- return "The agent completed but returned no text. Please rephrase your question."
235
- return response_text
 
 
 
 
 
 
236
 
237
  except Exception as e:
238
- logger.error(f"ADK Agent execution failed: {e}")
239
- return f"Error processing request: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import asyncio
3
  import base64
4
+ import json
5
+ from typing import Optional, AsyncGenerator, Dict, Any
6
 
7
  from google.adk.agents import Agent
8
  from google.adk.runners import Runner
 
12
  from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
13
 
14
  from app.core.settings import settings
15
+ from app.core.llm_guard import check_and_increment
16
  from app.tools.web_scraper import WebScraper
17
  from app.tools.symbolic_solver import SymbolicSolver
18
  from app.tools.similarity_search import SimilarProblemFinder
19
+ from app.tools.python_executor import PythonInterpreter
 
20
  from app.core.math_normalizer import MathQueryNormalizer
21
 
22
  logger = logging.getLogger(__name__)
 
25
  class MathMindsADKAgent:
26
  """
27
  Agent-based architecture using Google ADK.
28
+ Supports real-time streaming of reasoning steps and final answers.
29
  """
30
 
31
+ def __init__(self, model_name: str = "gemini-2.0-flash", redis_client=None):
32
  self.api_key = settings.GOOGLE_API_KEY
33
+ self.redis_client = redis_client
34
 
35
  if not self.api_key:
36
  logger.warning("No Google API Key found. Agent will fail.")
 
40
  self.symbolic_solver = SymbolicSolver()
41
  self.normalizer = MathQueryNormalizer()
42
  self.similar_finder = SimilarProblemFinder()
43
+ self.python_executor = PythonInterpreter()
 
44
 
45
  # ── Tool definitions ──────────────────────────────────────────────────
46
  async def web_search(query: str) -> str:
 
67
  return result.get("content", "No solution found.")
68
  return f"Error solving math: {result.get('error')}"
69
 
70
+ async def execute_python(code: str) -> str:
71
+ """
72
+ Execute arbitrary Python code for simulations, complex logic, or data analysis.
73
+ Use this when SymPy is too restrictive or you need to run a simulation.
74
+ Args:
75
+ code: The Python code to execute.
76
+ """
77
+ result = await self.python_executor.execute(code)
78
+ if result.get("status") == "success":
79
+ return f"Output:\n{result.get('content')}\nResult: {result.get('result')}"
80
+ return f"Error in Python execution: {result.get('content')}"
81
+
82
  def find_similar_problems(query: str) -> str:
83
  """
84
  Find similar solved problems from the database for reference.
 
93
  formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n"
94
  return formatted
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # ── Agent & Runner ─────────────────────────────────────────────────��──
97
  self.agent = Agent(
98
  name="math_minds_core",
99
  model=model_name,
100
+ tools=[web_search, math_solver, execute_python, find_similar_problems],
101
  instruction=(
102
  "You are MathMinds AI, a precise mathematical assistant. "
103
+ "You can see images natively! When an image is provided, examine it "
104
+ "carefully to extract equations, count objects, or interpret graphs. "
105
+ "\n\nCRITICAL: Always start by explaining your step-by-step approach "
106
+ "before using any tools. Your internal monologue should be clear "
107
+ "and explain the reasoning behind your tool choices."
108
  )
109
  )
110
 
 
115
  session_service=self.session_service
116
  )
117
 
118
+ logger.info(f"MathMindsADKAgent initialized with model: {model_name}")
119
 
120
  async def solve(
121
  self,
 
123
  image_data: Optional[str] = None,
124
  session_id: str = "default_session",
125
  user_id: str = "default_user"
126
+ ) -> AsyncGenerator[Dict[str, Any], None]:
127
  """
128
+ Streaming entry point. Yields events as they occur.
 
129
  """
130
 
131
  # ── 1. Daily quota check ──────────────────────────────────────────────
 
132
  if self.redis_client:
133
  allowed, used, limit = check_and_increment(self.redis_client, user_id)
134
  if not allowed:
135
+ yield {"type": "error", "content": f"⚠️ Daily limit reached ({limit} today)."}
136
+ return
 
 
 
137
  else:
138
  logger.warning("Redis unavailable — skipping quota check (failing open).")
139
 
 
146
  await self.session_service.create_session(
147
  app_name="mathminds", user_id=user_id, session_id=session_id
148
  )
149
+ except Exception as e:
150
+ logger.warning(f"Session setup warning: {e}")
 
 
 
 
 
151
 
152
  # ── 3. Build message parts ────────────────────────────────────────────
153
+ parts = []
154
+ if problem:
155
+ parts.append(types.Part.from_text(text=problem))
156
+ else:
157
+ parts.append(types.Part.from_text(text="Analyze this image."))
158
 
159
  if image_data:
160
  try:
 
 
 
 
 
 
 
 
 
 
 
161
  img_bytes = base64.b64decode(image_data)
162
+ mime_type = "image/png" # Default
163
+ # Basic sniff
164
+ if image_data.startswith("/9j/"): mime_type = "image/jpeg"
165
+ elif image_data.startswith("iVBORw"): mime_type = "image/png"
166
+
167
  parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
 
168
  except Exception as e:
169
+ logger.error(f"Image decode failed: {e}")
170
+
171
+ # ── 4. Run agent (Streaming) ──────────────────────────────────────────
172
+ try:
 
 
 
 
 
 
 
 
173
  async for event in self.runner.run_async(
174
  user_id=user_id,
175
  session_id=session_id,
176
  new_message=types.Content(role="user", parts=parts)
177
  ):
178
+ # ── Capture Reasoning / Thoughts ──
179
+ if hasattr(event, "content") and event.content:
180
+ for part in event.content.parts:
181
+ if part.text:
182
+ # We treat intermittent text as reasoning/logic
183
+ yield {"type": "thought", "content": part.text}
184
+
185
+ # ── Capture Tool Usage ──
186
+ if hasattr(event, "tool_call") and event.tool_call:
187
+ yield {
188
+ "type": "action",
189
+ "content": f"Using tool: {event.tool_call.function_call.name}"
190
+ }
191
+
192
+ # ── Capture Tool Response ──
193
+ if hasattr(event, "tool_response") and event.tool_response:
194
+ yield {
195
+ "type": "observation",
196
+ "content": f"Obtained result from {event.tool_response.function_response.name}"
197
+ }
198
 
199
  except Exception as e:
200
+ logger.error(f"Streaming execution failed: {e}")
201
+ yield {"type": "error", "content": str(e)}
app/api/main.py CHANGED
@@ -1,4 +1,6 @@
1
-
 
 
2
  import sys
3
  import asyncio
4
 
@@ -6,18 +8,19 @@ if sys.platform == 'win32':
6
  asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
7
 
8
  import logging
9
- import datetime
10
- from datetime import datetime
11
  import uuid
12
  import sys
 
13
 
14
  from fastapi import FastAPI, HTTPException, status, Depends, Request
15
- from fastapi.responses import JSONResponse
16
  from slowapi import _rate_limit_exceeded_handler
17
  from slowapi.errors import RateLimitExceeded
18
  from app.core.limiter import limiter
19
  from app.core.orchestrator import Orchestrator
20
- from app.core.schemas import SolveRequest, SolveResponse, HealthResponse
 
21
  from app.core.logging_config import configure_logging
22
  from app.core.errors import AppError, ErrorCodes, ERROR_MESSAGES
23
  from app.core.settings import settings # New settings module
@@ -200,7 +203,7 @@ async def health_check():
200
 
201
  return health_status
202
 
203
- @app.post("/solve", response_model=SolveResponse)
204
  @limiter.limit("5/minute")
205
  async def solve_problem(
206
  request: Request,
@@ -243,55 +246,89 @@ async def solve_problem(
243
  logger.warning(f"Redis dedup failed (failing open): {e}")
244
  # If Redis fails, we allow the request to proceed (fail open)
245
 
246
- try:
247
- result = await orchestrator.process_problem(
248
- text=solve_req.effective_text,
249
- image=solve_req.image,
250
- request_id=final_request_id,
251
- model_preference=solve_req.model_preference,
252
- session_id=solve_req.session_id,
253
- user_id=current_user.get("uid")
254
- )
255
-
256
- # Sanitize metadata for public response
257
- public_metadata = result["metadata"].copy()
258
- public_metadata.pop("_internal_debug", None)
259
-
260
- # Map internal result to schema
261
- return SolveResponse(
262
- request_id=result.get("request_id", final_request_id),
263
- status=result["status"],
264
- problem_type=result.get("problem_type", "unknown"),
265
- source=result.get("source", "unknown"),
266
- answer=result.get("answer"),
267
- steps=result.get("steps", []),
268
- explanation=result.get("explanation"),
269
- confidence=result.get("confidence", 0.0),
270
- cached=result.get("cached", False),
271
- error=result.get("error_msg"), # Keep for backward compat if any
272
- error_code=result.get("error_code"),
273
- metadata=public_metadata
274
- )
275
 
276
- except Exception as e:
277
- logger.error(f"[{req_id}] Unhandled error in /solve: {e}")
278
- # Return generic error
279
- return JSONResponse(
280
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
281
- content={
282
- "status": "error",
283
- "error": ERROR_MESSAGES[ErrorCodes.INTERNAL_ERROR],
284
- "error_code": ErrorCodes.INTERNAL_ERROR,
285
- "metadata": {"request_id": req_id}
286
- }
287
- )
288
- finally:
289
- # Cleanup deduplication key
290
- if redis_client:
291
- try:
292
- redis_client.delete(dedup_key)
293
- except Exception as e:
294
- logger.warning(f"Failed to clear dedup key: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  # --- User Profile Endpoints ---
297
  from pydantic import BaseModel
@@ -348,3 +385,64 @@ async def update_profile(
348
  if __name__ == "__main__":
349
  import uvicorn
350
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["DISABLE_MODEL_SOURCE_CHECK"] = "True"
3
+ from typing import Any, Dict, Optional, List
4
  import sys
5
  import asyncio
6
 
 
8
  asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
9
 
10
  import logging
11
+ from datetime import datetime, timezone
 
12
  import uuid
13
  import sys
14
+ import json
15
 
16
  from fastapi import FastAPI, HTTPException, status, Depends, Request
17
+ from fastapi.responses import JSONResponse, StreamingResponse
18
  from slowapi import _rate_limit_exceeded_handler
19
  from slowapi.errors import RateLimitExceeded
20
  from app.core.limiter import limiter
21
  from app.core.orchestrator import Orchestrator
22
+ from app.core.schemas import SolveRequest, SolveResponse, HealthResponse, Message, ChatSession, SessionRename, UserSignup, UserLogin, TokenResponse
23
+ from app.core.auth_utils import hash_password, verify_password, create_access_token
24
  from app.core.logging_config import configure_logging
25
  from app.core.errors import AppError, ErrorCodes, ERROR_MESSAGES
26
  from app.core.settings import settings # New settings module
 
203
 
204
  return health_status
205
 
206
+ @app.post("/solve")
207
  @limiter.limit("5/minute")
208
  async def solve_problem(
209
  request: Request,
 
246
  logger.warning(f"Redis dedup failed (failing open): {e}")
247
  # If Redis fails, we allow the request to proceed (fail open)
248
 
249
+ async def event_generator():
250
+ try:
251
+ async for chunk in orchestrator.process_problem(
252
+ text=solve_req.effective_text,
253
+ image=solve_req.image,
254
+ request_id=final_request_id,
255
+ model_preference=solve_req.model_preference,
256
+ session_id=solve_req.session_id,
257
+ user_id=current_user.get("uid")
258
+ ):
259
+ yield json.dumps(chunk) + "\n"
260
+ except Exception as e:
261
+ logger.error(f"Streaming error: {e}")
262
+ yield json.dumps({"type": "error", "content": "Internal processing error"}) + "\n"
263
+ finally:
264
+ if redis_client:
265
+ try:
266
+ redis_client.delete(dedup_key)
267
+ except Exception:
268
+ pass
 
 
 
 
 
 
 
 
 
269
 
270
+ return StreamingResponse(event_generator(), media_type="application/x-ndjson")
271
+
272
+ # --- Chat History Endpoints ---
273
+
274
+ @app.get("/chat/sessions", response_model=List[ChatSession])
275
+ async def list_chat_sessions(
276
+ current_user: dict = Depends(get_current_user),
277
+ db_manager = Depends(get_db_manager)
278
+ ):
279
+ """List all chat sessions for the current user."""
280
+ return db_manager.list_sessions(current_user["uid"])
281
+
282
+ @app.post("/chat/sessions", response_model=ChatSession)
283
+ async def create_chat_session(
284
+ current_user: dict = Depends(get_current_user),
285
+ db_manager = Depends(get_db_manager)
286
+ ):
287
+ """Create a new chat session."""
288
+ session_id = str(uuid.uuid4())
289
+ title = "New Chat"
290
+ if db_manager.create_session(current_user["uid"], session_id, title):
291
+ return {
292
+ "session_id": session_id,
293
+ "title": title,
294
+ "created_at": datetime.utcnow()
295
+ }
296
+ raise HTTPException(status_code=500, detail="Failed to create session")
297
+
298
+ @app.get("/chat/sessions/{session_id}/messages", response_model=List[Message])
299
+ async def get_session_history(
300
+ session_id: str,
301
+ current_user: dict = Depends(get_current_user),
302
+ db_manager = Depends(get_db_manager)
303
+ ):
304
+ """Get message history for a specific session."""
305
+ history = db_manager.get_chat_history(current_user["uid"], session_id)
306
+ if not history and history != []:
307
+ raise HTTPException(status_code=404, detail="Session not found")
308
+ return history
309
+
310
+ @app.patch("/chat/sessions/{session_id}")
311
+ async def rename_chat_session(
312
+ session_id: str,
313
+ rename_data: SessionRename,
314
+ current_user: dict = Depends(get_current_user),
315
+ db_manager = Depends(get_db_manager)
316
+ ):
317
+ """Rename a chat session."""
318
+ if db_manager.rename_session(current_user["uid"], session_id, rename_data.title):
319
+ return {"status": "success", "title": rename_data.title}
320
+ raise HTTPException(status_code=404, detail="Session not found or rename failed")
321
+
322
+ @app.delete("/chat/sessions/{session_id}")
323
+ async def delete_chat_session(
324
+ session_id: str,
325
+ current_user: dict = Depends(get_current_user),
326
+ db_manager = Depends(get_db_manager)
327
+ ):
328
+ """Delete a chat session."""
329
+ if db_manager.delete_session(current_user["uid"], session_id):
330
+ return {"status": "success", "message": "Session deleted"}
331
+ raise HTTPException(status_code=404, detail="Session not found or delete failed")
332
 
333
  # --- User Profile Endpoints ---
334
  from pydantic import BaseModel
 
385
  if __name__ == "__main__":
386
  import uvicorn
387
  uvicorn.run(app, host="0.0.0.0", port=8000)
388
+ # ── Auth Endpoints ─────────────────────────────────────────────────────────
389
+
390
+ @app.post("/auth/signup", response_model=TokenResponse)
391
+ async def signup(
392
+ user_in: UserSignup,
393
+ db_manager = Depends(get_db_manager)
394
+ ):
395
+ """Register a new user."""
396
+ # Check if user already exists
397
+ existing_user = db_manager.get_user_by_email(user_in.email)
398
+ if existing_user:
399
+ raise HTTPException(
400
+ status_code=status.HTTP_400_BAD_REQUEST,
401
+ detail="User with this email already exists"
402
+ )
403
+
404
+ user_id = str(uuid.uuid4())
405
+ hashed_pw = hash_password(user_in.password)
406
+
407
+ user_dict = {
408
+ "user_id": user_id,
409
+ "email": user_in.email,
410
+ "hashed_password": hashed_pw,
411
+ "full_name": user_in.full_name,
412
+ "created_at": datetime.now(timezone.utc)
413
+ }
414
+
415
+ if db_manager.create_user(user_dict):
416
+ token = create_access_token(data={"sub": user_id, "email": user_in.email})
417
+ return {
418
+ "access_token": token,
419
+ "token_type": "bearer",
420
+ "user_id": user_id,
421
+ "email": user_in.email
422
+ }
423
+
424
+ raise HTTPException(status_code=500, detail="Failed to create user")
425
+
426
+ @app.post("/auth/login", response_model=TokenResponse)
427
+ async def login(
428
+ login_in: UserLogin,
429
+ db_manager = Depends(get_db_manager)
430
+ ):
431
+ """Login and get access token."""
432
+ user = db_manager.get_user_by_email(login_in.email)
433
+ if not user or not verify_password(login_in.password, user["hashed_password"]):
434
+ raise HTTPException(
435
+ status_code=status.HTTP_401_UNAUTHORIZED,
436
+ detail="Incorrect email or password",
437
+ headers={"WWW-Authenticate": "Bearer"},
438
+ )
439
+
440
+ user_id = user["user_id"]
441
+ token = create_access_token(data={"sub": user_id, "email": user["email"]})
442
+
443
+ return {
444
+ "access_token": token,
445
+ "token_type": "bearer",
446
+ "user_id": user_id,
447
+ "email": user["email"]
448
+ }
app/core/orchestrator.py CHANGED
@@ -3,7 +3,8 @@ import time
3
  import hashlib
4
  import json
5
  import re
6
- from typing import Any, Dict, Optional
 
7
 
8
  import sympy
9
  from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
@@ -17,21 +18,12 @@ from app.core.settings import settings
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- # Transformations that allow "2x" → "2*x" in SymPy
21
  _SYMPY_TRANSFORMATIONS = standard_transformations + (implicit_multiplication_application,)
22
 
23
 
24
  class Orchestrator:
25
  """
26
- Pipeline: Cache → Pre-flight (SymPy) → ADK Agent
27
-
28
- Pre-flight layer solves arithmetic, algebra, derivatives, integrals, and
29
- equations locally with SymPy — ZERO LLM calls. Only queries that cannot
30
- be resolved symbolically are forwarded to the ADK Agent.
31
-
32
- With a 20 RPD quota this means:
33
- - Simple math → instant, free, no quota used
34
- - Complex math → ADK agent, 1-2 LLM calls as needed
35
  """
36
 
37
  def __init__(
@@ -51,9 +43,6 @@ class Orchestrator:
51
  logger.critical(f"Failed to initialize Orchestrator: {e}")
52
  raise
53
 
54
- # ──────────────────────────────────────────────────────────────────────────
55
- # Public entry point
56
- # ──────────────────────────────────────────────────────────────────────────
57
  async def process_problem(
58
  self,
59
  text: Optional[str] = None,
@@ -62,222 +51,173 @@ class Orchestrator:
62
  model_preference: str = "fast",
63
  session_id: Optional[str] = None,
64
  user_id: Optional[str] = None,
65
- ) -> Dict[str, Any]:
66
 
67
  start_time = time.time()
68
  request_id = request_id or "unknown"
69
 
70
  result_schema: Dict[str, Any] = {
71
  "request_id": request_id,
72
- "status": "error",
73
- "source": "google_adk_agent",
74
- "answer": None,
75
- "steps": [],
76
- "explanation": None,
77
- "confidence": 0.0,
78
- "cached": False,
79
- "metadata": {"latency_ms": 0, "model": "sympy_preflight", "tools_used": []},
80
  }
81
 
82
  try:
83
- # ── Step 1: Input processing ───────────────────────────────────────
84
  processed = self.input_processor.process_compound(text_input=text, image_input=image)
85
  if not processed.is_valid:
86
- result_schema["explanation"] = processed.error_message
87
- return self._finalize(result_schema, start_time)
 
 
 
88
 
89
- query = processed.cleaned_content
90
- image_data = processed.metadata.get("image_data")
 
 
 
 
91
 
92
- # ── Step 2: Cache lookup ───────────────────────────────────────────
93
  if settings.ENABLE_CACHE and not image_data:
94
- cache_key = self._make_cache_key(query)
95
- cached = self.cache_manager.get_cached_answer(cache_key)
96
  if cached:
97
- logger.info(f"Cache hit for query: {query[:60]}")
98
- cached["cached"] = True
99
- cached["request_id"] = request_id
100
- return self._finalize(cached, start_time)
 
 
 
 
 
101
  else:
102
  cache_key = None
103
 
104
- # ── Step 3: Pre-flight — try SymPy before touching the LLM ────────
105
- # Only attempted when there is no image (images need vision model).
106
  if not image_data:
107
  preflight_result = self._try_sympy(query)
108
  if preflight_result is not None:
 
 
 
109
  result_schema.update({
110
- "status": "success",
111
- "source": "sympy_preflight",
112
- "answer": preflight_result,
113
- "explanation": "Solved locally by SymPy — no LLM call needed.",
114
- "confidence": 1.0,
115
- "metadata": {"latency_ms": 0, "model": "sympy_preflight", "tools_used": ["sympy"]},
116
  })
117
- logger.info(f"Pre-flight solved: {query[:60]} → {preflight_result[:80]}")
118
-
119
- # Cache and persist
120
- if settings.ENABLE_CACHE and cache_key:
121
- self.cache_manager.set_cached_answer(cache_key, result_schema)
122
- self.db_manager.save_problem({"content": query}, result_schema)
123
- return self._finalize(result_schema, start_time)
124
-
125
- # ── Step 4: ADK Agent (LLM) ────────────────────────────────────────
126
- logger.info("Pre-flight could not solve — routing to ADK Agent")
127
- result_schema["metadata"]["model"] = "gemini-flash-adk"
128
-
129
- try:
130
- agent_response = await self.adk_agent.solve(
131
- problem=query,
132
- image_data=image_data,
133
- session_id=session_id or "default_session",
134
- user_id=user_id,
135
- )
136
- result_schema.update({
137
- "status": "success",
138
- "source": "google_adk_agent",
139
- "answer": agent_response,
140
- "explanation": "Processed by MathMinds ADK Agent.",
141
- "confidence": 1.0,
142
- })
143
- except Exception as e:
144
- logger.error(f"ADK Agent execution failed: {e}")
145
- result_schema["explanation"] = f"Agent Error: {str(e)}"
146
- return self._finalize(result_schema, start_time)
147
 
148
- # ── Step 5: Persist ───────────────────────────────���────────────────
149
- if result_schema["status"] == "success":
150
- if settings.ENABLE_CACHE and cache_key:
151
- self.cache_manager.set_cached_answer(cache_key, result_schema)
152
- self.db_manager.save_problem({"content": query}, result_schema)
153
-
154
- return self._finalize(result_schema, start_time)
155
 
 
 
 
 
156
  except Exception as e:
157
- logger.error(f"Orchestrator Critical Error: {e}")
158
- result_schema["explanation"] = f"Internal Error: {str(e)}"
159
- return self._finalize(result_schema, start_time)
 
 
 
 
 
 
 
 
160
 
161
- # ──────────────────────────────────────────────────────────────────────────
162
- # Pre-flight: pure SymPy resolution, no LLM
163
- # ──────────────────────────────────────────────────────────────────────────
164
  def _try_sympy(self, query: str) -> Optional[str]:
165
- """
166
- Attempt to solve the query with the MathQueryNormalizer + SymPy.
167
- Returns a human-readable answer string, or None if it can't be solved
168
- locally (which means the ADK Agent should handle it).
169
- """
170
  try:
171
  intent: Optional[MathIntent] = self.normalizer.normalize(query)
172
- if intent is None:
173
- return None
174
-
175
- expr_str = self._prep_expr(intent.expression)
176
- target_var = sympy.Symbol(intent.variable or "x")
177
-
178
- if intent.intent == "arithmetic":
179
- return self._solve_arithmetic(expr_str)
180
-
181
- if intent.intent == "equation":
182
- return self._solve_equation(expr_str, target_var)
183
-
184
  if intent.intent == "derivative":
185
- expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
186
- result = sympy.diff(expr, target_var)
187
- return f"d/d{target_var}({intent.expression}) = {sympy.latex(result)}"
188
-
189
  if intent.intent == "integral":
190
- expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
191
- result = sympy.integrate(expr, target_var)
192
- return f"∫({intent.expression}) d{target_var} = {sympy.latex(result)} + C"
193
-
194
- if intent.intent == "limit":
195
- return self._solve_limit(intent, query)
196
-
197
  if intent.intent == "simplification":
198
- expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
199
- result = sympy.simplify(expr)
200
- return f"Simplified: {sympy.latex(result)}"
201
-
202
- except Exception as e:
203
- # Any parse/evaluation error → fall through to agent
204
- logger.debug(f"Pre-flight SymPy failed for '{query}': {e}")
205
-
206
  return None
207
 
208
  def _prep_expr(self, expr: str) -> str:
209
- """Normalise expression string for SymPy."""
210
- expr = expr.replace("^", "**") # ^ → **
211
- expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr) # 2x → 2*x
212
- expr = re.sub(r"\)\s*\(", ")*(", expr) # )( → )*(
213
  return expr.strip()
214
 
215
  def _solve_arithmetic(self, expr_str: str) -> Optional[str]:
216
- """Evaluate a pure arithmetic/algebraic expression."""
217
  try:
218
- expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
219
- result = sympy.simplify(expr)
220
- # If the result is a number, show it plainly; otherwise use LaTeX
221
  if result.is_number:
222
  numeric = float(result)
223
- # Show as integer if it is one
224
- if numeric == int(numeric):
225
- return str(int(numeric))
226
- return f"{numeric:.6g}"
227
  return sympy.latex(result)
228
- except Exception:
229
- return None
230
 
231
  def _solve_equation(self, expr_str: str, var: sympy.Symbol) -> Optional[str]:
232
- """Solve an equation of the form lhs = rhs, or expr = 0."""
233
  try:
234
  parts = expr_str.split("=", 1)
235
  if len(parts) == 2:
236
- lhs = parse_expr(self._prep_expr(parts[0]), transformations=_SYMPY_TRANSFORMATIONS)
237
- rhs = parse_expr(self._prep_expr(parts[1]), transformations=_SYMPY_TRANSFORMATIONS)
238
  solution = sympy.solve(lhs - rhs, var)
239
  else:
240
- expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
241
- solution = sympy.solve(expr, var)
242
-
243
- if not solution:
244
- return "No solution found."
245
- if len(solution) == 1:
246
- return f"{var} = {sympy.latex(solution[0])}"
247
- sols = ", ".join(sympy.latex(s) for s in solution)
248
- return f"{var} ∈ {{{sols}}}"
249
- except Exception:
250
- return None
251
 
252
  def _solve_limit(self, intent: MathIntent, original_query: str) -> Optional[str]:
253
- """Parse and evaluate a limit expression."""
254
  try:
255
- # Limit pattern: "limit of X as Y approaches Z"
256
- match = re.search(
257
- r"limit of\s+(.+?)\s+as\s+(\w+)\s+approaches\s+(.+)",
258
- original_query, re.IGNORECASE
259
- )
260
- if not match:
261
- return None
262
- expr_raw = self._prep_expr(match.group(1))
263
- var_name = match.group(2).strip()
264
- point_raw = self._prep_expr(match.group(3).strip())
265
-
266
- expr = parse_expr(expr_raw, transformations=_SYMPY_TRANSFORMATIONS)
267
- var = sympy.Symbol(var_name)
268
- point = parse_expr(point_raw, transformations=_SYMPY_TRANSFORMATIONS)
269
 
270
- result = sympy.limit(expr, var, point)
271
- return f"lim({var}→{point}) {sympy.latex(expr)} = {sympy.latex(result)}"
272
- except Exception:
273
- return None
274
-
275
- # ──────────────────────────────────────────────────────────────────────────
276
- # Utilities
277
- # ──────────────────────────────────────────────────────────────────────────
278
  def _make_cache_key(self, query: str) -> str:
279
  return hashlib.sha256(query.strip().lower().encode()).hexdigest()
280
-
281
- def _finalize(self, schema: Dict, start_time: float) -> Dict:
282
- schema["metadata"]["latency_ms"] = int((time.time() - start_time) * 1000)
283
- return schema
 
3
  import hashlib
4
  import json
5
  import re
6
+ import asyncio
7
+ from typing import Any, Dict, Optional, AsyncGenerator
8
 
9
  import sympy
10
  from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
21
  _SYMPY_TRANSFORMATIONS = standard_transformations + (implicit_multiplication_application,)
22
 
23
 
24
  class Orchestrator:
25
  """
26
+ Evolved Pipeline: Cache → Pre-flight (SymPy) → Agentic Streaming Loop
 
 
 
 
 
 
 
 
27
  """
28
 
29
  def __init__(
 
43
  logger.critical(f"Failed to initialize Orchestrator: {e}")
44
  raise
45
 
 
 
 
46
  async def process_problem(
47
  self,
48
  text: Optional[str] = None,
 
51
  model_preference: str = "fast",
52
  session_id: Optional[str] = None,
53
  user_id: Optional[str] = None,
54
+ ) -> AsyncGenerator[Dict[str, Any], None]:
55
 
56
  start_time = time.time()
57
  request_id = request_id or "unknown"
58
 
59
  result_schema: Dict[str, Any] = {
60
  "request_id": request_id,
61
+ "status": "success",
62
+ "source": "agent",
63
+ "answer": "",
64
+ "metadata": {"latency_ms": 0, "model": "gemini-2.0-flash", "tools_used": []},
 
 
 
 
65
  }
66
 
67
  try:
68
+ # ── 1. Input processing ───────────────────────────────────────────
69
  processed = self.input_processor.process_compound(text_input=text, image_input=image)
70
  if not processed.is_valid:
71
+ yield {"type": "error", "content": processed.error_message}
72
+ return
73
+
74
+ query = processed.cleaned_content
75
+ image_data = processed.metadata.get("image_data")
76
 
77
+ # Background: Persist user message
78
+ if user_id and session_id:
79
+ asyncio.create_task(self._persist_message(
80
+ user_id=user_id, session_id=session_id, role="user",
81
+ content=text or "Uploaded an image", image_data=image_data
82
+ ))
83
 
84
+ # ── 2. Cache lookup ───────────────────────────────────────────────
85
  if settings.ENABLE_CACHE and not image_data:
86
+ cache_key = self._make_cache_key(query)
87
+ cached = self.cache_manager.get_cached_answer(cache_key)
88
  if cached:
89
+ yield {"type": "thought", "content": "Retrieving answer from memory..."}
90
+ yield {"type": "answer", "content": cached.get("answer")}
91
+ # Background: Persist assistant response
92
+ if user_id and session_id:
93
+ asyncio.create_task(self._persist_message(
94
+ user_id=user_id, session_id=session_id, role="assistant",
95
+ content=cached.get("answer"), metadata=cached.get("metadata")
96
+ ))
97
+ return
98
  else:
99
  cache_key = None
100
 
101
+ # ── 3. Pre-flight (SymPy) ─────────────────────────────────────────
 
102
  if not image_data:
103
  preflight_result = self._try_sympy(query)
104
  if preflight_result is not None:
105
+ yield {"type": "thought", "content": "Calculating result symbolically..."}
106
+ yield {"type": "answer", "content": preflight_result}
107
+
108
  result_schema.update({
109
+ "source": "sympy_preflight",
110
+ "answer": preflight_result,
111
+ "metadata": {"model": "sympy", "tools_used": ["sympy"]}
 
 
 
112
  })
113
+
114
+ self._background_log(query, result_schema, user_id, session_id, cache_key)
115
+ return
116
+
117
+ # ── 4. Agentic Streaming Loop ─────────────────────────────────────
118
+ full_answer = ""
119
+ async for event in self.adk_agent.solve(
120
+ problem=query, image_data=image_data,
121
+ session_id=session_id, user_id=user_id
122
+ ):
123
+ if event["type"] == "thought":
124
+ yield event
125
+ elif event["type"] in ("action", "observation"):
126
+ yield event
127
+ elif event["type"] == "error":
128
+ yield event
129
+ else:
130
+ # Treat everything else as part of the answer
131
+ full_answer += event["content"]
132
+ yield {"type": "answer", "content": event["content"]}
133
+
134
+ # ── 5. Finalize ───────────────────────────────────────────────────
135
+ result_schema["answer"] = full_answer
136
+ result_schema["metadata"]["latency_ms"] = int((time.time() - start_time) * 1000)
137
+
138
+ if full_answer:
139
+ self._background_log(query, result_schema, user_id, session_id, cache_key)
 
 
 
140
 
141
+ except Exception as e:
142
+ logger.error(f"Orchestrator Error: {e}")
143
+ yield {"type": "error", "content": f"Internal Error: {str(e)}"}
 
 
 
 
144
 
145
+ async def _persist_message(self, user_id, session_id, role, content, **kwargs):
146
+ try:
147
+ self.db_manager.create_session(user_id, session_id)
148
+ self.db_manager.save_chat_message(user_id, session_id, role, content, **kwargs)
149
  except Exception as e:
150
+ logger.error(f"Failed to persist message: {e}")
151
+
152
+ def _background_log(self, query, schema, user_id, session_id, cache_key):
153
+ """Fire and forget persistence tasks."""
154
+ asyncio.create_task(self._persist_message(
155
+ user_id=user_id, session_id=session_id, role="assistant",
156
+ content=schema["answer"], metadata=schema["metadata"]
157
+ ))
158
+ if settings.ENABLE_CACHE and cache_key:
159
+ self.cache_manager.set_cached_answer(cache_key, schema)
160
+ self.db_manager.save_problem({"content": query}, schema)
161
 
 
 
 
162
  def _try_sympy(self, query: str) -> Optional[str]:
 
 
 
 
 
163
  try:
164
  intent: Optional[MathIntent] = self.normalizer.normalize(query)
165
+ if intent is None: return None
166
+ expr_str = self._prep_expr(intent.expression)
167
+ target_var = sympy.Symbol(intent.variable or "x")
168
+ if intent.intent == "arithmetic": return self._solve_arithmetic(expr_str)
169
+ if intent.intent == "equation": return self._solve_equation(expr_str, target_var)
 
 
 
 
 
 
 
170
  if intent.intent == "derivative":
171
+ expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
172
+ return f"d/d{target_var}({intent.expression}) = {sympy.latex(sympy.diff(expr, target_var))}"
 
 
173
  if intent.intent == "integral":
174
+ expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
175
+ return f"∫({intent.expression}) d{target_var} = {sympy.latex(sympy.integrate(expr, target_var))} + C"
176
+ if intent.intent == "limit": return self._solve_limit(intent, query)
 
 
 
 
177
  if intent.intent == "simplification":
178
+ expr = parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS)
179
+ return f"Simplified: {sympy.latex(sympy.simplify(expr))}"
180
+ except Exception: pass
 
 
 
 
 
181
  return None
182
 
183
  def _prep_expr(self, expr: str) -> str:
184
+ expr = expr.replace("^", "**")
185
+ expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
186
+ expr = re.sub(r"\)\s*\(", ")*(", expr)
 
187
  return expr.strip()
188
 
189
  def _solve_arithmetic(self, expr_str: str) -> Optional[str]:
 
190
  try:
191
+ result = sympy.simplify(parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS))
 
 
192
  if result.is_number:
193
  numeric = float(result)
194
+ return str(int(numeric)) if numeric == int(numeric) else f"{numeric:.6g}"
 
 
 
195
  return sympy.latex(result)
196
+ except Exception: return None
 
197
 
198
  def _solve_equation(self, expr_str: str, var: sympy.Symbol) -> Optional[str]:
 
199
  try:
200
  parts = expr_str.split("=", 1)
201
  if len(parts) == 2:
202
+ lhs = parse_expr(self._prep_expr(parts[0]), transformations=_SYMPY_TRANSFORMATIONS)
203
+ rhs = parse_expr(self._prep_expr(parts[1]), transformations=_SYMPY_TRANSFORMATIONS)
204
  solution = sympy.solve(lhs - rhs, var)
205
  else:
206
+ solution = sympy.solve(parse_expr(expr_str, transformations=_SYMPY_TRANSFORMATIONS), var)
207
+ if not solution: return "No solution found."
208
+ if len(solution) == 1: return f"{var} = {sympy.latex(solution[0])}"
209
+ return f"{var} ∈ {{{', '.join(sympy.latex(s) for s in solution)}}}"
210
+ except Exception: return None
 
 
 
 
 
 
211
 
212
  def _solve_limit(self, intent: MathIntent, original_query: str) -> Optional[str]:
 
213
  try:
214
+ match = re.search(r"limit of\s+(.+?)\s+as\s+(\w+)\s+approaches\s+(.+)", original_query, re.IGNORECASE)
215
+ if not match: return None
216
+ expr = parse_expr(self._prep_expr(match.group(1)), transformations=_SYMPY_TRANSFORMATIONS)
217
+ var = sympy.Symbol(match.group(2).strip())
218
+ point = parse_expr(self._prep_expr(match.group(3).strip()), transformations=_SYMPY_TRANSFORMATIONS)
219
+ return f"lim({var}→{point}) {sympy.latex(expr)} = {sympy.latex(sympy.limit(expr, var, point))}"
220
+ except Exception: return None
 
 
 
 
 
 
 
221
 
 
 
 
 
 
 
 
 
222
  def _make_cache_key(self, query: str) -> str:
223
  return hashlib.sha256(query.strip().lower().encode()).hexdigest()
 
 
 
 
app/core/schemas.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Any, Dict, Optional, List
2
  from pydantic import BaseModel, Field, model_validator
3
 
@@ -58,3 +59,39 @@ class HealthResponse(BaseModel):
58
  """
59
  status: str
60
  version: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
  from typing import Any, Dict, Optional, List
3
  from pydantic import BaseModel, Field, model_validator
4
 
 
59
  """
60
  status: str
61
  version: str
62
+
63
+ # --- Chat History Schemas ---
64
+
65
+ class Message(BaseModel):
66
+ role: str
67
+ content: str
68
+ timestamp: datetime
69
+ reasoning: Optional[str] = None
70
+ metadata: Dict[str, Any] = {}
71
+ steps: List[str] = []
72
+
73
+ class ChatSession(BaseModel):
74
+ session_id: str
75
+ title: str
76
+ created_at: datetime
77
+ # messages: Optional[List[Message]] = None # Optional for listing
78
+
79
+ class SessionRename(BaseModel):
80
+ title: str = Field(..., min_length=1, max_length=100)
81
+
82
+ # --- Auth Schemas ---
83
+
84
+ class UserSignup(BaseModel):
85
+ email: str
86
+ password: str = Field(..., min_length=8, max_length=72)
87
+ full_name: Optional[str] = None
88
+
89
+ class UserLogin(BaseModel):
90
+ email: str
91
+ password: str = Field(..., max_length=72)
92
+
93
+ class TokenResponse(BaseModel):
94
+ access_token: str
95
+ token_type: str = "bearer"
96
+ user_id: str
97
+ email: str
app/core/security.py CHANGED
@@ -1,31 +1,16 @@
1
  import logging
2
- import firebase_admin
3
- from firebase_admin import credentials, auth
4
  from fastapi import HTTPException, status, Security
5
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
6
  from app.core.settings import settings
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
- # Initialize Firebase Admin SDK
11
- try:
12
- if settings.FIREBASE_CREDENTIALS_PATH:
13
- if not firebase_admin._apps:
14
- cred = credentials.Certificate(settings.FIREBASE_CREDENTIALS_PATH)
15
- firebase_admin.initialize_app(cred)
16
- logger.info("Firebase Admin SDK initialized successfully.")
17
- else:
18
- logger.info("Firebase Admin SDK already initialized.")
19
- else:
20
- logger.warning("FIREBASE_CREDENTIALS_PATH not set. Auth will fail if enabled.")
21
- except Exception as e:
22
- logger.error(f"Failed to initialize Firebase: {e}")
23
-
24
  security = HTTPBearer()
25
 
26
  def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
27
  """
28
- Verifies the Firebase ID token.
29
  Returns the decoded token dict if valid.
30
  """
31
  token = credentials.credentials
@@ -39,16 +24,21 @@ def verify_token(credentials: HTTPAuthorizationCredentials = Security(security))
39
  logger.info(f"Using MOCK AUTH for token: {token}")
40
  return {"uid": "dev_user_123", "email": "dev@mathminds.ai"}
41
 
42
- try:
43
- decoded_token = auth.verify_id_token(token)
44
- return decoded_token
45
- except Exception as e:
46
- logger.warning(f"Auth failed: {e}")
47
- raise HTTPException(
48
- status_code=status.HTTP_401_UNAUTHORIZED,
49
- detail="Invalid authentication credentials",
50
- headers={"WWW-Authenticate": "Bearer"},
51
- )
 
 
 
 
 
52
 
53
  def get_current_user(token: dict = Security(verify_token)):
54
  """
 
1
  import logging
 
 
2
  from fastapi import HTTPException, status, Security
3
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4
  from app.core.settings import settings
5
+ from app.core.auth_utils import decode_access_token
6
 
7
  logger = logging.getLogger(__name__)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  security = HTTPBearer()
10
 
11
  def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
12
  """
13
+ Verifies the Local JWT access token.
14
  Returns the decoded token dict if valid.
15
  """
16
  token = credentials.credentials
 
24
  logger.info(f"Using MOCK AUTH for token: {token}")
25
  return {"uid": "dev_user_123", "email": "dev@mathminds.ai"}
26
 
27
+ # Use local JWT verification
28
+ payload = decode_access_token(token)
29
+ if payload:
30
+ # Map 'sub' from JWT to 'uid' to maintain compatibility with existing code
31
+ return {
32
+ "uid": payload.get("sub"),
33
+ "email": payload.get("email")
34
+ }
35
+
36
+ logger.warning(f"Invalid or expired token provided.")
37
+ raise HTTPException(
38
+ status_code=status.HTTP_401_UNAUTHORIZED,
39
+ detail="Invalid or expired authentication credentials",
40
+ headers={"WWW-Authenticate": "Bearer"},
41
+ )
42
 
43
  def get_current_user(token: dict = Security(verify_token)):
44
  """
app/core/settings.py CHANGED
@@ -40,6 +40,11 @@ class Settings(BaseSettings):
40
  SUPABASE_KEY: Optional[str] = None
41
  WOLFRAM_APP_ID: Optional[str] = None
42
 
 
 
 
 
 
43
  model_config = {
44
  "env_file": ".env",
45
  "case_sensitive": True,
 
40
  SUPABASE_KEY: Optional[str] = None
41
  WOLFRAM_APP_ID: Optional[str] = None
42
 
43
+ # Security
44
+ JWT_SECRET_KEY: str = "super_secret_key_change_me"
45
+ JWT_ALGORITHM: str = "HS256"
46
+ ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7 # 1 Week
47
+
48
  model_config = {
49
  "env_file": ".env",
50
  "case_sensitive": True,
app/memory/database.py CHANGED
@@ -55,10 +55,15 @@ class DatabaseManager:
55
 
56
  self.db = self.client[db_name]
57
  self.collection = self.db["solved_problems"]
 
 
58
 
59
- # Ensure index
60
- index = IndexModel([("hash", ASCENDING)], name="hash_index")
61
- self.collection.create_indexes([index])
 
 
 
62
 
63
  logger.info(f"Successfully connected to MongoDB at {self.mongo_uri} (DB: {db_name})")
64
 
@@ -134,18 +139,20 @@ class DatabaseManager:
134
  logger.error(f"Failed to save problem: {e}")
135
  return False
136
 
137
- def create_session(self, session_id: str, title: str = "New Chat") -> bool:
138
  """
139
- Initialize a new chat session.
 
140
  """
141
- if self.db is None:
142
  return False
143
  try:
144
- self.db["chat_sessions"].update_one(
145
  {"session_id": session_id},
146
  {
147
  "$setOnInsert": {
148
  "session_id": session_id,
 
149
  "title": title,
150
  "created_at": datetime.now(timezone.utc),
151
  "messages": []
@@ -154,78 +161,151 @@ class DatabaseManager:
154
  upsert=True
155
  )
156
  return True
157
- except PyMongoError as e:
158
- logger.error(f"Failed to create session {session_id}: {e}")
 
 
 
 
159
  return False
160
 
161
- def get_chat_history(self, session_id: str, limit: int = 10) -> List[Dict[str, Any]]:
162
  """
163
- Retrieve recent messages for a session.
164
  """
165
- if self.db is None:
166
  return []
167
  try:
168
- # Get the session document with sliced messages
169
- doc = self.db["chat_sessions"].find_one(
170
- {"session_id": session_id},
171
- {"messages": {"$slice": -limit}}
172
- )
173
- if doc and "messages" in doc:
174
- return doc["messages"]
175
- return []
176
  except PyMongoError as e:
177
- logger.error(f"Failed to get history for {session_id}: {e}")
178
  return []
179
 
180
- def save_chat_message(self, session_id: str, role: str, content: str) -> bool:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  """
182
  Append a message to the session history.
183
- Also updates the session title if it's the first user message.
184
  """
185
- if self.db is None:
186
  return False
187
  try:
188
  # logic to update title if it's currently "New Chat" and this is a user message
189
  if role == "user":
190
- session = self.db["chat_sessions"].find_one({"session_id": session_id})
191
- if session and session.get("title") == "New Chat":
192
- # Generate title from content (truncate)
193
  new_title = content[:50] + "..." if len(content) > 50 else content
194
- self.db["chat_sessions"].update_one(
195
- {"session_id": session_id},
196
  {"$set": {"title": new_title}}
197
  )
198
 
199
  # Push the new message
200
- self.db["chat_sessions"].update_one(
201
- {"session_id": session_id},
202
- {
203
- "$push": {
204
- "messages": {
205
- "role": role,
206
- "content": content,
207
- "timestamp": datetime.now(timezone.utc)
208
- }
209
- }
210
- },
211
- upsert=True
212
  )
213
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  except PyMongoError as e:
215
- logger.error(f"Failed to save message to {session_id}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  return False
217
 
218
  # -------------------------------------------------------------------------
219
  # User Profile Management
220
  # -------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
222
  """
223
- Retrieve user profile by ID (Firebase UID).
224
  """
225
- if self.db is None:
226
  return None
227
  try:
228
- return self.db["users"].find_one({"user_id": user_id})
229
  except PyMongoError as e:
230
  logger.error(f"Failed to get profile for {user_id}: {e}")
231
  return None
@@ -234,10 +314,10 @@ class DatabaseManager:
234
  """
235
  Update or create user profile.
236
  """
237
- if self.db is None:
238
  return False
239
  try:
240
- self.db["users"].update_one(
241
  {"user_id": user_id},
242
  {"$set": {**data, "updated_at": datetime.now(timezone.utc)}},
243
  upsert=True
 
55
 
56
  self.db = self.client[db_name]
57
  self.collection = self.db["solved_problems"]
58
+ self.sessions_collection = self.db["chat_sessions"]
59
+ self.users_collection = self.db["users"]
60
 
61
+ # Ensure indexes
62
+ self.collection.create_index([("hash", ASCENDING)], name="hash_index")
63
+ self.sessions_collection.create_index([("user_id", ASCENDING)], name="user_id_index")
64
+ self.sessions_collection.create_index([("session_id", ASCENDING)], name="session_id_index", unique=True)
65
+ self.users_collection.create_index([("email", ASCENDING)], name="email_index", unique=True)
66
+ self.users_collection.create_index([("user_id", ASCENDING)], name="user_id_index", unique=True)
67
 
68
  logger.info(f"Successfully connected to MongoDB at {self.mongo_uri} (DB: {db_name})")
69
 
 
139
  logger.error(f"Failed to save problem: {e}")
140
  return False
141
 
142
+ def create_session(self, user_id: str, session_id: str, title: str = "New Chat") -> bool:
143
  """
144
+ Create a new chat session for a user.
145
+ Uses upsert with $setOnInsert to be idempotent.
146
  """
147
+ if self.sessions_collection is None:
148
  return False
149
  try:
150
+ self.sessions_collection.update_one(
151
  {"session_id": session_id},
152
  {
153
  "$setOnInsert": {
154
  "session_id": session_id,
155
+ "user_id": user_id,
156
  "title": title,
157
  "created_at": datetime.now(timezone.utc),
158
  "messages": []
 
161
  upsert=True
162
  )
163
  return True
164
+ except Exception as e:
165
+ # If it's a duplicate key error, it means another thread just inserted it.
166
+ # That's fine, we consider the session "created" or at least existing.
167
+ if "E11000" in str(e) or "duplicate key" in str(e).lower():
168
+ return True
169
+ logger.error(f"Failed to create session {session_id} for user {user_id}: {e}")
170
  return False
171
 
172
+ def list_sessions(self, user_id: str) -> List[Dict[str, Any]]:
173
  """
174
+ Retrieve all sessions for a specific user.
175
  """
176
+ if self.sessions_collection is None:
177
  return []
178
  try:
179
+ cursor = self.sessions_collection.find(
180
+ {"user_id": user_id},
181
+ {"session_id": 1, "title": 1, "created_at": 1, "_id": 0}
182
+ ).sort("created_at", -1)
183
+ return list(cursor)
 
 
 
184
  except PyMongoError as e:
185
+ logger.error(f"Failed to list sessions for user {user_id}: {e}")
186
  return []
187
 
188
+ def get_chat_history(self, user_id: str, session_id: str, limit: int = 50) -> Optional[List[Dict[str, Any]]]:
189
+ """
190
+ Retrieve recent messages for a session, ensuring it belongs to the user.
191
+ Returns None if session not found or not owned by user.
192
+ """
193
+ if self.sessions_collection is None:
194
+ return None
195
+ try:
196
+ doc = self.sessions_collection.find_one(
197
+ {"session_id": session_id, "user_id": user_id},
198
+ {"messages": {"$slice": -limit}, "_id": 0}
199
+ )
200
+ if doc is not None:
201
+ return doc.get("messages", [])
202
+ return None
203
+ except PyMongoError as e:
204
+ logger.error(f"Failed to get history for {session_id} (user: {user_id}): {e}")
205
+ return None
206
+
207
+ def save_chat_message(self, user_id: str, session_id: str, role: str, content: str, **kwargs) -> bool:
208
  """
209
  Append a message to the session history.
210
+ Only succeeds if the session belongs to the user_id.
211
  """
212
+ if self.sessions_collection is None:
213
  return False
214
  try:
215
  # logic to update title if it's currently "New Chat" and this is a user message
216
  if role == "user":
217
+ session = self.sessions_collection.find_one({"session_id": session_id, "user_id": user_id})
218
+ if session and (session.get("title") == "New Chat" or session.get("title") == "New Session" or session.get("title") == "Untitled"):
 
219
  new_title = content[:50] + "..." if len(content) > 50 else content
220
+ self.sessions_collection.update_one(
221
+ {"session_id": session_id, "user_id": user_id},
222
  {"$set": {"title": new_title}}
223
  )
224
 
225
  # Push the new message
226
+ msg = {
227
+ "role": role,
228
+ "content": content,
229
+ "timestamp": datetime.now(timezone.utc)
230
+ }
231
+ msg.update(kwargs)
232
+
233
+ result = self.sessions_collection.update_one(
234
+ {"session_id": session_id, "user_id": user_id},
235
+ {"$push": {"messages": msg}}
 
 
236
  )
237
+ # return True if we found the document to update
238
+ return result.matched_count > 0
239
+ except PyMongoError as e:
240
+ logger.error(f"Failed to save message to {session_id} for user {user_id}: {e}")
241
+ return False
242
+
243
+ def delete_session(self, user_id: str, session_id: str) -> bool:
244
+ """
245
+ Delete a session belonging to a user.
246
+ """
247
+ if self.sessions_collection is None:
248
+ return False
249
+ try:
250
+ result = self.sessions_collection.delete_one({"session_id": session_id, "user_id": user_id})
251
+ return result.deleted_count > 0
252
  except PyMongoError as e:
253
+ logger.error(f"Failed to delete session {session_id} for user {user_id}: {e}")
254
+ return False
255
+
256
+ def rename_session(self, user_id: str, session_id: str, new_title: str) -> bool:
257
+ """
258
+ Rename a session belonging to a user.
259
+ """
260
+ if self.sessions_collection is None:
261
+ return False
262
+ try:
263
+ result = self.sessions_collection.update_one(
264
+ {"session_id": session_id, "user_id": user_id},
265
+ {"$set": {"title": new_title}}
266
+ )
267
+ # return True if we found the document to update (even if title was same)
268
+ return result.matched_count > 0
269
+ except PyMongoError as e:
270
+ logger.error(f"Failed to rename session {session_id} for user {user_id}: {e}")
271
  return False
272
 
273
  # -------------------------------------------------------------------------
274
  # User Profile Management
275
  # -------------------------------------------------------------------------
276
+ def create_user(self, user_data: Dict[str, Any]) -> bool:
277
+ """
278
+ Create a new user in the database.
279
+ """
280
+ if self.users_collection is None:
281
+ return False
282
+ try:
283
+ self.users_collection.insert_one(user_data)
284
+ return True
285
+ except PyMongoError as e:
286
+ logger.error(f"Failed to create user: {e}")
287
+ return False
288
+
289
+ def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
290
+ """
291
+ Retrieve a user by email.
292
+ """
293
+ if self.users_collection is None:
294
+ return None
295
+ try:
296
+ return self.users_collection.find_one({"email": email})
297
+ except PyMongoError as e:
298
+ logger.error(f"Failed to get user by email {email}: {e}")
299
+ return None
300
+
301
  def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
302
  """
303
+ Retrieve user profile by ID (Firebase UID or local UID).
304
  """
305
+ if self.users_collection is None:
306
  return None
307
  try:
308
+ return self.users_collection.find_one({"user_id": user_id})
309
  except PyMongoError as e:
310
  logger.error(f"Failed to get profile for {user_id}: {e}")
311
  return None
 
314
  """
315
  Update or create user profile.
316
  """
317
+ if self.users_collection is None:
318
  return False
319
  try:
320
+ self.users_collection.update_one(
321
  {"user_id": user_id},
322
  {"$set": {**data, "updated_at": datetime.now(timezone.utc)}},
323
  upsert=True
frontend/app.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
  import requests
3
- import json
4
  import base64
5
  from PIL import Image
6
  import io
@@ -13,21 +12,35 @@ from dotenv import load_dotenv
13
  load_dotenv()
14
 
15
  # ── Session state: ALL keys initialized ONCE at the very top ─────────────────
16
- # BUG 1 WAS HERE: The original code had TWO `if "user" not in st.session_state`
17
- # blocks one here and one again at line ~4797 after the CSS/config blocks.
18
- # Streamlit re-runs the whole script top-to-bottom on every rerun. On the rerun
19
- # triggered after login, the second block executed and found "user" ALREADY in
20
- # session_state (because we just set it during login), so it was a no-op —
21
- # BUT on a hard browser refresh the two blocks ran in the same execution pass
22
- # and the second one re-initialized user=None, wiping the login state.
23
- # FIX: One single initialization block here, never again below.
24
  if "is_processing" not in st.session_state:
25
  st.session_state.is_processing = False
26
  if "user" not in st.session_state:
27
- st.session_state.user = None
28
  if "current_view" not in st.session_state:
29
  st.session_state.current_view = "Chat"
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # ====================================================
32
  # Page Config — must come before any st.* calls
33
  # ====================================================
@@ -91,59 +104,177 @@ st.markdown("""
91
  """, unsafe_allow_html=True)
92
 
93
  # ====================================================
94
- # Config & State
95
  # ====================================================
96
- API_URL = "http://localhost:8000/solve"
97
- HISTORY_FILE = "chat_history.json"
98
 
99
- if "chat_sessions" not in st.session_state:
100
- if os.path.exists(HISTORY_FILE):
101
- with open(HISTORY_FILE, "r") as f:
102
- st.session_state.chat_sessions = json.load(f)
103
- else:
104
- st.session_state.chat_sessions = {}
105
 
106
- if "active_session_id" not in st.session_state:
107
- sid = str(uuid.uuid4())
108
- st.session_state.chat_sessions[sid] = {"title": "New Session", "messages": [], "created_at": time.time()}
109
- st.session_state.active_session_id = sid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- # ── IMPORTANT: No second "user" init block here. See top of file. ─────────────
112
 
113
  # ====================================================
114
  # Helper Functions
115
  # ====================================================
116
- def save_history():
117
- with open(HISTORY_FILE, "w") as f:
118
- json.dump(st.session_state.chat_sessions, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def get_active_session():
121
- return st.session_state.chat_sessions[st.session_state.active_session_id]
 
 
 
 
122
 
123
  def add_message(role, content, sent_to_api=False, **kwargs):
124
- session = get_active_session()
125
  msg = {"role": role, "content": content, "timestamp": time.time(), "sent_to_api": sent_to_api}
126
  msg.update(kwargs)
127
- session["messages"].append(msg)
128
- save_history()
129
 
130
  def new_chat():
131
- sid = str(uuid.uuid4())
132
- st.session_state.chat_sessions[sid] = {"title": "New Session", "messages": [], "created_at": time.time()}
133
- st.session_state.active_session_id = sid
134
- save_history()
135
- st.rerun()
 
 
 
 
 
 
 
 
 
136
 
137
  def delete_chat(sid):
138
- if sid in st.session_state.chat_sessions:
139
- del st.session_state.chat_sessions[sid]
140
- if st.session_state.active_session_id == sid:
141
- if st.session_state.chat_sessions:
142
- st.session_state.active_session_id = list(st.session_state.chat_sessions.keys())[0]
143
- else:
144
- new_chat()
145
- save_history()
146
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  # ====================================================
149
  # Login Screen
@@ -168,23 +299,29 @@ def login_screen():
168
  password = st.text_input("Password", type="password")
169
  if st.form_submit_button("Sign In", use_container_width=True):
170
  if email and password:
171
- api_key = os.getenv("FIREBASE_WEB_API_KEY")
172
- if not api_key:
173
- st.error("Missing FIREBASE_WEB_API_KEY in .env")
174
- else:
175
- try:
176
- url = f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={api_key}"
177
- r = requests.post(url, json={"email": email, "password": password, "returnSecureToken": True}, timeout=30)
178
- if r.status_code == 200:
179
- d = r.json()
180
- st.session_state.user = {"email": d["email"], "token": d["idToken"], "uid": d["localId"]}
181
- st.success(f"Welcome back, {d['email']}!")
182
- time.sleep(0.5)
183
- st.rerun()
184
- else:
185
- st.error(f"Login Failed: {r.json().get('error',{}).get('message','Unknown error')}")
186
- except Exception as e:
187
- st.error(f"Connection Error: {e}")
 
 
 
 
 
 
188
  else:
189
  st.error("Please enter email and password.")
190
 
@@ -193,49 +330,78 @@ def login_screen():
193
  new_email = st.text_input("New Email", placeholder="new@student.edu")
194
  new_password = st.text_input("New Password", type="password")
195
  confirm_password = st.text_input("Confirm Password", type="password")
 
196
  if st.form_submit_button("Create Account", use_container_width=True):
197
  if new_email and new_password:
198
  if new_password != confirm_password:
199
  st.error("Passwords do not match!")
200
  else:
201
- api_key = os.getenv("FIREBASE_WEB_API_KEY")
202
- if not api_key:
203
- st.error("Missing FIREBASE_WEB_API_KEY in .env")
204
- else:
205
- try:
206
- url = f"https://identitytoolkit.googleapis.com/v1/accounts:signUp?key={api_key}"
207
- r = requests.post(url, json={"email": new_email, "password": new_password, "returnSecureToken": True}, timeout=30)
208
- if r.status_code == 200:
209
- d = r.json()
210
- st.session_state.user = {"email": d["email"], "token": d["idToken"], "uid": d["localId"]}
211
- st.success(f"Account Created! Welcome, {d['email']}!")
212
- time.sleep(0.5)
213
- st.rerun()
214
- else:
215
- st.error(f"Sign Up Failed: {r.json().get('error',{}).get('message','Unknown error')}")
216
- except Exception as e:
217
- st.error(f"Connection Error: {e}")
 
 
 
 
 
 
 
 
 
218
  else:
219
  st.error("Please fill all fields.")
220
 
221
- st.markdown("<p style='text-align:center;font-size:0.8rem;color:#6b7280;'>Powered by Gemini & SymPy</p>", unsafe_allow_html=True)
 
 
 
 
222
 
223
  # ── Auth gate ─────────────────────────────────────────────────────────────────
224
  if not st.session_state.user:
225
  login_screen()
226
  st.stop()
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  # ====================================================
229
  # Profile Interface
230
  # ====================================================
231
  def profile_interface():
232
  st.title("👤 User Profile")
233
  st.markdown("Customize your MathMinds experience.")
234
- headers = {"Authorization": f"Bearer {st.session_state.user['token']}"}
235
 
236
  if "profile_data" not in st.session_state:
237
  try:
238
- r = requests.get(f"{API_URL.replace('/solve','')}/users/profile", headers=headers, timeout=30)
239
  st.session_state.profile_data = r.json() if r.status_code == 200 else {}
240
  except Exception:
241
  st.session_state.profile_data = {}
@@ -246,43 +412,57 @@ def profile_interface():
246
 
247
  with st.form("profile_form"):
248
  display_name = st.text_input("Display Name", value=data.get("display_name", ""))
249
- math_level = st.selectbox("Math Proficiency Level", levels,
250
- index=levels.index(data.get("math_level","Undergraduate"))
251
- if data.get("math_level") in levels else 1)
252
- interests = st.multiselect("Areas of Interest", interests_all,
253
- default=[i for i in data.get("interests",[]) if i in interests_all])
 
 
 
 
254
  if st.form_submit_button("Save Profile", use_container_width=True, type="primary"):
255
  payload = {"display_name": display_name, "math_level": math_level, "interests": interests}
256
  try:
257
- r = requests.post(f"{API_URL.replace('/solve','')}/users/profile", json=payload, headers=headers)
258
  if r.status_code == 200:
259
  st.success("Profile updated!")
260
  st.session_state.profile_data = payload
261
- time.sleep(1); st.rerun()
 
262
  else:
263
  st.error(f"Update failed: {r.text}")
264
  except Exception as e:
265
  st.error(f"Error saving: {e}")
266
 
 
267
  # ====================================================
268
  # Chat Interface
269
  # ====================================================
270
  def chat_interface():
271
- if st.session_state.active_session_id not in st.session_state.chat_sessions:
272
- new_chat()
 
 
 
 
 
273
 
274
- st.title(st.session_state.chat_sessions[st.session_state.active_session_id]["title"])
275
- session = get_active_session()
276
 
277
  # ── 1. Render history ─────────────────────────────────────────────────────
278
- for msg in session["messages"]:
279
- if msg["role"] == "user":
280
- with st.chat_message("user", avatar="👤"):
 
281
  if msg.get("image_data"):
282
- st.image(base64.b64decode(msg["image_data"]), width=300)
 
 
 
283
  st.write(msg["content"])
284
- else:
285
- with st.chat_message("assistant", avatar="🤖"):
286
  meta = msg.get("metadata", {})
287
  if meta:
288
  badges = ""
@@ -293,16 +473,17 @@ def chat_interface():
293
  badges += '<span class="badge badge-blue">💾 CACHED</span>'
294
  elif src in ("google_adk_agent", "agent"):
295
  badges += '<span class="badge badge-purple">🤖 AGENT</span>'
296
- model = meta.get("model_used")
297
  if model:
298
  badges += f'<span class="badge" style="background:rgba(255,255,255,0.1);">{model}</span>'
299
  if badges:
300
  st.markdown(badges, unsafe_allow_html=True)
301
 
302
- content = msg["content"]
303
- if msg.get("reasoning"):
304
  with st.expander("Show Reasoning Steps"):
305
- st.markdown(msg["reasoning"])
 
 
306
  if isinstance(content, dict) and "final_answer" in content:
307
  st.markdown(f"**Answer:**\n\n> {content['final_answer']}")
308
  else:
@@ -316,24 +497,29 @@ def chat_interface():
316
  is_processing = st.session_state.get("is_processing", False)
317
 
318
  with tab_text:
319
- prompt = st.chat_input("Ask a math question...", disabled=is_processing)
 
 
320
 
321
  with tab_draw:
322
  col_canvas, col_controls = st.columns([3, 1])
323
  with col_canvas:
324
- if "canvas_key" not in st.session_state:
325
- st.session_state.canvas_key = "main_canvas"
326
  canvas_result = st_canvas(
327
  stroke_width=3, stroke_color="#FFFFFF", background_color="#000000",
328
- height=300, width=600, drawing_mode="freedraw", key=st.session_state.canvas_key,
 
 
 
 
 
 
329
  )
330
- draw_prompt = st.text_input("Question about drawing (optional)", placeholder="Solve this handwritten problem...")
331
  with col_controls:
332
  st.caption("Controls")
333
- if st.button("Clear Canvas"):
334
  st.session_state.canvas_key = f"canvas_{uuid.uuid4()}"
335
  st.rerun()
336
- if st.button("Solve Drawing", type="primary", disabled=is_processing):
337
  if canvas_result.image_data is not None:
338
  img = Image.fromarray(canvas_result.image_data.astype("uint8"), "RGBA")
339
  bg = Image.new("RGB", img.size, (0, 0, 0))
@@ -341,102 +527,99 @@ def chat_interface():
341
  buf = io.BytesIO()
342
  bg.save(buf, format="PNG")
343
  image_b64 = base64.b64encode(buf.getvalue()).decode()
344
- prompt = draw_prompt or "Solve this handwritten math problem."
345
 
346
  with tab_upload:
347
- uploaded = st.file_uploader("Upload Image", type=["png","jpg"], disabled=is_processing)
348
- upload_prompt = st.text_input("Question about image (optional)", placeholder="Analyze this image...", disabled=is_processing)
349
- if uploaded and st.button("Analyze Image", disabled=is_processing):
350
- image_b64 = base64.b64encode(uploaded.getvalue()).decode()
351
- prompt = upload_prompt or "Analyze this image."
352
 
353
- # ── 3. New user message → optimistic write + rerun ────────────────────────
354
  if prompt:
355
  req_id = str(uuid.uuid4())
356
  add_message("user", prompt, image_data=image_b64, request_id=req_id, sent_to_api=False)
357
  st.session_state.is_processing = True
358
  st.rerun()
359
 
360
- # ── 4. Recovery: if we restarted mid-flight, allow retry ──────────────────
361
- if session["messages"] and session["messages"][-1]["role"] == "user":
362
- last = session["messages"][-1]
363
- if last.get("sent_to_api") and not st.session_state.is_processing:
364
- last["sent_to_api"] = False
365
- save_history()
366
-
367
- # ── 5. Fire API call if last message is unsent user message ───────────────
368
  if (
369
- session["messages"]
370
- and session["messages"][-1]["role"] == "user"
371
- and not session["messages"][-1].get("sent_to_api", False)
372
  ):
373
- last = session["messages"][-1]
374
  current_request_id = last.get("request_id") or str(uuid.uuid4())
375
- last["request_id"] = current_request_id
376
 
377
  with st.chat_message("assistant", avatar="🤖"):
378
- with st.spinner("Agent is thinking..."):
379
- try:
380
- last["sent_to_api"] = True
381
- save_history()
382
-
383
- payload = {
384
- "text": last["content"],
385
- "image": last.get("image_data"),
386
- "model_preference": "agent",
387
- "session_id": st.session_state.active_session_id,
388
- "request_id": current_request_id,
389
- }
390
- headers = {}
391
- if st.session_state.user:
392
- headers["Authorization"] = f"Bearer {st.session_state.user['token']}"
393
-
394
- response = requests.post(API_URL, json=payload, headers=headers, timeout=360)
395
-
396
- if response.status_code == 200:
397
- data = response.json()
398
- if data.get("status") == "success":
399
- answer_raw = data.get("answer") or data.get("explanation") or "⚠️ No answer returned."
400
- meta = data.get("metadata", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  add_message(
402
- "assistant",
403
- answer_raw,
404
- reasoning=data.get("explanation"),
405
- metadata={
406
- "source": data.get("source", "agent"),
407
- "model_used": meta.get("model", "gemini-2.5-flash"),
408
- "latency": f"{meta.get('latency_ms',0)/1000:.2f}s",
409
- "tools": meta.get("tools_used", []),
410
- },
411
- steps=data.get("steps", [])
412
  )
413
- st.session_state.is_processing = False
414
- st.rerun() # ← BUG 2 FIX: missing rerun caused blank UI
415
-
416
- else:
417
- error_msg = data.get("error", "Unknown error")
418
- add_message("assistant", f"⚠️ Error: {error_msg}")
419
- st.session_state.is_processing = False
420
  st.rerun()
421
-
422
- elif response.status_code in [202, 409]:
423
- st.info("ℹ️ Request already processing.")
424
- st.session_state.is_processing = False
425
  st.rerun()
426
-
427
  else:
428
- try:
429
- err_msg = response.json().get("error", f"HTTP {response.status_code}")
430
- except Exception:
431
- err_msg = f"HTTP {response.status_code}"
432
- add_message("assistant", f"❌ Server Error: {err_msg}")
433
  st.session_state.is_processing = False
434
- st.rerun()
435
 
436
- except Exception as e:
437
- add_message("assistant", f"Connection Failed: {str(e)}")
438
- st.session_state.is_processing = False
439
- st.rerun()
 
440
 
441
  # ====================================================
442
  # Sidebar
@@ -445,16 +628,31 @@ with st.sidebar:
445
  st.markdown("### 🧠 MathMinds")
446
  st.write(f"Logged in as **{st.session_state.user['email']}**")
447
 
448
- view = st.radio("Navigation", ["Chat", "Profile"],
449
- index=0 if st.session_state.current_view == "Chat" else 1)
 
 
450
  if view != st.session_state.current_view:
451
  st.session_state.current_view = view
452
  st.rerun()
453
 
454
  if st.button("Sign Out", type="secondary"):
 
 
 
 
 
455
  st.session_state.user = None
456
  st.rerun()
457
 
 
 
 
 
 
 
 
 
458
  st.divider()
459
 
460
  if st.session_state.current_view == "Chat":
@@ -462,26 +660,45 @@ with st.sidebar:
462
  new_chat()
463
 
464
  st.markdown("#### History")
465
- sorted_sids = sorted(
466
- st.session_state.chat_sessions.keys(),
467
- key=lambda k: st.session_state.chat_sessions[k].get("created_at", 0),
468
- reverse=True
469
- )
470
- for sid in sorted_sids:
471
- sess = st.session_state.chat_sessions[sid]
472
- title = sess.get("title", "Untitled")
473
- isActive = (sid == st.session_state.active_session_id)
474
- col_nav, col_del = st.columns([0.85, 0.15])
475
- with col_nav:
476
- if st.button(f"{'📍 ' if isActive else ''}{title}", key=sid, use_container_width=True):
477
  st.session_state.active_session_id = sid
 
 
 
 
 
 
 
478
  st.rerun()
479
- with col_del:
480
- if isActive and st.button("🗑️", key=f"del_{sid}"):
481
  delete_chat(sid)
482
 
483
- # ── Router ────────────────────────────────────────────────────────────────────
484
- if st.session_state.current_view == "Profile":
485
- profile_interface()
486
- else:
487
- chat_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import requests
 
3
  import base64
4
  from PIL import Image
5
  import io
 
12
  load_dotenv()
13
 
14
  # ── Session state: ALL keys initialized ONCE at the very top ─────────────────
15
+ # CRITICAL: These must be the very first st.session_state accesses, before any
16
+ # st.* UI calls. Streamlit re-runs the entire script on every interaction.
 
 
 
 
 
 
17
  if "is_processing" not in st.session_state:
18
  st.session_state.is_processing = False
19
  if "user" not in st.session_state:
20
+ st.session_state.user = None # None = logged out
21
  if "current_view" not in st.session_state:
22
  st.session_state.current_view = "Chat"
23
 
24
+ # MULTIUSER FIX ─ these three keys must be RESET on logout.
25
+ # They are initialized here so first-run doesn't KeyError.
26
+ if "chat_sessions" not in st.session_state:
27
+ st.session_state.chat_sessions = []
28
+ if "active_session_id" not in st.session_state:
29
+ st.session_state.active_session_id = None
30
+ if "messages" not in st.session_state:
31
+ st.session_state.messages = []
32
+
33
+ # MULTIUSER FIX ─ track WHICH user's data is currently loaded.
34
+ # If this doesn't match st.session_state.user["uid"], we know we need to reload.
35
+ if "loaded_for_user" not in st.session_state:
36
+ st.session_state.loaded_for_user = None
37
+
38
+ if "renaming_session_id" not in st.session_state:
39
+ st.session_state.renaming_session_id = None
40
+
41
+ if "canvas_key" not in st.session_state:
42
+ st.session_state.canvas_key = "main_canvas"
43
+
44
  # ====================================================
45
  # Page Config — must come before any st.* calls
46
  # ====================================================
 
104
  """, unsafe_allow_html=True)
105
 
106
  # ====================================================
107
+ # Config
108
  # ====================================================
109
+ BASE_API_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
110
+ API_URL = f"{BASE_API_URL}/solve"
111
 
 
 
 
 
 
 
112
 
113
+ # ====================================================
114
+ # MULTIUSER ISOLATION — Core helper
115
+ # ====================================================
116
+ def _clear_user_state():
117
+ """
118
+ Wipe ALL per-user data from Streamlit session state.
119
+
120
+ Called on logout and whenever a different user logs in.
121
+
122
+ WHY THIS IS THE MOST IMPORTANT FUNCTION FOR MULTIUSER ISOLATION:
123
+ Streamlit's st.session_state is per browser-tab, not per user. If User A
124
+ logs in, chats, then User B logs in on the same tab, all of User A's
125
+ chat_sessions and messages are still sitting in st.session_state. The
126
+ backend correctly refuses to serve User A's data to User B (every DB query
127
+ filters by user_id), but the frontend would still DISPLAY User A's messages
128
+ briefly until the next API call returns. This function prevents that.
129
+ """
130
+ st.session_state.chat_sessions = []
131
+ st.session_state.active_session_id = None
132
+ st.session_state.messages = []
133
+ st.session_state.loaded_for_user = None
134
+ st.session_state.is_processing = False
135
+ st.session_state.current_view = "Chat"
136
+ st.session_state.renaming_session_id = None
137
+ st.session_state.canvas_key = f"canvas_{uuid.uuid4()}"
138
+ # Also clear profile cache if it exists
139
+ if "profile_data" in st.session_state:
140
+ del st.session_state["profile_data"]
141
 
 
142
 
143
  # ====================================================
144
  # Helper Functions
145
  # ====================================================
146
+ def get_auth_headers():
147
+ if st.session_state.user and "token" in st.session_state.user:
148
+ return {"Authorization": f"Bearer {st.session_state.user['token']}"}
149
+ return {}
150
+
151
+
152
+ def load_sessions():
153
+ """Fetch THIS user's chat sessions from the backend and populate state."""
154
+ try:
155
+ headers = get_auth_headers()
156
+ response = requests.get(f"{BASE_API_URL}/chat/sessions", headers=headers, timeout=30)
157
+ if response.status_code == 200:
158
+ st.session_state.chat_sessions = response.json()
159
+ # Mark that we've successfully loaded data for this specific user
160
+ if st.session_state.user:
161
+ st.session_state.loaded_for_user = st.session_state.user["uid"]
162
+ # Auto-select first session if none active
163
+ if not st.session_state.active_session_id and st.session_state.chat_sessions:
164
+ st.session_state.active_session_id = st.session_state.chat_sessions[0]["session_id"]
165
+ load_messages(st.session_state.active_session_id)
166
+ elif st.session_state.active_session_id and not any(
167
+ s["session_id"] == st.session_state.active_session_id
168
+ for s in st.session_state.chat_sessions
169
+ ):
170
+ # Active session was deleted — pick first or clear
171
+ if st.session_state.chat_sessions:
172
+ st.session_state.active_session_id = st.session_state.chat_sessions[0]["session_id"]
173
+ load_messages(st.session_state.active_session_id)
174
+ else:
175
+ st.session_state.active_session_id = None
176
+ st.session_state.messages = []
177
+ elif response.status_code == 401:
178
+ # JWT expired — force re-login
179
+ _clear_user_state()
180
+ st.session_state.user = None
181
+ st.error("Session expired. Please log in again.")
182
+ else:
183
+ st.error(f"Failed to load sessions: {response.status_code}")
184
+ st.session_state.chat_sessions = []
185
+ except Exception as e:
186
+ st.error(f"Error loading sessions: {e}")
187
+ st.session_state.chat_sessions = []
188
+
189
+
190
+ def load_messages(session_id):
191
+ """
192
+ Load messages for a session.
193
+ The backend enforces user ownership — it will 404 if session_id
194
+ doesn't belong to the authenticated user, so this is safe.
195
+ """
196
+ try:
197
+ headers = get_auth_headers()
198
+ response = requests.get(
199
+ f"{BASE_API_URL}/chat/sessions/{session_id}/messages",
200
+ headers=headers, timeout=30
201
+ )
202
+ if response.status_code == 200:
203
+ st.session_state.messages = response.json()
204
+ elif response.status_code == 404:
205
+ # Session doesn't belong to this user — clear silently
206
+ st.session_state.messages = []
207
+ st.session_state.active_session_id = None
208
+ st.warning("Session not found.")
209
+ else:
210
+ st.session_state.messages = []
211
+ st.error(f"Failed to load messages: {response.status_code}")
212
+ except Exception as e:
213
+ st.error(f"Error loading messages: {e}")
214
+ st.session_state.messages = []
215
+
216
 
217
  def get_active_session():
218
+ for s in st.session_state.chat_sessions:
219
+ if s["session_id"] == st.session_state.active_session_id:
220
+ return s
221
+ return None
222
+
223
 
224
  def add_message(role, content, sent_to_api=False, **kwargs):
225
+ """Optimistic UI update only — persistence happens in the backend via /solve."""
226
  msg = {"role": role, "content": content, "timestamp": time.time(), "sent_to_api": sent_to_api}
227
  msg.update(kwargs)
228
+ st.session_state.messages.append(msg)
229
+
230
 
231
  def new_chat():
232
+ try:
233
+ headers = get_auth_headers()
234
+ response = requests.post(f"{BASE_API_URL}/chat/sessions", headers=headers, timeout=30)
235
+ if response.status_code == 200:
236
+ new_s = response.json()
237
+ st.session_state.active_session_id = new_s["session_id"]
238
+ st.session_state.messages = []
239
+ load_sessions()
240
+ st.rerun()
241
+ else:
242
+ st.error("Failed to create new chat")
243
+ except Exception as e:
244
+ st.error(f"Error: {e}")
245
+
246
 
247
  def delete_chat(sid):
248
+ try:
249
+ headers = get_auth_headers()
250
+ response = requests.delete(f"{BASE_API_URL}/chat/sessions/{sid}", headers=headers, timeout=30)
251
+ if response.status_code == 200:
252
+ if st.session_state.active_session_id == sid:
253
+ st.session_state.active_session_id = None
254
+ st.session_state.messages = []
255
+ load_sessions()
256
+ st.rerun()
257
+ else:
258
+ st.error("Failed to delete chat")
259
+ except Exception as e:
260
+ st.error(f"Error: {e}")
261
+
262
+
263
+ def rename_chat(sid, new_title):
264
+ try:
265
+ headers = get_auth_headers()
266
+ response = requests.patch(
267
+ f"{BASE_API_URL}/chat/sessions/{sid}",
268
+ headers=headers, json={"title": new_title}, timeout=30
269
+ )
270
+ if response.status_code == 200:
271
+ load_sessions()
272
+ st.rerun()
273
+ else:
274
+ st.error("Failed to rename chat")
275
+ except Exception as e:
276
+ st.error(f"Error: {e}")
277
+
278
 
279
  # ====================================================
280
  # Login Screen
 
299
  password = st.text_input("Password", type="password")
300
  if st.form_submit_button("Sign In", use_container_width=True):
301
  if email and password:
302
+ try:
303
+ r = requests.post(
304
+ f"{BASE_API_URL}/auth/login",
305
+ json={"email": email, "password": password},
306
+ timeout=30
307
+ )
308
+ if r.status_code == 200:
309
+ d = r.json()
310
+ # MULTIUSER FIX: Clear ALL previous user data
311
+ # BEFORE setting the new user identity.
312
+ _clear_user_state()
313
+ st.session_state.user = {
314
+ "email": d["email"],
315
+ "token": d["access_token"],
316
+ "uid": d["user_id"]
317
+ }
318
+ st.success(f"Welcome back, {d['email']}!")
319
+ time.sleep(0.5)
320
+ st.rerun()
321
+ else:
322
+ st.error(f"Login Failed: {r.json().get('detail', 'Unknown error')}")
323
+ except Exception as e:
324
+ st.error(f"Connection Error: {e}")
325
  else:
326
  st.error("Please enter email and password.")
327
 
 
330
  new_email = st.text_input("New Email", placeholder="new@student.edu")
331
  new_password = st.text_input("New Password", type="password")
332
  confirm_password = st.text_input("Confirm Password", type="password")
333
+ full_name = st.text_input("Full Name", placeholder="Optional")
334
  if st.form_submit_button("Create Account", use_container_width=True):
335
  if new_email and new_password:
336
  if new_password != confirm_password:
337
  st.error("Passwords do not match!")
338
  else:
339
+ try:
340
+ r = requests.post(
341
+ f"{BASE_API_URL}/auth/signup",
342
+ json={
343
+ "email": new_email,
344
+ "password": new_password,
345
+ "full_name": full_name
346
+ },
347
+ timeout=30
348
+ )
349
+ if r.status_code == 200:
350
+ d = r.json()
351
+ # ✅ MULTIUSER FIX: Same as login — clear first
352
+ _clear_user_state()
353
+ st.session_state.user = {
354
+ "email": d["email"],
355
+ "token": d["access_token"],
356
+ "uid": d["user_id"]
357
+ }
358
+ st.success(f"Account Created! Welcome, {d['email']}!")
359
+ time.sleep(0.5)
360
+ st.rerun()
361
+ else:
362
+ st.error(f"Sign Up Failed: {r.json().get('detail', 'Unknown error')}")
363
+ except Exception as e:
364
+ st.error(f"Connection Error: {e}")
365
  else:
366
  st.error("Please fill all fields.")
367
 
368
+ st.markdown(
369
+ "<p style='text-align:center;font-size:0.8rem;color:#6b7280;'>Powered by Gemini & SymPy</p>",
370
+ unsafe_allow_html=True
371
+ )
372
+
373
 
374
  # ── Auth gate ─────────────────────────────────────────────────────────────────
375
  if not st.session_state.user:
376
  login_screen()
377
  st.stop()
378
 
379
+ # ====================================================
380
+ # ✅ MULTIUSER FIX — Per-rerun data isolation check
381
+ # ====================================================
382
+ # At this point we know a user IS logged in.
383
+ # Check: is the data currently in state actually for THIS user?
384
+ # This handles the scenario where User A's browser tab is reused by User B
385
+ # (e.g. token swap, shared kiosk, etc.) without a full page reload.
386
+ _current_uid = st.session_state.user["uid"]
387
+ if st.session_state.loaded_for_user != _current_uid:
388
+ # Data in state belongs to a different user (or nobody) — reload for current user
389
+ _clear_user_state()
390
+ load_sessions()
391
+ # loaded_for_user is set inside load_sessions() on success
392
+
393
+
394
  # ====================================================
395
  # Profile Interface
396
  # ====================================================
397
  def profile_interface():
398
  st.title("👤 User Profile")
399
  st.markdown("Customize your MathMinds experience.")
400
+ headers = get_auth_headers()
401
 
402
  if "profile_data" not in st.session_state:
403
  try:
404
+ r = requests.get(f"{BASE_API_URL}/users/profile", headers=headers, timeout=30)
405
  st.session_state.profile_data = r.json() if r.status_code == 200 else {}
406
  except Exception:
407
  st.session_state.profile_data = {}
 
412
 
413
  with st.form("profile_form"):
414
  display_name = st.text_input("Display Name", value=data.get("display_name", ""))
415
+ math_level = st.selectbox(
416
+ "Math Proficiency Level", levels,
417
+ index=levels.index(data.get("math_level", "Undergraduate"))
418
+ if data.get("math_level") in levels else 1
419
+ )
420
+ interests = st.multiselect(
421
+ "Areas of Interest", interests_all,
422
+ default=[i for i in data.get("interests", []) if i in interests_all]
423
+ )
424
  if st.form_submit_button("Save Profile", use_container_width=True, type="primary"):
425
  payload = {"display_name": display_name, "math_level": math_level, "interests": interests}
426
  try:
427
+ r = requests.post(f"{BASE_API_URL}/users/profile", json=payload, headers=headers)
428
  if r.status_code == 200:
429
  st.success("Profile updated!")
430
  st.session_state.profile_data = payload
431
+ time.sleep(1)
432
+ st.rerun()
433
  else:
434
  st.error(f"Update failed: {r.text}")
435
  except Exception as e:
436
  st.error(f"Error saving: {e}")
437
 
438
+
439
  # ====================================================
440
  # Chat Interface
441
  # ====================================================
442
  def chat_interface():
443
+ if not st.session_state.active_session_id:
444
+ if st.session_state.chat_sessions:
445
+ st.session_state.active_session_id = st.session_state.chat_sessions[0]["session_id"]
446
+ load_messages(st.session_state.active_session_id)
447
+ else:
448
+ new_chat()
449
+ return
450
 
451
+ active_sess = get_active_session()
452
+ st.title(active_sess["title"] if active_sess else "Chat")
453
 
454
  # ── 1. Render history ─────────────────────────────────────────────────────
455
+ for msg in st.session_state.messages:
456
+ role = msg["role"]
457
+ with st.chat_message(role, avatar="👤" if role == "user" else "🤖"):
458
+ if role == "user":
459
  if msg.get("image_data"):
460
+ try:
461
+ st.image(base64.b64decode(msg["image_data"]), width=300)
462
+ except Exception:
463
+ pass
464
  st.write(msg["content"])
465
+ else:
 
466
  meta = msg.get("metadata", {})
467
  if meta:
468
  badges = ""
 
473
  badges += '<span class="badge badge-blue">💾 CACHED</span>'
474
  elif src in ("google_adk_agent", "agent"):
475
  badges += '<span class="badge badge-purple">🤖 AGENT</span>'
476
+ model = meta.get("model_used") or meta.get("model")
477
  if model:
478
  badges += f'<span class="badge" style="background:rgba(255,255,255,0.1);">{model}</span>'
479
  if badges:
480
  st.markdown(badges, unsafe_allow_html=True)
481
 
482
+ if msg.get("reasoning") or msg.get("explanation"):
 
483
  with st.expander("Show Reasoning Steps"):
484
+ st.markdown(msg.get("reasoning") or msg.get("explanation"))
485
+
486
+ content = msg["content"]
487
  if isinstance(content, dict) and "final_answer" in content:
488
  st.markdown(f"**Answer:**\n\n> {content['final_answer']}")
489
  else:
 
497
  is_processing = st.session_state.get("is_processing", False)
498
 
499
  with tab_text:
500
+ text_prompt = st.chat_input("Ask a math question...", disabled=is_processing)
501
+ if text_prompt:
502
+ prompt = text_prompt
503
 
504
  with tab_draw:
505
  col_canvas, col_controls = st.columns([3, 1])
506
  with col_canvas:
 
 
507
  canvas_result = st_canvas(
508
  stroke_width=3, stroke_color="#FFFFFF", background_color="#000000",
509
+ height=300, width=600, drawing_mode="freedraw",
510
+ key=st.session_state.canvas_key,
511
+ )
512
+ draw_prompt_input = st.text_input(
513
+ "Question about drawing (optional)",
514
+ placeholder="Solve this handwritten problem...",
515
+ key="draw_prompt_input"
516
  )
 
517
  with col_controls:
518
  st.caption("Controls")
519
+ if st.button("Clear"):
520
  st.session_state.canvas_key = f"canvas_{uuid.uuid4()}"
521
  st.rerun()
522
+ if st.button("Solve", type="primary", disabled=is_processing):
523
  if canvas_result.image_data is not None:
524
  img = Image.fromarray(canvas_result.image_data.astype("uint8"), "RGBA")
525
  bg = Image.new("RGB", img.size, (0, 0, 0))
 
527
  buf = io.BytesIO()
528
  bg.save(buf, format="PNG")
529
  image_b64 = base64.b64encode(buf.getvalue()).decode()
530
+ prompt = draw_prompt_input or "Solve this handwritten math problem."
531
 
532
  with tab_upload:
533
+ uploaded_file = st.file_uploader("Upload", type=["png", "jpg"], disabled=is_processing)
534
+ upload_prompt_input = st.text_input("Question", placeholder="Analyze...", disabled=is_processing, key="upload_prompt_input")
535
+ if uploaded_file and st.button("Analyze", disabled=is_processing):
536
+ image_b64 = base64.b64encode(uploaded_file.getvalue()).decode()
537
+ prompt = upload_prompt_input or "Analyze this image."
538
 
539
+ # ── 3. New user message → optimistic UI update + rerun ────────────────────
540
  if prompt:
541
  req_id = str(uuid.uuid4())
542
  add_message("user", prompt, image_data=image_b64, request_id=req_id, sent_to_api=False)
543
  st.session_state.is_processing = True
544
  st.rerun()
545
 
546
+ # ── 4. Fire API call if last message is an unsent user message ────────────
 
 
 
 
 
 
 
547
  if (
548
+ st.session_state.messages
549
+ and st.session_state.messages[-1]["role"] == "user"
550
+ and not st.session_state.messages[-1].get("sent_to_api", False)
551
  ):
552
+ last = st.session_state.messages[-1]
553
  current_request_id = last.get("request_id") or str(uuid.uuid4())
554
+ last["request_id"] = current_request_id
555
 
556
  with st.chat_message("assistant", avatar="🤖"):
557
+ status_msg = st.status("Thinking...", expanded=True)
558
+ logic_placeholder = status_msg.empty()
559
+ answer_placeholder = st.empty()
560
+
561
+ full_answer = ""
562
+ logic_trace = []
563
+
564
+ try:
565
+ last["sent_to_api"] = True
566
+ payload = {
567
+ "text": last["content"],
568
+ "image": last.get("image_data"),
569
+ "session_id": st.session_state.active_session_id,
570
+ "request_id": current_request_id,
571
+ }
572
+ headers = get_auth_headers()
573
+
574
+ with requests.post(API_URL, json=payload, headers=headers, stream=True, timeout=360) as r:
575
+ if r.status_code == 200:
576
+ for line in r.iter_lines():
577
+ if line:
578
+ try:
579
+ data = json.loads(line)
580
+ if data["type"] == "thought":
581
+ logic_trace.append(data["content"])
582
+ elif data["type"] == "action":
583
+ logic_trace.append(f"⚙️ {data['content']}")
584
+ elif data["type"] == "observation":
585
+ logic_trace.append(f"👁️ {data['content']}")
586
+ elif data["type"] == "answer":
587
+ full_answer += data["content"]
588
+ answer_placeholder.markdown(full_answer)
589
+ elif data["type"] == "error":
590
+ st.error(data["content"])
591
+
592
+ # Update logic trace UI
593
+ logic_placeholder.markdown("\n".join(logic_trace))
594
+ except Exception:
595
+ continue
596
+
597
+ status_msg.update(label="Solved!", state="complete", expanded=False)
598
+ st.session_state.is_processing = False
599
+
600
+ if full_answer:
601
  add_message(
602
+ "assistant",
603
+ full_answer,
604
+ reasoning="\n".join(logic_trace),
605
+ metadata={"source": "agent"}
 
 
 
 
 
 
606
  )
607
+ load_sessions() # Update titles
 
 
 
 
 
 
608
  st.rerun()
609
+ elif r.status_code == 401:
610
+ _clear_user_state()
611
+ st.session_state.user = None
612
+ st.error("Session expired. Please log in again.")
613
  st.rerun()
 
614
  else:
615
+ st.error(f"Error: {r.status_code}")
 
 
 
 
616
  st.session_state.is_processing = False
 
617
 
618
+ except Exception as e:
619
+ st.error(f"Connection error: {e}")
620
+ st.session_state.is_processing = False
621
+ st.rerun()
622
+
623
 
624
  # ====================================================
625
  # Sidebar
 
628
  st.markdown("### 🧠 MathMinds")
629
  st.write(f"Logged in as **{st.session_state.user['email']}**")
630
 
631
+ view = st.radio(
632
+ "Navigation", ["Chat", "Profile"],
633
+ index=0 if st.session_state.current_view == "Chat" else 1
634
+ )
635
  if view != st.session_state.current_view:
636
  st.session_state.current_view = view
637
  st.rerun()
638
 
639
  if st.button("Sign Out", type="secondary"):
640
+ # ✅ MULTIUSER FIX: Wipe ALL user-specific state first, THEN clear identity.
641
+ # Without _clear_user_state() here, the next user to log in on the same
642
+ # browser tab would see User A's chat history briefly before load_sessions
643
+ # returns, because st.session_state persists across logins within a tab.
644
+ _clear_user_state()
645
  st.session_state.user = None
646
  st.rerun()
647
 
648
+ if st.session_state.is_processing:
649
+ if st.button(
650
+ "🔓 Reset Processing Lock", type="primary",
651
+ help="Use if UI is stuck despite answer finishing."
652
+ ):
653
+ st.session_state.is_processing = False
654
+ st.rerun()
655
+
656
  st.divider()
657
 
658
  if st.session_state.current_view == "Chat":
 
660
  new_chat()
661
 
662
  st.markdown("#### History")
663
+
664
+ for session in st.session_state.chat_sessions:
665
+ sid = session["session_id"]
666
+ title = session["title"]
667
+
668
+ cols = st.columns([0.8, 0.1, 0.1])
669
+ with cols[0]:
670
+ is_active = (st.session_state.active_session_id == sid)
671
+ btn_type = "primary" if is_active else "secondary"
672
+ if st.button(title, key=f"sel_{sid}", use_container_width=True, type=btn_type):
 
 
673
  st.session_state.active_session_id = sid
674
+ load_messages(sid)
675
+ st.rerun()
676
+ with cols[1]:
677
+ if st.button("🖊️", key=f"ren_{sid}", help="Rename"):
678
+ st.session_state.renaming_session_id = (
679
+ sid if st.session_state.renaming_session_id != sid else None
680
+ )
681
  st.rerun()
682
+ with cols[2]:
683
+ if st.button("🗑️", key=f"del_{sid}", help="Delete"):
684
  delete_chat(sid)
685
 
686
+ if st.session_state.renaming_session_id == sid:
687
+ with st.container():
688
+ new_title = st.text_input(
689
+ "New title", value=title,
690
+ key=f"in_{sid}", label_visibility="collapsed"
691
+ )
692
+ if st.button("Save", key=f"save_{sid}", use_container_width=True):
693
+ rename_chat(sid, new_title)
694
+ st.session_state.renaming_session_id = None
695
+ st.rerun()
696
+
697
+
698
+ # ====================================================
699
+ # Main Content Area
700
+ # ====================================================
701
+ if st.session_state.current_view == "Chat":
702
+ chat_interface()
703
+ elif st.session_state.current_view == "Profile":
704
+ profile_interface()
requirements.txt CHANGED
@@ -33,6 +33,9 @@ pandas
33
  gradio
34
  transformers
35
  torch
 
 
 
36
 
37
  selenium
38
  webdriver-manager
 
33
  gradio
34
  transformers
35
  torch
36
+ passlib[bcrypt]
37
+ bcrypt
38
+ PyJWT
39
 
40
  selenium
41
  webdriver-manager