kaushik1064 commited on
Commit
886572e
·
1 Parent(s): 84e9d55

Add backend FastAPI code

Browse files
Files changed (43) hide show
  1. Dockerfile +18 -0
  2. app/__init__.py +6 -0
  3. app/__pycache__/__init__.cpython-313.pyc +0 -0
  4. app/__pycache__/config.cpython-313.pyc +0 -0
  5. app/__pycache__/guardrails.cpython-313.pyc +0 -0
  6. app/__pycache__/logger.cpython-313.pyc +0 -0
  7. app/__pycache__/main.cpython-313.pyc +0 -0
  8. app/__pycache__/schemas.cpython-313.pyc +0 -0
  9. app/config.py +104 -0
  10. app/guardrails.py +100 -0
  11. app/logger.py +37 -0
  12. app/main.py +156 -0
  13. app/schemas.py +117 -0
  14. app/services/__pycache__/kb_updater.cpython-313.pyc +0 -0
  15. app/services/__pycache__/retrieval.cpython-313.pyc +0 -0
  16. app/services/__pycache__/vector_store.cpython-313.pyc +0 -0
  17. app/services/kb_updater.py +56 -0
  18. app/services/retrieval.py +109 -0
  19. app/services/vector_store.py +348 -0
  20. app/tools/__init__.py +2 -0
  21. app/tools/__pycache__/__init__.cpython-313.pyc +0 -0
  22. app/tools/__pycache__/audio.cpython-313.pyc +0 -0
  23. app/tools/__pycache__/dspy_pipeline.cpython-313.pyc +0 -0
  24. app/tools/__pycache__/validator.cpython-313.pyc +0 -0
  25. app/tools/__pycache__/vision.cpython-313.pyc +0 -0
  26. app/tools/__pycache__/web_search.cpython-313.pyc +0 -0
  27. app/tools/audio.py +41 -0
  28. app/tools/dspy_pipeline.py +476 -0
  29. app/tools/validator.py +28 -0
  30. app/tools/vision.py +33 -0
  31. app/tools/web_search.py +395 -0
  32. app/workflows/__init__.py +3 -0
  33. app/workflows/__pycache__/__init__.cpython-313.pyc +0 -0
  34. app/workflows/__pycache__/langgraph_pipeline.cpython-313.pyc +0 -0
  35. app/workflows/langgraph_pipeline.py +251 -0
  36. backend/data/feedback_db.json +176 -0
  37. data/knowledge_base.jsonl +6 -0
  38. mcp_servers/tavily_server.py +65 -0
  39. poetry.lock +0 -0
  40. pyproject.toml +44 -0
  41. render.yaml +30 -0
  42. scripts/build_kb.py +62 -0
  43. scripts/build_weaviate_kb.py +102 -0
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13.0
2
+
3
+ WORKDIR /code
4
+
5
+ # Install Poetry
6
+ RUN pip install poetry
7
+
8
+ # Copy your project files
9
+ COPY pyproject.toml poetry.lock ./
10
+
11
+ # Install project dependencies
12
+ RUN poetry install --no-interaction --no-ansi
13
+
14
+ COPY . .
15
+
16
+ EXPOSE 7860
17
+
18
+ CMD ["poetry", "run", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Math Agent FastAPI application package."""
2
+
3
+ from .main import create_app
4
+
5
+ __all__ = ["create_app"]
6
+
app/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (259 Bytes). View file
 
app/__pycache__/config.cpython-313.pyc ADDED
Binary file (3.62 kB). View file
 
app/__pycache__/guardrails.cpython-313.pyc ADDED
Binary file (4.77 kB). View file
 
app/__pycache__/logger.cpython-313.pyc ADDED
Binary file (2.03 kB). View file
 
app/__pycache__/main.cpython-313.pyc ADDED
Binary file (7.95 kB). View file
 
app/__pycache__/schemas.cpython-313.pyc ADDED
Binary file (4.86 kB). View file
 
app/config.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Application configuration using Pydantic settings."""
2
+
3
+ from functools import lru_cache
4
+ import os
5
+ from typing import List
6
+ from pathlib import Path
7
+
8
+ from pydantic_settings import BaseSettings, SettingsConfigDict
9
+
10
+ # Resolve .env path robustly so it works whether .env is at repo root or backend/.env
11
+ # This file is at backend/app/config.py → repo root is two levels up
12
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
13
+ _BACKEND_DIR = _REPO_ROOT / "backend"
14
+ _ENV_CANDIDATES = [
15
+ _REPO_ROOT / ".env", # D:/ai_planet/.env
16
+ _BACKEND_DIR / ".env", # D:/ai_planet/backend/.env
17
+ Path.cwd() / ".env", # fallback to CWD
18
+ ]
19
+ _ENV_FILE = next((p for p in _ENV_CANDIDATES if p.exists()), _REPO_ROOT / ".env")
20
+
21
+
22
+ class Settings(BaseSettings):
23
+ """Runtime configuration loaded from environment or .env file."""
24
+
25
+ model_config = SettingsConfigDict(env_file=str(_ENV_FILE), env_file_encoding="utf-8", extra="allow")
26
+
27
+ app_name: str = "Math Agentic RAG"
28
+ environment: str = "local"
29
+ debug: bool = True
30
+
31
+ # Vector store configuration (Weaviate)
32
+ weaviate_url: str | None = None
33
+ weaviate_api_key: str | None = None
34
+ weaviate_collection: str = "mathvectors" # Collection name in Weaviate cloud
35
+ embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
36
+ top_k: int = 4
37
+ similarity_threshold: float = 0.80
38
+
39
+ # Guardrail keywords for allow-listing math topics
40
+ allowed_subjects: List[str] = [
41
+ "algebra",
42
+ "geometry",
43
+ "calculus",
44
+ "probability",
45
+ "statistics",
46
+ "number theory",
47
+ "trigonometry",
48
+ "combinatorics",
49
+ "math",
50
+ ]
51
+
52
+ blocked_keywords: List[str] = [
53
+ "violence",
54
+ "weapon",
55
+ "politics",
56
+ "hate",
57
+ "self-harm",
58
+ "explicit",
59
+ ]
60
+
61
+ # External service credentials
62
+ tavily_api_key: str | None = None
63
+ mcp_tavily_url: str | None = None # e.g., https://mcp.tavily.com/mcp/?tavilyApiKey=tvly-...
64
+ gemini_api_key: str | None = None
65
+ groq_api_key: str | None = None
66
+
67
+ # Model configuration
68
+ gemini_model: str = "gemini-2.5-flash" # Fast model for search grounding and vision
69
+ dspy_model: str = "gemini-2.5-flash" # For DSPy pipeline
70
+ whisper_model: str = "whisper-large-v3-turbo" # For audio transcription
71
+ dspy_max_tokens: int = 2048
72
+
73
+ # Feedback storage
74
+ feedback_store_path: str = "backend/data/feedback_db.json"
75
+
76
+ # Gateway / guardrails toggles
77
+ enforce_input_guardrails: bool = True
78
+ enforce_output_guardrails: bool = True
79
+
80
+
81
+ @lru_cache(maxsize=1)
82
+ def get_settings() -> Settings:
83
+ """Return a cached Settings instance."""
84
+ settings = Settings()
85
+
86
+ # Allow alternate env var aliases for KB similarity threshold
87
+ # Priority: KB_SIMILARITY_THRESHOLD -> KB_THRESHOLD -> existing value
88
+ kb_thresh_env = (
89
+ os.getenv("KB_SIMILARITY_THRESHOLD")
90
+ or os.getenv("KB_THRESHOLD")
91
+ )
92
+ if kb_thresh_env:
93
+ try:
94
+ settings.similarity_threshold = float(kb_thresh_env)
95
+ except ValueError:
96
+ # Ignore invalid values and keep existing
97
+ pass
98
+
99
+ return settings
100
+
101
+
102
+ settings = get_settings()
103
+
104
+
app/guardrails.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Guardrail functions for input/output moderation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import Tuple
7
+
8
+ from .config import settings
9
+ from .logger import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class GuardrailViolation(Exception):
15
+ """Raised when a guardrail condition is violated."""
16
+
17
+ def __init__(self, message: str, code: str = "guardrail_violation") -> None:
18
+ super().__init__(message)
19
+ self.code = code
20
+ self.message = message
21
+
22
+
23
+ SANITIZE_REGEX = re.compile(r"[^\x20-\x7E]+")
24
+ PII_REGEX = re.compile(r"\b(\d{3}-?\d{2}-?\d{4}|\d{16}|[A-Z]{5}[0-9]{4}[A-Z]{1})\b")
25
+
26
+
27
+ def sanitize_text(text: str) -> str:
28
+ """Remove non-ascii characters and collapse whitespace."""
29
+
30
+ cleaned = SANITIZE_REGEX.sub("", text)
31
+ cleaned = re.sub(r"\s+", " ", cleaned).strip()
32
+ return cleaned
33
+
34
+
35
+ def run_input_guardrails(user_text: str) -> str:
36
+ """Validate and sanitize user input before routing.
37
+
38
+ Note: Since Gemini is used to validate math questions, we're very permissive here.
39
+ Only block obvious non-math content and security issues.
40
+ """
41
+
42
+ logger.debug("guardrails.input.start", text=user_text)
43
+
44
+ text = sanitize_text(user_text)
45
+
46
+ # Only block PII - this is a security issue
47
+ if PII_REGEX.search(text):
48
+ logger.warning("guardrails.input.blocked", reason="pii_match")
49
+ raise GuardrailViolation("The request may contain sensitive information. Please remove it and try again.")
50
+
51
+ lower = text.lower()
52
+
53
+ # Block only obvious non-educational content - let Gemini decide if it's math
54
+ if any(keyword in lower for keyword in settings.blocked_keywords):
55
+ logger.warning("guardrails.input.blocked", reason="blocked_keyword")
56
+ raise GuardrailViolation("I can only help with mathematics-related educational questions.")
57
+
58
+ # Very permissive - allow anything that could be math-related
59
+ # Gemini will be the final validator, so we just do basic sanity checks
60
+ # Allow anything with math-like patterns, or let it through and let Gemini decide
61
+
62
+ logger.debug("guardrails.input.pass")
63
+ return text
64
+
65
+
66
+ def run_output_guardrails(response_text: str) -> Tuple[str, list[str]]:
67
+ """Apply output guardrails - very permissive since Gemini validates math content.
68
+
69
+ Only blocks obvious non-math content. Most validation is done by Gemini.
70
+ """
71
+
72
+ logger.debug("guardrails.output.start")
73
+ text = sanitize_text(response_text)
74
+
75
+ citations: list[str] = []
76
+
77
+ # Only block PII - security issue
78
+ if PII_REGEX.search(text):
79
+ logger.warning("guardrails.output.flag", reason="pii_detected")
80
+ text = re.sub(PII_REGEX, "[redacted]", text)
81
+
82
+ # Extract URLs as citations
83
+ if "http" in text.lower():
84
+ citations = re.findall(r"(https?://\S+)", text)
85
+
86
+ # Very permissive - allow almost everything
87
+ # Since Gemini generates the responses, it should be math-related
88
+ # Only block if it's clearly and obviously non-math (e.g., contains blocked keywords)
89
+ lower = text.lower()
90
+
91
+ # Only block if it contains obviously inappropriate content
92
+ if any(keyword in lower for keyword in settings.blocked_keywords):
93
+ logger.warning("guardrails.output.blocked", reason="blocked_keyword", preview=text[:100])
94
+ raise GuardrailViolation("The generated response contains inappropriate content.")
95
+
96
+ # Everything else passes - let Gemini handle math validation
97
+ logger.debug("guardrails.output.pass", citations=len(citations))
98
+ return text, citations
99
+
100
+
app/logger.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Structured logging helpers."""
2
+
3
+ import logging
4
+ from typing import Any
5
+
6
+ import structlog
7
+
8
+
9
+ def configure_logging(level: int = logging.INFO) -> None:
10
+ """Configure structlog for the application."""
11
+
12
+ timestamper = structlog.processors.TimeStamper(fmt="iso")
13
+
14
+ structlog.configure(
15
+ processors=[
16
+ structlog.stdlib.add_logger_name,
17
+ structlog.stdlib.add_log_level,
18
+ timestamper,
19
+ structlog.processors.StackInfoRenderer(),
20
+ structlog.processors.format_exc_info,
21
+ structlog.processors.UnicodeDecoder(),
22
+ structlog.dev.ConsoleRenderer() if level <= logging.DEBUG else structlog.processors.JSONRenderer(),
23
+ ],
24
+ context_class=dict,
25
+ logger_factory=structlog.stdlib.LoggerFactory(),
26
+ cache_logger_on_first_use=True,
27
+ )
28
+
29
+ logging.basicConfig(level=level, format="%(message)s")
30
+
31
+
32
+ def get_logger(*args: Any, **kwargs: Any) -> structlog.stdlib.BoundLogger:
33
+ """Return a configured structlog logger."""
34
+
35
+ return structlog.get_logger(*args, **kwargs)
36
+
37
+
app/main.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI entrypoint for the math agent backend."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import os
7
+ from typing import AsyncIterator
8
+
9
+ from contextlib import asynccontextmanager
10
+
11
+ from fastapi import Depends, FastAPI, HTTPException, Request
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.responses import JSONResponse
14
+
15
+ from .config import settings
16
+ from .guardrails import GuardrailViolation
17
+ from .logger import configure_logging, get_logger
18
+ from .schemas import AgentResponse, ChatRequest, FeedbackRequest
19
+ from .services.retrieval import MathAgent
20
+ from .services.vector_store import (
21
+ load_vector_store,
22
+ save_feedback_to_queue,
23
+ )
24
+ from .services.kb_updater import update_knowledge_base
25
+ from .tools.audio import transcribe_audio
26
+ from .tools.validator import validate_user_solution
27
+ from .tools.vision import extract_text_from_image
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ def create_app() -> FastAPI:
33
+ configure_logging()
34
+
35
+ app = FastAPI(title=settings.app_name, version="0.1.0")
36
+
37
+ # Configure CORS. Accept a single FRONTEND_URL or a comma-separated
38
+ # FRONTEND_URLS env var. Default to common dev ports (5173 and 3000).
39
+ frontend_env = os.getenv("FRONTEND_URL", os.getenv("FRONTEND_URLS", "http://localhost:5173,http://localhost:3000"))
40
+ # split and strip any whitespace
41
+ allowed_origins = [o.strip() for o in frontend_env.split(",") if o.strip()]
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=allowed_origins,
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+ @asynccontextmanager
51
+ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
52
+ # Startup
53
+ logger.info("app.startup")
54
+ app.state.vector_store = await asyncio.to_thread(load_vector_store)
55
+
56
+ yield # Server is running
57
+
58
+ # Shutdown
59
+ logger.info("app.shutdown")
60
+ # Clean up any resources if needed
61
+
62
+ app.router.lifespan_context = lifespan
63
+
64
+ def get_agent() -> MathAgent:
65
+ vector_store = getattr(app.state, "vector_store", None)
66
+ if vector_store is None:
67
+ vector_store = load_vector_store()
68
+ app.state.vector_store = vector_store
69
+ return MathAgent(vector_store=vector_store)
70
+
71
+ @app.middleware("http")
72
+ async def add_process_time_header(request: Request, call_next): # type: ignore[override]
73
+ response = await call_next(request)
74
+ response.headers["X-App-Env"] = settings.environment
75
+ return response
76
+
77
+ @app.get("/health")
78
+ async def health() -> dict[str, str]:
79
+ return {"status": "ok", "environment": settings.environment}
80
+
81
+ @app.post("/api/chat", response_model=AgentResponse, responses={400: {"description": "Guardrail failure"}})
82
+ async def chat_endpoint(payload: ChatRequest, agent: MathAgent = Depends(get_agent)) -> AgentResponse:
83
+ logger.info("chat.request", modality=payload.modality)
84
+
85
+ query = payload.query
86
+ if payload.modality == "audio" and payload.audio_base64:
87
+ query = transcribe_audio(payload.audio_base64)
88
+ elif payload.modality == "image" and payload.image_base64:
89
+ query = extract_text_from_image(payload.image_base64)
90
+
91
+ try:
92
+ response = await agent.handle_query(query)
93
+ return response
94
+ except GuardrailViolation as exc:
95
+ raise HTTPException(status_code=400, detail=exc.message) from exc
96
+ except Exception as exc: # pragma: no cover - defensive logging
97
+ logger.exception("chat.error", error=str(exc))
98
+ raise HTTPException(status_code=500, detail="Unexpected error handling query") from exc
99
+
100
+ @app.post("/api/feedback")
101
+ async def feedback_endpoint(request: FeedbackRequest) -> JSONResponse:
102
+ """Handle user feedback with optional solution upload."""
103
+ logger.info(
104
+ "feedback.received",
105
+ message_id=request.message_id,
106
+ helpful=request.feedback.thumbs_up,
107
+ issue=request.feedback.primary_issue
108
+ )
109
+
110
+ # Always save the feedback first
111
+ record = request.model_dump()
112
+ save_feedback_to_queue(record)
113
+
114
+ # If it's negative feedback with a solution
115
+ if not request.feedback.thumbs_up and request.feedback.has_better_solution:
116
+ solution = None
117
+
118
+ # Get solution based on type
119
+ if request.feedback.solution_type == "text":
120
+ solution = request.feedback.better_solution_text
121
+ elif request.feedback.solution_type == "pdf":
122
+ # TODO: Extract text from PDF
123
+ solution = request.feedback.better_solution_text
124
+ elif request.feedback.solution_type == "image":
125
+ # Use vision model to extract solution from image
126
+ if request.feedback.better_solution_image_base64:
127
+ solution = extract_text_from_image(
128
+ request.feedback.better_solution_image_base64,
129
+ "Extract the mathematical solution from this image."
130
+ )
131
+
132
+ if solution:
133
+ # Validate and update KB
134
+ success = await update_knowledge_base(request.query, solution)
135
+ return JSONResponse({
136
+ "status": "ok",
137
+ "feedback_saved": True,
138
+ "kb_updated": success
139
+ })
140
+
141
+ return JSONResponse({
142
+ "status": "ok",
143
+ "feedback_saved": True
144
+ })
145
+
146
+ @app.get("/api/vector-store/reload")
147
+ async def reload_vector_store() -> dict[str, str]:
148
+ app.state.vector_store = await asyncio.to_thread(load_vector_store, True)
149
+ return {"status": "reloaded"}
150
+
151
+ return app
152
+
153
+
154
+ app = create_app()
155
+
156
+
app/schemas.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic schemas for API endpoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, List, Optional
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class Step(BaseModel):
11
+ """A solution step with title, content and optional LaTeX expression."""
12
+
13
+ title: str
14
+ content: str
15
+ expression: Optional[str] = None
16
+
17
+
18
+ class Citation(BaseModel):
19
+ """Represents a web citation."""
20
+
21
+ title: str
22
+ url: str
23
+
24
+
25
+ class RetrievalContext(BaseModel):
26
+ """Context snippet from knowledge base."""
27
+
28
+ document_id: str
29
+ question: str
30
+ answer: str
31
+ similarity: float
32
+
33
+
34
+ class AgentResponse(BaseModel):
35
+ """Structured agent response."""
36
+
37
+ answer: str
38
+ steps: List[Step]
39
+ retrieved_from_kb: bool = False
40
+ knowledge_hits: List[RetrievalContext] = Field(default_factory=list)
41
+ citations: List[Citation] = Field(default_factory=list)
42
+ source: str = Field(default="kb") # kb | tavily | gemini
43
+ feedback_required: bool = True
44
+ gateway_trace: List[str] = Field(default_factory=list)
45
+
46
+
47
+ class ChatRequest(BaseModel):
48
+ """Incoming chat message payload."""
49
+
50
+ query: str
51
+ modality: str = Field(default="text", description="text|image|audio")
52
+ image_base64: Optional[str] = None
53
+ audio_base64: Optional[str] = None
54
+
55
+
56
+ class FeedbackMetadata(BaseModel):
57
+ """User feedback payload."""
58
+
59
+ # Minimal - Always Ask (5 seconds)
60
+ thumbs_up: bool = Field(description="👍 Helpful / 👎 Not helpful")
61
+
62
+ # If thumbs_down, expand to next level (10 seconds)
63
+ primary_issue: Optional[str] = Field(
64
+ None,
65
+ description="What went wrong?",
66
+ # These match exactly with the frontend enums
67
+ enum=["wrong-answer", "unclear", "missing-steps", "wrong-method"]
68
+ )
69
+
70
+ # Optional user solution upload (30-60 seconds)
71
+ has_better_solution: bool = False
72
+ solution_type: Optional[str] = Field(
73
+ None,
74
+ description="Type of solution upload",
75
+ enum=["text", "pdf", "image"]
76
+ )
77
+ better_solution_text: Optional[str] = None
78
+ better_solution_pdf_base64: Optional[str] = None
79
+ better_solution_image_base64: Optional[str] = None
80
+
81
+
82
+ class FeedbackRequest(BaseModel):
83
+ """Feedback submission request."""
84
+
85
+ message_id: str
86
+ query: str
87
+ agent_response: AgentResponse
88
+ feedback: FeedbackMetadata
89
+
90
+
91
+ class BenchmarkResult(BaseModel):
92
+ """Result for a single benchmark item."""
93
+
94
+ question_id: str
95
+ reference_answer: str
96
+ agent_answer: str
97
+ score: float
98
+ source: str
99
+
100
+
101
+ class BenchmarkSummary(BaseModel):
102
+ """Aggregate benchmark output."""
103
+
104
+ dataset: str
105
+ total_questions: int
106
+ average_score: float
107
+ details: List[BenchmarkResult]
108
+
109
+
110
+ class ErrorResponse(BaseModel):
111
+ """Error response envelope."""
112
+
113
+ detail: str
114
+ code: str = Field(default="error")
115
+ context: Optional[dict[str, Any]] = None
116
+
117
+
app/services/__pycache__/kb_updater.cpython-313.pyc ADDED
Binary file (2.16 kB). View file
 
app/services/__pycache__/retrieval.cpython-313.pyc ADDED
Binary file (4.97 kB). View file
 
app/services/__pycache__/vector_store.cpython-313.pyc ADDED
Binary file (16 kB). View file
 
app/services/kb_updater.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Knowledge base update service for validated user solutions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import Optional
7
+
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ from ..config import settings
11
+ from ..logger import get_logger
12
+ from ..tools.validator import validate_user_solution
13
+ from .vector_store import load_vector_store
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ async def update_knowledge_base(question: str, solution: str, source: str = "user-feedback") -> bool:
19
+ """Validate and add a new solution to the knowledge base."""
20
+
21
+ try:
22
+ # First validate the solution
23
+ is_valid = await asyncio.to_thread(validate_user_solution, question, solution)
24
+
25
+ if not is_valid:
26
+ logger.warning("kb_update.validation_failed", question=question)
27
+ return False
28
+
29
+ # Load vector store
30
+ vector_store = load_vector_store()
31
+ encoder = SentenceTransformer(settings.embedding_model_name)
32
+
33
+ # Generate embedding
34
+ text = question + "\n" + solution
35
+ embedding = encoder.encode(text)
36
+
37
+ # Add to Weaviate
38
+ data_object = {
39
+ "question": question,
40
+ "answer": solution,
41
+ "source": source
42
+ }
43
+
44
+ # Add with vector
45
+ vector_store.client.data_object.create(
46
+ data_object=data_object,
47
+ class_name=settings.weaviate_class_name,
48
+ vector=embedding.tolist()
49
+ )
50
+
51
+ logger.info("kb_update.success", question=question, source=source)
52
+ return True
53
+
54
+ except Exception as exc:
55
+ logger.error("kb_update.failed", error=str(exc), question=question)
56
+ return False
app/services/retrieval.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retrieval and routing logic for the math agent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ..guardrails import GuardrailViolation
6
+ from ..logger import get_logger
7
+ from ..schemas import AgentResponse, Step
8
+ from ..workflows.langgraph_pipeline import build_math_agent_graph
9
+ from .vector_store import VectorStore, load_vector_store
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class MathAgent:
15
+ """High-level orchestration for handling a math query."""
16
+
17
+ def __init__(self, vector_store: VectorStore | None = None) -> None:
18
+ self.vector_store = vector_store or load_vector_store()
19
+ self.graph = build_math_agent_graph(self.vector_store)
20
+
21
+ async def handle_query(self, query: str) -> AgentResponse:
22
+ """Run the query through guardrails, routing, and generation."""
23
+
24
+ try:
25
+ result_state = await self.graph.ainvoke({"query": query, "gateway_trace": []})
26
+ except GuardrailViolation:
27
+ raise
28
+ except Exception as exc: # pragma: no cover - defensive logging
29
+ logger.exception("math_agent.graph_failure", error=str(exc))
30
+ raise
31
+
32
+ steps_raw = result_state.get("steps", []) or []
33
+
34
+ # Get KB hits and web hits separately, then combine for display
35
+ kb_hits = result_state.get("kb_hits", []) or []
36
+ web_hits = result_state.get("web_hits", []) or []
37
+ # Combine for knowledge_hits display (show both KB and web)
38
+ knowledge_hits = list(kb_hits) + list(web_hits)
39
+ # If no separate hits, fall back to combined knowledge_hits
40
+ if not knowledge_hits:
41
+ knowledge_hits = result_state.get("knowledge_hits", []) or []
42
+
43
+ citations = result_state.get("citations", []) or []
44
+ answer_text = result_state.get("answer", "") or ""
45
+
46
+ # If answer is too short (just overview), try to enhance it with steps
47
+ if len(answer_text.strip()) < 50 and steps_raw:
48
+ # Answer might be incomplete - use steps to build a better answer
49
+ step_texts = []
50
+ for step in steps_raw:
51
+ if isinstance(step, dict):
52
+ step_texts.append(step.get("content", "") or step.get("explanation", ""))
53
+ elif isinstance(step, (list, tuple)) and len(step) > 1:
54
+ step_texts.append(step[1])
55
+
56
+ if step_texts:
57
+ # Find the last step that looks like a conclusion
58
+ for step_text in reversed(step_texts):
59
+ if any(word in step_text.lower() for word in ["therefore", "answer", "result", "thus", "hence"]):
60
+ answer_text = step_text
61
+ break
62
+
63
+ # If still no good answer, use the last step
64
+ if len(answer_text.strip()) < 50:
65
+ answer_text = step_texts[-1] if step_texts else answer_text
66
+
67
+ retrieved_from_kb = result_state.get("retrieved_from_kb", False)
68
+ source = result_state.get("source", "kb+web" if (kb_hits and web_hits) else ("kb" if retrieved_from_kb else "gemini"))
69
+ gateway_trace = result_state.get("gateway_trace", [])
70
+
71
+ # Convert steps to new format with content and optional expressions
72
+ formatted_steps = []
73
+ for step in steps_raw:
74
+ # steps_raw is now a list of dicts with title/content/expression
75
+ if isinstance(step, dict):
76
+ formatted_steps.append(Step(
77
+ title=step.get("title", ""),
78
+ content=step.get("content", ""),
79
+ expression=step.get("expression")
80
+ ))
81
+ else:
82
+ # Handle legacy tuple format for backward compatibility
83
+ title, explanation = step
84
+ formatted_steps.append(Step(
85
+ title=title,
86
+ content=explanation
87
+ ))
88
+
89
+ response = AgentResponse(
90
+ answer=answer_text,
91
+ steps=formatted_steps,
92
+ retrieved_from_kb=retrieved_from_kb,
93
+ knowledge_hits=knowledge_hits,
94
+ citations=citations,
95
+ source=source,
96
+ gateway_trace=gateway_trace,
97
+ )
98
+
99
+ logger.info(
100
+ "math_agent.completed",
101
+ source=response.source,
102
+ retrieved=retrieved_from_kb,
103
+ kb_hits=len(knowledge_hits),
104
+ trace=gateway_trace,
105
+ )
106
+
107
+ return response
108
+
109
+
app/services/vector_store.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vector store utilities using Weaviate cloud for knowledge base retrieval."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import List, Dict, Any
8
+
9
+ import numpy as np
10
+ import weaviate
11
+ from sentence_transformers import SentenceTransformer
12
+ from weaviate.collections.classes.config import DataType
13
+
14
+ from ..config import settings
15
+ from ..logger import get_logger
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class VectorStore:
21
+ """Wrapper around Weaviate vector store."""
22
+
23
+ def __init__(self, client: weaviate.WeaviateClient, encoder: SentenceTransformer) -> None:
24
+ self.client = client
25
+ self.encoder = encoder
26
+ self.collection = self._ensure_collection()
27
+ self._property_map = self._detect_property_names()
28
+
29
+ def _ensure_collection(self) -> weaviate.collections.Collection:
30
+ """Create collection if it doesn't exist."""
31
+ collection_name = settings.weaviate_collection
32
+ collections = self.client.collections.list_all()
33
+
34
+ # Check for exact or case-insensitive match to avoid 422 when class
35
+ # already exists with different casing (e.g. 'Mathvectors').
36
+ lower_names = [n.lower() for n in collections]
37
+ if collection_name not in collections and collection_name.lower() not in lower_names:
38
+ try:
39
+ collection = self.client.collections.create(
40
+ name=collection_name,
41
+ vectorizer_config=weaviate.classes.config.Configure.Vectorizer.none(),
42
+ vector_index_config=weaviate.classes.config.Configure.VectorIndex.hnsw(),
43
+ properties=[
44
+ {"name": "question", "data_type": DataType.TEXT},
45
+ {"name": "answer", "data_type": DataType.TEXT},
46
+ {"name": "source", "data_type": DataType.TEXT},
47
+ ]
48
+ )
49
+ logger.info("vector_store.collection.created", name=collection_name)
50
+ except weaviate.exceptions.UnexpectedStatusCodeError as e:
51
+ # If the create failed because the class already exists (422),
52
+ # recover by locating the existing class name (case-insensitive)
53
+ msg = str(e)
54
+ if "already exists" in msg or "class name" in msg:
55
+ logger.warning("vector_store.collection.create_conflict", msg=msg)
56
+ # refresh list and try to find the existing class name
57
+ collections = self.client.collections.list_all()
58
+ existing = None
59
+ for n in collections:
60
+ if n.lower() == collection_name.lower():
61
+ existing = n
62
+ break
63
+ if existing:
64
+ collection = self.client.collections.get(existing)
65
+ logger.info("vector_store.collection.exists", name=existing)
66
+ else:
67
+ # re-raise if we can't resolve the conflict
68
+ raise
69
+ else:
70
+ raise
71
+ else:
72
+ # exact or case-insensitive match found — pick the actual existing name
73
+ if collection_name in collections:
74
+ existing_name = collection_name
75
+ else:
76
+ existing_name = next(n for n in collections if n.lower() == collection_name.lower())
77
+ collection = self.client.collections.get(existing_name)
78
+ logger.info("vector_store.collection.exists", name=existing_name)
79
+
80
+ return collection
81
+
82
+ def _detect_property_names(self) -> Dict[str, str]:
83
+ """Detect the actual property names in the collection schema.
84
+
85
+ Returns a mapping from logical names (question, answer, source) to actual property names.
86
+ """
87
+ property_map = {}
88
+ prop_names = []
89
+
90
+ try:
91
+ # Get the collection configuration to see what properties exist
92
+ config = self.collection.config.get()
93
+
94
+ # Try different ways to access properties depending on Weaviate v4 structure
95
+ if hasattr(config, 'properties'):
96
+ props = config.properties
97
+ if props:
98
+ # Properties might be a list or iterable
99
+ if isinstance(props, (list, tuple)):
100
+ prop_names = [prop.name if hasattr(prop, 'name') else str(prop) for prop in props]
101
+ elif hasattr(props, '__iter__') and not isinstance(props, (str, bytes)):
102
+ prop_names = [prop.name if hasattr(prop, 'name') else str(prop) for prop in props]
103
+
104
+ # Also try accessing as attributes
105
+ if not prop_names and hasattr(config, 'properties'):
106
+ try:
107
+ props_dict = config.properties.__dict__ if hasattr(config.properties, '__dict__') else {}
108
+ prop_names = list(props_dict.keys()) if props_dict else []
109
+ except:
110
+ pass
111
+
112
+ except Exception as exc:
113
+ logger.warning("vector_store.schema.detection_error", error=str(exc))
114
+
115
+ # Try to map to our expected names
116
+ # Check for new schema (question, answer, source) first
117
+ if prop_names:
118
+ if "question" in prop_names:
119
+ property_map["question"] = "question"
120
+ elif "input" in prop_names:
121
+ property_map["question"] = "input"
122
+
123
+ if "answer" in prop_names:
124
+ property_map["answer"] = "answer"
125
+ elif "label" in prop_names:
126
+ property_map["answer"] = "label"
127
+
128
+ if "source" in prop_names:
129
+ property_map["source"] = "source"
130
+ elif "source_file" in prop_names:
131
+ property_map["source"] = "source_file"
132
+ else:
133
+ # No properties detected - use fallback (old schema)
134
+ logger.warning("vector_store.schema.detection_failed", msg="Could not detect properties, using fallback")
135
+ property_map = {"question": "input", "answer": "label", "source": "source_file"}
136
+
137
+ # If detection partially failed, fill in missing mappings
138
+ if "question" not in property_map:
139
+ property_map["question"] = "input"
140
+ if "answer" not in property_map:
141
+ property_map["answer"] = "label"
142
+ if "source" not in property_map:
143
+ property_map["source"] = "source_file"
144
+
145
+ logger.info("vector_store.schema.detected", mapping=property_map, found_properties=prop_names)
146
+ return property_map
147
+
148
+ def search(self, query: str, top_k: int | None = None) -> List[Dict[str, Any]]:
149
+ """Search for similar questions and return contexts."""
150
+ if top_k is None:
151
+ top_k = settings.top_k
152
+
153
+ # Generate embedding for query
154
+ embedding = self.encoder.encode(query)
155
+ embedding = embedding / np.linalg.norm(embedding) # Normalize
156
+
157
+ try:
158
+ # Get the actual property names from the detected schema
159
+ question_prop = self._property_map.get("question", "input")
160
+ answer_prop = self._property_map.get("answer", "label")
161
+ source_prop = self._property_map.get("source", "source_file")
162
+
163
+ # Build Weaviate v4 query with near_vector using actual property names
164
+ response = self.collection.query.near_vector(
165
+ near_vector=embedding.tolist(),
166
+ certainty=settings.similarity_threshold,
167
+ limit=top_k,
168
+ return_metadata=["certainty"],
169
+ return_properties=[question_prop, answer_prop, source_prop]
170
+ )
171
+ except Exception as exc: # pragma: no cover - defensive
172
+ # If query failed, try with old schema as fallback
173
+ error_msg = str(exc).lower()
174
+ if "no such prop" in error_msg or "property" in error_msg:
175
+ logger.warning("vector_store.search.schema_mismatch", error=str(exc), trying="old_schema")
176
+ try:
177
+ # Fallback to old schema property names
178
+ response = self.collection.query.near_vector(
179
+ near_vector=embedding.tolist(),
180
+ certainty=settings.similarity_threshold,
181
+ limit=top_k,
182
+ return_metadata=["certainty"],
183
+ return_properties=["input", "label", "source_file"]
184
+ )
185
+ # Update property map for this query
186
+ self._property_map = {"question": "input", "answer": "label", "source": "source_file"}
187
+ question_prop = "input"
188
+ answer_prop = "label"
189
+ source_prop = "source_file"
190
+ except Exception as fallback_exc:
191
+ logger.exception("vector_store.search.fallback_failed", error=str(fallback_exc))
192
+ return []
193
+ else:
194
+ logger.exception("vector_store.search.error", error=str(exc))
195
+ return []
196
+
197
+ results = []
198
+ # Use the property names that were successfully used in the query
199
+ # (either from detection or from fallback)
200
+ question_prop = self._property_map.get("question", "input")
201
+ answer_prop = self._property_map.get("answer", "label")
202
+ source_prop = self._property_map.get("source", "source_file")
203
+
204
+ # Weaviate v4 returns a QueryReturn with an objects attribute
205
+ for obj in response.objects:
206
+
207
+ question = ""
208
+ answer = ""
209
+ source = ""
210
+
211
+ if isinstance(obj.properties, dict):
212
+ question = obj.properties.get(question_prop, "")
213
+ answer = obj.properties.get(answer_prop, "")
214
+ source = obj.properties.get(source_prop, "")
215
+ else:
216
+ # Try attribute access
217
+ question = getattr(obj.properties, question_prop, "")
218
+ answer = getattr(obj.properties, answer_prop, "")
219
+ source = getattr(obj.properties, source_prop, "")
220
+
221
+ # For old schema where answer might be a label (integer), convert to string
222
+ if answer_prop == "label" and isinstance(answer, (int, float)):
223
+ answer = str(answer)
224
+
225
+ # Extract similarity - use certainty if available, otherwise distance (inverted), default to 0.0
226
+ similarity = 0.0
227
+ if obj.metadata:
228
+ if hasattr(obj.metadata, 'certainty') and obj.metadata.certainty is not None:
229
+ similarity = float(obj.metadata.certainty)
230
+ elif hasattr(obj.metadata, 'distance') and obj.metadata.distance is not None:
231
+ # Convert distance to similarity (distance is lower for more similar items)
232
+ # For normalized vectors, distance = 1 - certainty approximately
233
+ similarity = max(0.0, 1.0 - float(obj.metadata.distance))
234
+
235
+ results.append({
236
+ "document_id": str(obj.uuid),
237
+ "question": question,
238
+ "answer": answer,
239
+ "source": source,
240
+ "similarity": similarity,
241
+ })
242
+
243
+ return results
244
+
245
+ def add_entry(self, question: str, answer: str, source: str = "kb") -> str:
246
+ """Add a new entry to the vector store."""
247
+ # Generate embedding
248
+ text = question + "\n" + answer
249
+ embedding = self.encoder.encode(text)
250
+ embedding = embedding / np.linalg.norm(embedding)
251
+
252
+ # Add to Weaviate
253
+ result = self.collection.data.insert(
254
+ properties={
255
+ "question": question,
256
+ "answer": answer,
257
+ "source": source,
258
+ },
259
+ vector=embedding.tolist()
260
+ )
261
+
262
+ return result.uuid
263
+
264
+
265
+ _vector_store: VectorStore | None = None
266
+
267
+
268
+ class _NullVectorStore:
269
+ """Fallback vector store used when Weaviate is unavailable at startup.
270
+ Provides the minimal API consumed by the workflow.
271
+ """
272
+
273
+ def search(self, query: str, top_k: int | None = None) -> List[Dict[str, Any]]: # type: ignore[override]
274
+ return []
275
+
276
+ def add_entry(self, question: str, answer: str, source: str = "kb") -> str: # pragma: no cover
277
+ logger.warning("vector_store.null.add_entry_ignored")
278
+ return ""
279
+
280
+
281
+ def load_vector_store(force_reload: bool = False) -> VectorStore:
282
+ """Lazy-load the Weaviate vector store."""
283
+ global _vector_store
284
+ if _vector_store is not None and not force_reload:
285
+ return _vector_store
286
+
287
+ if not settings.weaviate_url or not settings.weaviate_api_key:
288
+ logger.warning("vector_store.config.missing", msg="Starting without KB; searches will return empty.")
289
+ _vector_store = _NullVectorStore() # type: ignore[assignment]
290
+ return _vector_store # type: ignore[return-value]
291
+
292
+ logger.info("vector_store.load.start", url=settings.weaviate_url)
293
+
294
+ # Use the v4 helper to connect to Weaviate Cloud. The older v3-style
295
+ # `weaviate.Client(url=..., auth_client_secret=...)` constructor is not
296
+ # supported by the installed weaviate client and raises a TypeError.
297
+ # connect_to_weaviate_cloud will construct the correct ConnectionParams
298
+ # for cloud deployments (http + grpc hosts) and return a connected client.
299
+ # pass the API key string directly — the helper will parse it into the
300
+ # proper AuthCredentials object. Accessing `weaviate.Auth` at module
301
+ # level can raise AttributeError in some installs; passing a string is
302
+ # supported and avoids that problem.
303
+ try:
304
+ client = weaviate.connect_to_weaviate_cloud(
305
+ cluster_url=settings.weaviate_url,
306
+ auth_credentials=settings.weaviate_api_key,
307
+ skip_init_checks=True,
308
+ )
309
+
310
+ encoder = SentenceTransformer(settings.embedding_model_name)
311
+ _vector_store = VectorStore(client=client, encoder=encoder)
312
+ logger.info("vector_store.load.success")
313
+ return _vector_store
314
+ except Exception as exc: # pragma: no cover
315
+ logger.warning(
316
+ "vector_store.load.failed_starting_without_kb",
317
+ error=str(exc),
318
+ )
319
+ _vector_store = _NullVectorStore() # type: ignore[assignment]
320
+ return _vector_store # type: ignore[return-value]
321
+
322
+
323
+ def save_feedback_to_queue(feedback_record: dict) -> None:
324
+ """Append feedback to feedback database."""
325
+
326
+ feedback_path = Path(settings.feedback_store_path)
327
+ feedback_path.parent.mkdir(parents=True, exist_ok=True)
328
+ if feedback_path.exists():
329
+ data = json.loads(feedback_path.read_text(encoding="utf-8"))
330
+ else:
331
+ data = []
332
+ data.append(feedback_record)
333
+ feedback_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
334
+
335
+
336
+ def queue_candidate_kb_entry(question: str, solution: str, source: str) -> None:
337
+ """Save a candidate knowledge base entry for human approval."""
338
+
339
+ queue_path = Path("backend/data/kb_candidate_queue.json")
340
+ queue_path.parent.mkdir(parents=True, exist_ok=True)
341
+ if queue_path.exists():
342
+ queue = json.loads(queue_path.read_text(encoding="utf-8"))
343
+ else:
344
+ queue = []
345
+ queue.append({"question": question, "answer": solution, "source": source})
346
+ queue_path.write_text(json.dumps(queue, ensure_ascii=False, indent=2), encoding="utf-8")
347
+
348
+
app/tools/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Tooling helpers for math agent."""
2
+
app/tools/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (185 Bytes). View file
 
app/tools/__pycache__/audio.cpython-313.pyc ADDED
Binary file (1.93 kB). View file
 
app/tools/__pycache__/dspy_pipeline.cpython-313.pyc ADDED
Binary file (20.7 kB). View file
 
app/tools/__pycache__/validator.cpython-313.pyc ADDED
Binary file (1.57 kB). View file
 
app/tools/__pycache__/vision.cpython-313.pyc ADDED
Binary file (1.75 kB). View file
 
app/tools/__pycache__/web_search.cpython-313.pyc ADDED
Binary file (16.3 kB). View file
 
app/tools/audio.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio transcription using Groq Whisper models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import io
7
+
8
+ from groq import Groq
9
+
10
+ from ..config import settings
11
+
12
+
13
+ def _strip_data_url(data: str) -> str:
14
+ if "," in data and data.startswith("data:"):
15
+ return data.split(",", 1)[1]
16
+ return data
17
+
18
+
19
+ def transcribe_audio(audio_base64: str, language: str = "en") -> str:
20
+ """Transcribe base64 encoded audio using Groq Whisper."""
21
+
22
+ if not settings.groq_api_key:
23
+ raise RuntimeError("GROQ_API_KEY not configured for speech transcription")
24
+
25
+ client = Groq(api_key=settings.groq_api_key)
26
+ clean_base64 = _strip_data_url(audio_base64)
27
+ audio_bytes = base64.b64decode(clean_base64)
28
+ audio_buffer = io.BytesIO(audio_bytes)
29
+ audio_buffer.name = "input.wav"
30
+
31
+ transcription = client.audio.transcriptions.create(
32
+ model="whisper-large-v3-turbo", file=audio_buffer, response_format="text", language=language
33
+ )
34
+
35
+ if hasattr(transcription, "text"):
36
+ return transcription.text
37
+ if isinstance(transcription, str):
38
+ return transcription
39
+ raise RuntimeError("Unexpected response format from Groq Whisper API")
40
+
41
+
app/tools/dspy_pipeline.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DSPy powered reasoning pipeline for math explanations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import re
8
+ from dataclasses import dataclass
9
+ from typing import Iterable, List, Sequence
10
+
11
+ # Normalization helpers to convert LaTeX-like math to readable text (module-wide)
12
+ _LATEX_SIMPLE_REPLACEMENTS = [
13
+ (r"\\cdot", "·"),
14
+ (r"\\times", "×"),
15
+ (r"\\to", "→"),
16
+ (r"\\leq", "≤"),
17
+ (r"\\geq", "≥"),
18
+ (r"\\neq", "≠"),
19
+ (r"\\pm", "±"),
20
+ (r"\\approx", "≈"),
21
+ (r"\\infty", "∞"),
22
+ (r"\\Rightarrow", "⇒"),
23
+ (r"\\Leftarrow", "⇐"),
24
+ (r"\\sin", "sin"),
25
+ (r"\\cos", "cos"),
26
+ (r"\\tan", "tan"),
27
+ (r"\\ln", "ln"),
28
+ (r"\\log", "log"),
29
+ (r"\\lim", "lim"),
30
+ (r"\\mathbb\{R\}", "R"),
31
+ (r"\\mathbb\{R\}\^3", "R^3"),
32
+ (r"\\mathbb\{Z\}", "Z"),
33
+ (r"\\mathbb\{Q\}", "Q"),
34
+ (r"\\sqrt", "√"),
35
+ (r"\\pi", "π"),
36
+ (r"\\theta", "θ"),
37
+ (r"\\alpha", "α"),
38
+ (r"\\beta", "β"),
39
+ (r"\\gamma", "γ"),
40
+ (r"\\Delta", "Δ"),
41
+ ]
42
+
43
+ _FRAC_PATTERN = re.compile(r"\\frac\{([^}]+)\}\{([^}]+)\}")
44
+
45
+ def _normalize_math(text: str) -> str:
46
+ if not text:
47
+ return text
48
+ # Remove inline/display math delimiters
49
+ text = (
50
+ text.replace("$$", "")
51
+ .replace("$", "")
52
+ .replace("\\[", "")
53
+ .replace("\\]", "")
54
+ .replace("\\(", "")
55
+ .replace("\\)", "")
56
+ )
57
+ # Strip common markdown artifacts (bold/headers/backticks)
58
+ text = text.replace("**", "").replace("`", "")
59
+ text = re.sub(r"^\s*#{1,6}\s*", "", text, flags=re.MULTILINE)
60
+ # Drop sizing/spacing commands
61
+ text = re.sub(r"\\(left|right|big|Big|quad|qquad)\\?", "", text)
62
+ # Replace \frac{a}{b} with (a)/(b)
63
+ def _frac_sub(match: re.Match) -> str:
64
+ return f"({match.group(1)})/({match.group(2)})"
65
+ text = _FRAC_PATTERN.sub(_frac_sub, text)
66
+ # Apply simple replacements
67
+ for pattern, repl in _LATEX_SIMPLE_REPLACEMENTS:
68
+ text = re.sub(pattern, repl, text)
69
+ # Normalize powers and tidy operators spacing
70
+ text = re.sub(r"\^2", "²", text)
71
+ text = re.sub(r"\^3", "³", text)
72
+ text = re.sub(r"\s*([=+\-×·*/])\s*", r" \1 ", text)
73
+ # Collapse excessive spaces
74
+ text = re.sub(r"\s+", " ", text).strip()
75
+ return text
76
+
77
+ try:
78
+ import dspy # type: ignore
79
+ import google.generativeai as genai # type: ignore
80
+
81
+ from ..config import settings
82
+ from ..logger import get_logger
83
+
84
+ logger = get_logger(__name__)
85
+
86
+ class GeminiLM(dspy.LM):
87
+ """Minimal DSPy LM wrapper around Gemini API."""
88
+
89
+ def __init__(self, model: str, api_key: str, max_output_tokens: int = 2048) -> None:
90
+ super().__init__(model=model, max_output_tokens=max_output_tokens)
91
+ genai.configure(api_key=api_key)
92
+ self._model = genai.GenerativeModel(model)
93
+
94
+ def __call__(self, prompt: str, **kwargs) -> dspy.Completion:
95
+ response = self._model.generate_content(prompt)
96
+ text = response.text or ""
97
+ return dspy.Completion(text=text)
98
+
99
+ def loglikelihood(self, prompt: str, continuation: str) -> float:
100
+ raise NotImplementedError("GeminiLM.loglikelihood is not supported")
101
+
102
+
103
+ class MathTutorSignature(dspy.Signature):
104
+ """DSPy signature describing the desired behavior."""
105
+
106
+ query = dspy.InputField(desc="Student's mathematics question - MUST solve completely")
107
+ context = dspy.InputField(desc="Knowledge snippets retrieved from KB or web")
108
+ requirements = dspy.InputField(desc="JSON encoded requirements for the explanation")
109
+ solution = dspy.OutputField(desc="Structured JSON with COMPLETE solution including all steps, all parts answered, and explicit final answer. Must be a complete solution, not just an overview.")
110
+
111
+
112
+ def _ensure_dspy_configured() -> None:
113
+ if settings.gemini_api_key is None:
114
+ raise RuntimeError("Gemini API key missing. Set GEMINI_API_KEY in environment.")
115
+
116
+ lm = GeminiLM(model=settings.dspy_model, api_key=settings.gemini_api_key, max_output_tokens=settings.dspy_max_tokens)
117
+ dspy.settings.configure(lm=lm)
118
+
119
+
120
+ @dataclass
121
+ class WebDocument:
122
+ id: str
123
+ title: str
124
+ url: str
125
+ snippet: str
126
+ score: float
127
+
128
+
129
+ @dataclass
130
+ class SearchResult:
131
+ query: str
132
+ source: str
133
+ documents: Sequence[WebDocument]
134
+
135
+
136
+ async def generate_solution_with_cot(
137
+ query: str,
138
+ contexts: Iterable,
139
+ search_metadata: SearchResult | None = None,
140
+ ) -> tuple[list[tuple[str, str]], str]:
141
+ """Return (steps, final_answer) produced by Gemini via DSPy."""
142
+
143
+ _ensure_dspy_configured()
144
+
145
+ context_blocks: List[str] = []
146
+ for ctx in contexts:
147
+ block = json.dumps({
148
+ "document_id": getattr(ctx, "document_id", ""),
149
+ "question": getattr(ctx, "question", ""),
150
+ "answer": getattr(ctx, "answer", ""),
151
+ "similarity": getattr(ctx, "similarity", 0.0),
152
+ }, ensure_ascii=False)
153
+ context_blocks.append(block)
154
+
155
+ if search_metadata:
156
+ citations = [doc.__dict__ for doc in search_metadata.documents]
157
+ else:
158
+ citations = []
159
+
160
+ requirements = {
161
+ "style": "Explain as a friendly mathematics professor using numbered steps.",
162
+ "must_include": [
163
+ "Start the response with a single line: 'Answer: <final value or statement>'",
164
+ "You MUST solve the ENTIRE problem completely - do not stop with just an overview",
165
+ "List each algebraic manipulation explicitly",
166
+ "Show ALL intermediate steps - do not skip any calculations",
167
+ "Conclude with a clear final numeric or symbolic answer",
168
+ "If the problem has multiple parts, answer ALL parts and label them (a), (b), etc.",
169
+ "Avoid LaTeX markers like $...$ or \\frac{}{}; prefer natural language and simple unicode math",
170
+ "State units or interpretation when relevant (probability, geometry, word problems, etc.)",
171
+ ],
172
+ "citations": citations,
173
+ }
174
+
175
+ # Build comprehensive context prompt
176
+ context_prompt = ""
177
+ if context_blocks:
178
+ context_prompt = "\n\nCONTEXT FROM KNOWLEDGE BASE AND WEB SEARCH:\n" + "\n".join(context_blocks)
179
+
180
+ enhanced_query = (
181
+ f"{query}\n\n"
182
+ "IMPORTANT: Solve this problem COMPLETELY. Provide all steps and the final answer. "
183
+ "Do not just give an overview - actually solve the entire problem." + context_prompt
184
+ )
185
+
186
+ predictor = dspy.Predict(MathTutorSignature)
187
+
188
+ async def _run_predict() -> dspy.Prediction:
189
+ return await asyncio.to_thread(
190
+ predictor,
191
+ query=enhanced_query,
192
+ context="\n".join(context_blocks) if context_blocks else "No additional context available.",
193
+ requirements=json.dumps(requirements, ensure_ascii=False),
194
+ )
195
+
196
+ prediction = await _run_predict()
197
+
198
+ try:
199
+ structured = json.loads(prediction.solution)
200
+ except json.JSONDecodeError:
201
+ logger.warning("dspy.solution.parse_failed", solution_preview=prediction.solution[:200] if prediction.solution else "")
202
+ # Try to extract solution from non-JSON response
203
+ solution_text = _normalize_math(prediction.solution or "")
204
+ # If it's not JSON, treat it as a natural language response and parse it
205
+ structured = {
206
+ "steps": [
207
+ {
208
+ "title": "Complete Solution",
209
+ "explanation": solution_text,
210
+ }
211
+ ],
212
+ "final_answer": solution_text.split("\n")[-1] if "\n" in solution_text else solution_text,
213
+ }
214
+
215
+ steps_raw = structured.get("steps", [])
216
+
217
+ # Convert to dict format if it's tuple format
218
+ steps = []
219
+ for idx, step in enumerate(steps_raw):
220
+ if isinstance(step, dict):
221
+ steps.append({
222
+ "title": step.get("title", f"Step {idx+1}"),
223
+ "content": _normalize_math(step.get("explanation", step.get("content", ""))),
224
+ "expression": _normalize_math(step.get("expression")) if step.get("expression") else None
225
+ })
226
+ elif isinstance(step, (list, tuple)) and len(step) >= 2:
227
+ steps.append({
228
+ "title": step[0],
229
+ "content": _normalize_math(step[1]),
230
+ "expression": _normalize_math(step[2]) if len(step) > 2 and step[2] else None
231
+ })
232
+ else:
233
+ steps.append({
234
+ "title": f"Step {idx+1}",
235
+ "content": _normalize_math(str(step)),
236
+ "expression": None
237
+ })
238
+
239
+ final_answer = _normalize_math(structured.get("final_answer", ""))
240
+
241
+ # Ensure final_answer is not empty
242
+ if not final_answer and steps:
243
+ # Try to extract from last step
244
+ last_content = steps[-1].get("content", "")
245
+ if last_content:
246
+ # Look for answer patterns
247
+ for line in reversed(last_content.split("\n")):
248
+ if any(word in line.lower() for word in ["answer", "therefore", "thus", "hence", "result"]):
249
+ final_answer = _normalize_math(line)
250
+ break
251
+ if not final_answer:
252
+ final_answer = last_content.split("\n")[-1] if "\n" in last_content else last_content
253
+
254
+ # Ultimate fallback
255
+ if not final_answer:
256
+ final_answer = "Please see the steps above for the complete solution."
257
+
258
+ # Prepend a single-line explicit answer for UI readability
259
+ if not final_answer.lower().startswith("answer:"):
260
+ final_answer = f"Answer: {final_answer}"
261
+
262
+ return steps, final_answer
263
+ except Exception: # pragma: no cover - fallback when dspy isn't installed
264
+ # Provide a lightweight fallback so the application can start even if
265
+ # `dspy` is not installed. This fallback uses the Gemini client directly
266
+ # to generate a JSON-structured explanation.
267
+ import google.generativeai as genai # type: ignore
268
+
269
+ from ..config import settings
270
+ from ..logger import get_logger
271
+
272
+ logger = get_logger(__name__)
273
+
274
+
275
+ @dataclass
276
+ class WebDocument:
277
+ id: str
278
+ title: str
279
+ url: str
280
+ snippet: str
281
+ score: float
282
+
283
+
284
+ @dataclass
285
+ class SearchResult:
286
+ query: str
287
+ source: str
288
+ documents: Sequence[WebDocument]
289
+
290
+
291
+ async def generate_solution_with_cot(
292
+ query: str,
293
+ contexts: Iterable,
294
+ search_metadata: SearchResult | None = None,
295
+ ) -> tuple[list[dict[str, str]], str]:
296
+ """Fallback: use Gemini directly to produce a JSON with `steps` and `final_answer`.
297
+
298
+ This is intentionally simple: it asks the model to return a JSON object. If parsing
299
+ fails, we return the raw text as a single reasoning step.
300
+ """
301
+
302
+ if settings.gemini_api_key is None:
303
+ raise RuntimeError("Gemini API key missing. Set GEMINI_API_KEY in environment.")
304
+
305
+ genai.configure(api_key=settings.gemini_api_key)
306
+ model = genai.GenerativeModel(settings.dspy_model)
307
+
308
+ context_text = []
309
+ for ctx in contexts:
310
+ context_text.append(json.dumps({
311
+ "document_id": getattr(ctx, "document_id", ""),
312
+ "question": getattr(ctx, "question", ""),
313
+ "answer": getattr(ctx, "answer", ""),
314
+ "similarity": getattr(ctx, "similarity", 0.0),
315
+ }, ensure_ascii=False))
316
+
317
+ context_section = ""
318
+ if context_text:
319
+ context_section = "\n\nCONTEXT FROM KNOWLEDGE BASE AND WEB SEARCH:\n" + "\n".join(context_text) + "\n"
320
+
321
+ prompt = (
322
+ "You are an expert mathematics tutor across ALL math domains (arithmetic, algebra, geometry, trigonometry, calculus, linear algebra, probability, statistics, number theory, optimization, etc.). Provide clear, human-friendly solutions.\n\n"
323
+ "CRITICAL REQUIREMENTS - YOU MUST FOLLOW THESE:\n"
324
+ "1. Start your response with a single line: 'Answer: <final value or statement>'.\n"
325
+ "2. Solve the ENTIRE problem - do NOT just give an overview or stop midway.\n"
326
+ "3. Show EVERY step of your work - no skipped calculations.\n"
327
+ "4. If there are multiple parts, answer ALL parts, labeled (a), (b), etc.\n"
328
+ "5. Provide a clear, explicit FINAL ANSWER at the very end as well.\n"
329
+ "6. Avoid LaTeX markers like $...$ or \\frac{}{}; write math in natural language or simple unicode (e.g., 1/2, ·, ×, →).\n"
330
+ "7. Use concise, readable sentences; keep symbols understandable to non-experts.\n\n"
331
+ "RESPONSE STRUCTURE:\n"
332
+ "1. Brief Overview (1-2 sentences about the goal)\n"
333
+ "2. Step-by-Step Solution (numbered) with clear titles and calculations\n"
334
+ "3. Final Answer Section (repeat the answer clearly)\n\n"
335
+ "STUDENT'S QUESTION:\n" + query + context_section +
336
+ "\n\nNOW SOLVE THIS PROBLEM COMPLETELY.\n"
337
+ )
338
+
339
+ # Generate with more tokens to ensure complete answers
340
+ from google.generativeai.types import GenerationConfig
341
+
342
+ generation_config = GenerationConfig(
343
+ temperature=0.3, # Lower temperature for more focused responses
344
+ max_output_tokens=8192, # Increased for complete solutions
345
+ top_p=0.95,
346
+ )
347
+
348
+ resp = model.generate_content(prompt, generation_config=generation_config)
349
+ text = (resp.text or "").strip()
350
+ text = _normalize_math(text)
351
+
352
+ if not text:
353
+ logger.warning("dspy.fallback.empty_response", query=query)
354
+ text = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
355
+
356
+ try:
357
+ # Parse the natural language response into structured steps
358
+ lines = text.strip().split("\n")
359
+ steps = []
360
+ current_step = None
361
+ final_answer = ""
362
+ buffer = []
363
+ overview_added = False
364
+
365
+ for line in lines:
366
+ line = line.strip()
367
+ if not line:
368
+ if current_step and buffer:
369
+ current_step["content"] = "\n".join(buffer)
370
+ buffer = []
371
+ continue
372
+
373
+ # Detect step headers (more flexible patterns)
374
+ step_match = False
375
+ if line.lower().startswith("step ") or re.match(r"^\d+[\.\)]\s+", line, re.IGNORECASE):
376
+ step_match = True
377
+ if current_step:
378
+ current_step["content"] = "\n".join(buffer)
379
+ steps.append(current_step)
380
+ buffer = []
381
+ # Extract step number and title
382
+ step_parts = re.split(r"[:\.]", line, 1)
383
+ step_title = step_parts[1].strip() if len(step_parts) > 1 else line
384
+ current_step = {
385
+ "title": step_title if step_title else f"Step {len(steps) + 1}",
386
+ "content": "",
387
+ "expression": None
388
+ }
389
+ # Detect final answer section
390
+ elif any(word in line.lower() for word in ["therefore", "thus", "finally", "hence", "answer:", "the answer is", "final answer"]):
391
+ if current_step:
392
+ current_step["content"] = "\n".join(buffer)
393
+ steps.append(current_step)
394
+ buffer = []
395
+ current_step = None
396
+ # Extract final answer
397
+ for word in ["therefore", "thus", "finally", "hence", "answer:", "the answer is", "final answer"]:
398
+ if word in line.lower():
399
+ final_answer = line.split(":", 1)[1].strip() if ":" in line else line
400
+ break
401
+ if not final_answer:
402
+ final_answer = line
403
+ # Handle equations (lines with = or mathematical operators)
404
+ elif current_step and ("=" in line or any(s in line for s in "+-*/^√∫∑∈ℝ")):
405
+ if buffer: # Save accumulated explanation first
406
+ current_step["content"] = "\n".join(buffer)
407
+ buffer = []
408
+ if not current_step["expression"]:
409
+ current_step["expression"] = _normalize_math(line)
410
+ else:
411
+ current_step["expression"] += "\n" + _normalize_math(line)
412
+ # Add to current step's content
413
+ elif current_step:
414
+ buffer.append(line)
415
+ # Handle text before first step (overview)
416
+ elif not overview_added and not step_match:
417
+ steps.append({
418
+ "title": "Overview",
419
+ "content": line,
420
+ "expression": None
421
+ })
422
+ overview_added = True
423
+
424
+ # Don't forget the last step
425
+ if current_step:
426
+ if buffer:
427
+ current_step["content"] = "\n".join(buffer)
428
+ steps.append(current_step)
429
+
430
+ # If no steps were parsed, create steps from the text
431
+ if not steps:
432
+ # Split by paragraphs or double newlines
433
+ paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
434
+ for idx, para in enumerate(paragraphs):
435
+ if idx == 0:
436
+ steps.append({
437
+ "title": "Solution",
438
+ "content": _normalize_math(para),
439
+ "expression": None
440
+ })
441
+ else:
442
+ steps.append({
443
+ "title": f"Step {idx}",
444
+ "content": _normalize_math(para),
445
+ "expression": None
446
+ })
447
+
448
+ # Extract final answer if not found
449
+ if not final_answer:
450
+ # Look for answer patterns in the last few lines
451
+ for line in reversed(lines[-10:]):
452
+ line_lower = line.lower()
453
+ if any(word in line_lower for word in ["answer", "therefore", "thus", "hence", "result"]):
454
+ final_answer = line
455
+ break
456
+
457
+ # If still no answer, use the last meaningful line or paragraph
458
+ if not final_answer and steps:
459
+ last_content = steps[-1].get("content", "")
460
+ if last_content:
461
+ final_answer = last_content.split("\n")[-1] if "\n" in last_content else last_content
462
+
463
+ # Ultimate fallback
464
+ if not final_answer:
465
+ final_answer = "Please refer to the steps above for the complete solution."
466
+
467
+ return steps, final_answer
468
+ except Exception:
469
+ logger.warning("dspy.fallback.parse_failed")
470
+ # Even for plain text, maintain consistent structure
471
+ return [{
472
+ "title": "Solution",
473
+ "content": text
474
+ }], text
475
+
476
+
app/tools/validator.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Validation helpers for user-provided solutions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import google.generativeai as genai
6
+
7
+ from ..config import settings
8
+
9
+
10
+ def validate_user_solution(question: str, proposed_solution: str) -> bool:
11
+ """Use Gemini to validate a user-uploaded solution."""
12
+
13
+ if not settings.gemini_api_key:
14
+ raise RuntimeError("Gemini API key required for solution validation")
15
+
16
+ genai.configure(api_key=settings.gemini_api_key)
17
+ model = genai.GenerativeModel(settings.gemini_model)
18
+ prompt = (
19
+ "You are an expert mathematics professor. Validate the student's solution. "
20
+ "Return ONLY 'VALID' if the reasoning is mathematically correct. "
21
+ "Return ONLY 'INVALID' otherwise.\n\n"
22
+ f"Question: {question}\nStudent solution:\n{proposed_solution}"
23
+ )
24
+ response = model.generate_content(prompt)
25
+ text = (response.text or "").strip().lower()
26
+ return text.startswith("valid")
27
+
28
+
app/tools/vision.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image understanding helpers using Gemini multimodal models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+
7
+ import google.generativeai as genai
8
+
9
+ from ..config import settings
10
+
11
+
12
+ def _strip_data_url(data: str) -> str:
13
+ if "," in data and data.startswith("data:"):
14
+ return data.split(",", 1)[1]
15
+ return data
16
+
17
+
18
+ def extract_text_from_image(image_base64: str, prompt: str = "Extract all mathematics text from this image.") -> str:
19
+ """Use Gemini to extract math text from base64 encoded image."""
20
+
21
+ if not settings.gemini_api_key:
22
+ raise RuntimeError("Gemini API key not configured for image understanding")
23
+
24
+ genai.configure(api_key=settings.gemini_api_key)
25
+ clean_base64 = _strip_data_url(image_base64)
26
+ image_bytes = base64.b64decode(clean_base64)
27
+ model = genai.GenerativeModel(settings.gemini_model)
28
+ response = model.generate_content(
29
+ [prompt, {"mime_type": "image/png", "data": image_bytes}]
30
+ )
31
+ return response.text or ""
32
+
33
+
app/tools/web_search.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web search workflow: MCP → LangChain Tavily → SDK fallback."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import os
8
+ from dataclasses import dataclass
9
+ from typing import List, Optional
10
+
11
+ import aiohttp
12
+
13
+ from ..config import settings
14
+ from ..logger import get_logger
15
+ from .dspy_pipeline import SearchResult, WebDocument
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class MCPDocument:
22
+ id: str
23
+ title: str
24
+ url: str
25
+ snippet: str
26
+ score: float
27
+
28
+
29
+ async def _parse_tavily_response(result: dict) -> List[MCPDocument]:
30
+ """Parse Tavily API response into MCPDocument objects."""
31
+ documents: List[MCPDocument] = []
32
+
33
+ # Handle both direct results and nested results
34
+ results_list = result.get("results", [])
35
+ if not results_list and isinstance(result, list):
36
+ results_list = result
37
+
38
+ for idx, item in enumerate(results_list):
39
+ documents.append(
40
+ MCPDocument(
41
+ id=item.get("id", str(idx)),
42
+ title=item.get("title", "Untitled"),
43
+ url=item.get("url", ""),
44
+ snippet=item.get("content", ""),
45
+ score=item.get("score", 0.0),
46
+ )
47
+ )
48
+ return documents
49
+
50
+
51
+ async def _invoke_mcp_tavily(query: str, max_results: int = 5) -> List[MCPDocument]:
52
+ """Invoke Tavily via MCP server with multiple payload format attempts."""
53
+ mcp_url = settings.mcp_tavily_url or os.getenv("MCP_TAVILY_URL")
54
+
55
+ if not mcp_url:
56
+ logger.debug("web_search.mcp_not_configured")
57
+ raise RuntimeError("MCP URL not configured")
58
+
59
+ # Try different payload formats that MCP servers commonly use
60
+ payload_formats = [
61
+ # Format 1: MCP protocol with tools/call
62
+ {
63
+ "jsonrpc": "2.0",
64
+ "id": 1,
65
+ "method": "tools/call",
66
+ "params": {
67
+ "name": "tavily_search",
68
+ "arguments": {
69
+ "query": query,
70
+ "max_results": max_results
71
+ }
72
+ }
73
+ },
74
+ # Format 2: Direct tavily_search method
75
+ {
76
+ "jsonrpc": "2.0",
77
+ "id": 1,
78
+ "method": "tavily_search",
79
+ "params": {
80
+ "query": query,
81
+ "max_results": max_results
82
+ }
83
+ },
84
+ # Format 3: Call tool method
85
+ {
86
+ "jsonrpc": "2.0",
87
+ "id": 1,
88
+ "method": "call_tool",
89
+ "params": {
90
+ "name": "tavily_search",
91
+ "arguments": {
92
+ "query": query,
93
+ "max_results": max_results
94
+ }
95
+ }
96
+ },
97
+ ]
98
+
99
+ last_error = None
100
+
101
+ async with aiohttp.ClientSession() as session:
102
+ # Quick health check
103
+ try:
104
+ async with session.head(mcp_url, timeout=10) as health_resp:
105
+ if health_resp.status not in (200, 405): # 405 = Method Not Allowed for HEAD
106
+ logger.warning("web_search.mcp_health_check_failed",
107
+ status=health_resp.status)
108
+ except Exception as health_exc:
109
+ logger.warning("web_search.mcp_unreachable",
110
+ error=str(health_exc),
111
+ url=mcp_url)
112
+ raise RuntimeError(f"MCP server unreachable: {health_exc}")
113
+
114
+ # Try each payload format
115
+ for idx, payload in enumerate(payload_formats, 1):
116
+ try:
117
+ logger.debug("web_search.mcp_attempting_format",
118
+ format_number=idx,
119
+ method=payload.get("method"))
120
+
121
+ headers = {
122
+ "Accept": "text/event-stream, application/json",
123
+ "Content-Type": "application/json",
124
+ }
125
+
126
+ async with session.post(mcp_url, headers=headers, json=payload, timeout=30) as resp:
127
+ ctype = resp.headers.get("Content-Type", "")
128
+ status = resp.status
129
+
130
+ if status != 200:
131
+ response_text = await resp.text()
132
+ logger.debug("web_search.mcp_format_failed",
133
+ format_number=idx,
134
+ status=status,
135
+ response=response_text[:200])
136
+ last_error = f"HTTP {status}: {response_text[:100]}"
137
+ continue
138
+
139
+ # Handle SSE response
140
+ if "text/event-stream" in ctype:
141
+ last_json = None
142
+ async for raw in resp.content:
143
+ line = raw.decode("utf-8", errors="ignore").strip()
144
+ if line.startswith("data:"):
145
+ data_str = line[5:].strip()
146
+ if data_str and data_str != "[DONE]":
147
+ last_json = data_str
148
+
149
+ if not last_json:
150
+ logger.debug("web_search.mcp_empty_sse", format_number=idx)
151
+ last_error = "Empty SSE stream"
152
+ continue
153
+
154
+ data = json.loads(last_json)
155
+ else:
156
+ # Handle JSON response
157
+ data = await resp.json()
158
+
159
+ # Check for JSON-RPC error
160
+ if "error" in data:
161
+ error_detail = data["error"]
162
+ logger.debug("web_search.mcp_format_error",
163
+ format_number=idx,
164
+ error=error_detail)
165
+ last_error = f"RPC error: {error_detail}"
166
+ continue
167
+
168
+ # Extract result
169
+ result = data.get("result", {})
170
+
171
+ # Handle MCP protocol response format
172
+ if "content" in result and isinstance(result["content"], list):
173
+ for content_item in result["content"]:
174
+ if content_item.get("type") == "text":
175
+ text_data = content_item.get("text", "")
176
+ try:
177
+ tavily_result = json.loads(text_data)
178
+ documents = await _parse_tavily_response(tavily_result)
179
+ if documents:
180
+ logger.info("web_search.mcp_success",
181
+ format_number=idx,
182
+ document_count=len(documents))
183
+ return documents
184
+ except json.JSONDecodeError:
185
+ continue
186
+
187
+ # Handle direct Tavily response format
188
+ if "results" in result:
189
+ documents = await _parse_tavily_response(result)
190
+ if documents:
191
+ logger.info("web_search.mcp_success",
192
+ format_number=idx,
193
+ document_count=len(documents))
194
+ return documents
195
+
196
+ last_error = f"No results in response structure: {list(result.keys())}"
197
+
198
+ except asyncio.TimeoutError:
199
+ logger.debug("web_search.mcp_timeout", format_number=idx)
200
+ last_error = "Request timeout"
201
+ continue
202
+ except Exception as exc:
203
+ logger.debug("web_search.mcp_format_exception",
204
+ format_number=idx,
205
+ error=str(exc),
206
+ error_type=type(exc).__name__)
207
+ last_error = str(exc)
208
+ continue
209
+
210
+ # All formats failed
211
+ logger.warning("web_search.mcp_all_formats_failed",
212
+ last_error=last_error,
213
+ formats_tried=len(payload_formats))
214
+ raise RuntimeError(f"MCP failed: {last_error}")
215
+
216
+
217
+ # LangChain Tavily tool (preferred fallback)
218
+ try:
219
+ from langchain_tavily import TavilySearch # type: ignore
220
+ except Exception:
221
+ TavilySearch = None # type: ignore
222
+
223
+
224
+ async def _run_tavily_langchain(query: str, max_results: int = 5) -> List[MCPDocument]:
225
+ """Use LangChain Tavily tool - the recommended integration."""
226
+ if TavilySearch is None:
227
+ raise RuntimeError("LangChain Tavily not installed. Run: pip install -U langchain-tavily")
228
+
229
+ if not settings.tavily_api_key:
230
+ raise RuntimeError("TAVILY_API_KEY not configured")
231
+
232
+ # Ensure API key is in environment for LangChain
233
+ os.environ["TAVILY_API_KEY"] = settings.tavily_api_key
234
+
235
+ try:
236
+ # Initialize the tool with recommended settings
237
+ tool = TavilySearch(
238
+ max_results=max_results,
239
+ topic="general",
240
+ search_depth="basic",
241
+ )
242
+
243
+ # Invoke the tool - it returns a dict with results
244
+ result = await asyncio.to_thread(tool.invoke, {"query": query})
245
+
246
+ documents: List[MCPDocument] = []
247
+ results_list = result.get("results", [])
248
+
249
+ for idx, item in enumerate(results_list):
250
+ documents.append(
251
+ MCPDocument(
252
+ id=item.get("id", str(idx)),
253
+ title=item.get("title", "Untitled"),
254
+ url=item.get("url", ""),
255
+ snippet=item.get("content", ""),
256
+ score=item.get("score", 0.0),
257
+ )
258
+ )
259
+
260
+ return documents
261
+
262
+ except Exception as exc:
263
+ logger.error("web_search.langchain_execution_error",
264
+ error=str(exc),
265
+ error_type=type(exc).__name__)
266
+ raise RuntimeError(f"LangChain Tavily failed: {exc}")
267
+
268
+
269
+ # Direct Tavily SDK (last resort fallback)
270
+ try:
271
+ from tavily import TavilyClient # type: ignore
272
+ except Exception:
273
+ TavilyClient = None # type: ignore
274
+
275
+
276
+ async def _run_tavily_sdk(query: str, max_results: int = 5) -> List[MCPDocument]:
277
+ """Direct Tavily SDK as last resort fallback."""
278
+ if TavilyClient is None:
279
+ raise RuntimeError("Tavily SDK not installed. Run: pip install tavily-python")
280
+
281
+ if not settings.tavily_api_key:
282
+ raise RuntimeError("TAVILY_API_KEY not configured")
283
+
284
+ client = TavilyClient(api_key=settings.tavily_api_key)
285
+ result = await asyncio.to_thread(
286
+ client.search,
287
+ query=query,
288
+ max_results=max_results,
289
+ include_images=False
290
+ )
291
+
292
+ documents: List[MCPDocument] = []
293
+ for idx, item in enumerate(result.get("results", [])):
294
+ documents.append(
295
+ MCPDocument(
296
+ id=item.get("id", str(idx)),
297
+ title=item.get("title", "Untitled"),
298
+ url=item.get("url", ""),
299
+ snippet=item.get("content", ""),
300
+ score=item.get("score", 0.0),
301
+ )
302
+ )
303
+ return documents
304
+
305
+
306
+ async def run_web_search_with_fallback(query: str) -> Optional[SearchResult]:
307
+ """Run web search with fallback chain: MCP → LangChain Tavily → SDK.
308
+
309
+ Priority order:
310
+ 1. MCP Tavily (if configured) - Remote MCP server integration
311
+ 2. LangChain Tavily - Official LangChain integration (recommended by Tavily)
312
+ 3. Tavily SDK - Direct Python SDK as last resort
313
+ """
314
+
315
+ all_documents: List[MCPDocument] = []
316
+ sources_used: List[str] = []
317
+
318
+ # 1) Try MCP Tavily first (if configured)
319
+ mcp_url = settings.mcp_tavily_url or os.getenv("MCP_TAVILY_URL")
320
+ if mcp_url:
321
+ try:
322
+ tavily_docs = await _invoke_mcp_tavily(query, max_results=5)
323
+ if tavily_docs:
324
+ all_documents.extend(tavily_docs[:5])
325
+ sources_used.append("tavily-mcp")
326
+ logger.info("web_search.mcp_completed", document_count=len(tavily_docs))
327
+ except Exception as exc:
328
+ logger.warning("web_search.mcp_failed", error=str(exc))
329
+
330
+ # 2) Fallback to LangChain Tavily (preferred method)
331
+ if len(all_documents) < 3:
332
+ try:
333
+ lc_docs = await _run_tavily_langchain(query, max_results=5)
334
+ existing_urls = {doc.url for doc in all_documents if doc.url}
335
+
336
+ added_count = 0
337
+ for doc in lc_docs:
338
+ if doc.url not in existing_urls:
339
+ all_documents.append(doc)
340
+ added_count += 1
341
+ if len(all_documents) >= 5:
342
+ break
343
+
344
+ if added_count > 0:
345
+ sources_used.append("tavily-langchain")
346
+ logger.info("web_search.langchain_completed",
347
+ document_count=added_count,
348
+ total_documents=len(all_documents))
349
+ except Exception as lc_exc:
350
+ logger.warning("web_search.langchain_failed", error=str(lc_exc))
351
+
352
+ # 3) Last resort: Direct Tavily SDK
353
+ if len(all_documents) < 3:
354
+ try:
355
+ tavily_sdk_docs = await _run_tavily_sdk(query, max_results=5)
356
+ existing_urls = {doc.url for doc in all_documents if doc.url}
357
+
358
+ added_count = 0
359
+ for doc in tavily_sdk_docs:
360
+ if doc.url not in existing_urls:
361
+ all_documents.append(doc)
362
+ added_count += 1
363
+ if len(all_documents) >= 5:
364
+ break
365
+
366
+ if added_count > 0:
367
+ sources_used.append("tavily-sdk")
368
+ logger.info("web_search.sdk_completed",
369
+ document_count=added_count,
370
+ total_documents=len(all_documents))
371
+ except Exception as sdk_exc:
372
+ logger.warning("web_search.sdk_failed", error=str(sdk_exc))
373
+
374
+ if not all_documents:
375
+ logger.error("web_search.all_sources_failed", query=query)
376
+ return None
377
+
378
+ web_documents = [
379
+ WebDocument(
380
+ id=doc.id,
381
+ title=doc.title,
382
+ url=doc.url,
383
+ snippet=doc.snippet,
384
+ score=doc.score,
385
+ )
386
+ for doc in all_documents[:5]
387
+ ]
388
+
389
+ source = "+".join(sources_used) if sources_used else "unknown"
390
+ logger.info("web_search.completed",
391
+ source=source,
392
+ total_documents=len(web_documents),
393
+ query=query)
394
+
395
+ return SearchResult(query=query, source=source, documents=web_documents)
app/workflows/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Workflow graphs for the math agent."""
2
+
3
+
app/workflows/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (193 Bytes). View file
 
app/workflows/__pycache__/langgraph_pipeline.cpython-313.pyc ADDED
Binary file (11.9 kB). View file
 
app/workflows/langgraph_pipeline.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph implementation of the math agent workflow."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal, TypedDict
6
+
7
+ from langgraph.graph import END, StateGraph
8
+
9
+ from ..config import settings
10
+ from ..guardrails import run_input_guardrails, run_output_guardrails
11
+ from ..logger import get_logger
12
+ from ..schemas import Citation, RetrievalContext
13
+ from ..services.vector_store import VectorStore
14
+ from ..tools.dspy_pipeline import SearchResult, generate_solution_with_cot
15
+ from ..tools.web_search import run_web_search_with_fallback
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class AgentGraphState(TypedDict, total=False):
21
+ query: str
22
+ sanitized_query: str
23
+ gateway_trace: list[str]
24
+ knowledge_hits: list[RetrievalContext] # Combined KB + web for generation
25
+ kb_hits: list[RetrievalContext] # KB hits only (for UI)
26
+ web_hits: list[RetrievalContext] # Web hits only (for UI)
27
+ search_result: SearchResult | None
28
+ steps: list[tuple[str, str]]
29
+ answer: str
30
+ source: str
31
+ citations: list[Citation]
32
+ retrieved_from_kb: bool
33
+
34
+
35
+ def build_math_agent_graph(vector_store: VectorStore) -> StateGraph:
36
+ """Compile the LangGraph workflow for the math agent."""
37
+
38
+ graph = StateGraph(AgentGraphState)
39
+
40
+ async def input_guardrails_node(state: AgentGraphState) -> AgentGraphState:
41
+ sanitized = run_input_guardrails(state["query"]) if settings.enforce_input_guardrails else state["query"]
42
+ trace = state.get("gateway_trace", []) + ["input_guardrails_pass"]
43
+ return {"sanitized_query": sanitized, "gateway_trace": trace}
44
+
45
+ async def retrieve_kb_node(state: AgentGraphState) -> AgentGraphState:
46
+ query = state["sanitized_query"]
47
+ kb_results = vector_store.search(query)
48
+ kb_contexts: list[RetrievalContext] = []
49
+ for result in kb_results:
50
+ # kb_results is a list of dicts with: document_id, question, answer, source, similarity
51
+ kb_contexts.append(
52
+ RetrievalContext(
53
+ document_id=result.get("document_id", ""),
54
+ question=result.get("question", ""),
55
+ answer=result.get("answer", ""),
56
+ similarity=result.get("similarity", 0.0) or 0.0,
57
+ )
58
+ )
59
+ trace = state.get("gateway_trace", []) + ["kb_search_complete"]
60
+ # Store KB contexts separately so we can combine them later
61
+ return {
62
+ "knowledge_hits": kb_contexts,
63
+ "kb_hits": kb_contexts, # Keep KB hits separate
64
+ "gateway_trace": trace
65
+ }
66
+
67
+ def route_after_retrieval(state: AgentGraphState) -> Literal["kb", "search"]:
68
+ # Always proceed to both KB and web search, then generate with Gemini
69
+ # This ensures we always get comprehensive results
70
+ contexts = state.get("knowledge_hits", []) or []
71
+ if contexts and len(contexts) > 0:
72
+ top_similarity = contexts[0].similarity if contexts[0].similarity is not None else 0.0
73
+ if top_similarity >= settings.similarity_threshold:
74
+ logger.info("router.has_good_kb_match", similarity=top_similarity, threshold=settings.similarity_threshold)
75
+ # Still go to web search to enhance with additional context
76
+ return "search"
77
+ logger.info("router.kb_below_threshold", top_similarity=top_similarity, threshold=settings.similarity_threshold)
78
+ else:
79
+ logger.info("router.no_kb_results")
80
+ # Always proceed to web search to get additional context
81
+ return "search"
82
+
83
+ async def kb_generation_node(state: AgentGraphState) -> AgentGraphState:
84
+ # This node is deprecated - we now always go through web search then generation
85
+ # But keep it for backward compatibility in case routing changes
86
+ contexts = state.get("knowledge_hits", [])
87
+ # Also fetch web search to combine with KB
88
+ query = state["sanitized_query"]
89
+ search_result = await run_web_search_with_fallback(query)
90
+ web_contexts = []
91
+ if search_result:
92
+ web_contexts = [
93
+ RetrievalContext(
94
+ document_id=document.id,
95
+ question=document.title,
96
+ answer=document.snippet,
97
+ similarity=document.score or 0.0,
98
+ )
99
+ for document in search_result.documents
100
+ ]
101
+
102
+ # Combine KB and web contexts
103
+ all_contexts = contexts + web_contexts
104
+ steps, answer_text = await generate_solution_with_cot(state["sanitized_query"], all_contexts, search_metadata=search_result)
105
+ trace = state.get("gateway_trace", []) + ["router->kb+web"]
106
+ return {
107
+ "steps": steps,
108
+ "answer": answer_text,
109
+ "source": "kb+web",
110
+ "retrieved_from_kb": True,
111
+ "citations": state.get("citations", []),
112
+ "gateway_trace": trace,
113
+ "search_result": search_result,
114
+ }
115
+
116
+ async def web_search_node(state: AgentGraphState) -> AgentGraphState:
117
+ query = state["sanitized_query"]
118
+ # Get KB contexts from previous node (preserve them separately)
119
+ kb_contexts = state.get("kb_hits", []) or state.get("knowledge_hits", [])
120
+
121
+ # Always attempt web search
122
+ search_result = await run_web_search_with_fallback(query)
123
+
124
+ # Combine KB contexts with web search contexts for Gemini generation
125
+ all_contexts = list(kb_contexts) if kb_contexts else []
126
+ web_contexts = []
127
+
128
+ if search_result and search_result.documents:
129
+ web_contexts = [
130
+ RetrievalContext(
131
+ document_id=document.id,
132
+ question=document.title,
133
+ answer=document.snippet,
134
+ similarity=document.score or 0.0,
135
+ )
136
+ for document in search_result.documents
137
+ ]
138
+ all_contexts.extend(web_contexts)
139
+
140
+ citations = [
141
+ Citation(title=document.title, url=document.url)
142
+ for document in search_result.documents
143
+ if document.url
144
+ ]
145
+ source = search_result.source
146
+ else:
147
+ logger.warning("web_search.no_results", query=query)
148
+ citations = []
149
+ source = "kb-only" if kb_contexts else "direct"
150
+
151
+ trace = state.get("gateway_trace", []) + ["web_search_complete", f"source={source}"]
152
+ return {
153
+ "search_result": search_result,
154
+ "knowledge_hits": all_contexts, # Combined KB + web contexts for generation
155
+ "kb_hits": kb_contexts, # Keep KB hits separate for UI
156
+ "web_hits": web_contexts, # Keep web hits separate for UI
157
+ "citations": citations,
158
+ "retrieved_from_kb": len(kb_contexts) > 0,
159
+ "source": source,
160
+ "gateway_trace": trace,
161
+ }
162
+
163
+ async def search_generation_node(state: AgentGraphState) -> AgentGraphState:
164
+ # Always use Gemini to generate comprehensive answer from all available contexts
165
+ contexts = state.get("knowledge_hits", []) # Already combined KB + web contexts
166
+ search_metadata = state.get("search_result")
167
+
168
+ logger.info("generation.starting", contexts_count=len(contexts), has_search_metadata=search_metadata is not None)
169
+
170
+ try:
171
+ # Always generate with Gemini using all available context
172
+ steps, answer_text = await generate_solution_with_cot(
173
+ state["sanitized_query"],
174
+ contexts,
175
+ search_metadata=search_metadata
176
+ )
177
+
178
+ # Ensure answer is comprehensive and not empty
179
+ if not answer_text or not answer_text.strip():
180
+ logger.warning("generation.empty_response", query=state["sanitized_query"])
181
+ # Try to construct answer from steps if available
182
+ if steps:
183
+ # Get the last step or combine all step content
184
+ if isinstance(steps[-1], dict):
185
+ answer_text = steps[-1].get("content", "") or steps[-1].get("explanation", "")
186
+ elif isinstance(steps[-1], (list, tuple)) and len(steps[-1]) > 1:
187
+ answer_text = steps[-1][1]
188
+ else:
189
+ answer_text = str(steps[-1])
190
+
191
+ if not answer_text:
192
+ # Combine all step contents
193
+ answer_parts = []
194
+ for step in steps:
195
+ if isinstance(step, dict):
196
+ content = step.get("content", "") or step.get("explanation", "")
197
+ if content:
198
+ answer_parts.append(content)
199
+ elif isinstance(step, (list, tuple)) and len(step) > 1:
200
+ answer_parts.append(step[1])
201
+ answer_text = "\n\n".join(answer_parts) if answer_parts else "Please see the steps above for the solution."
202
+ else:
203
+ answer_text = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
204
+ steps = [{"title": "Error", "content": "Unable to generate solution with available context."}]
205
+
206
+ # Log the response length to debug
207
+ logger.info("generation.completed", answer_length=len(answer_text), steps_count=len(steps))
208
+ except Exception as exc:
209
+ logger.exception("generation.failed", error=str(exc))
210
+ answer_text = "I encountered an error while processing your question. Please try again."
211
+ steps = [{"title": "Error", "content": str(exc)}]
212
+
213
+ # Always return with all context information - let UI decide what to show
214
+ return {
215
+ "steps": steps,
216
+ "answer": answer_text,
217
+ # Keep all context for UI display
218
+ }
219
+
220
+ async def output_guardrails_node(state: AgentGraphState) -> AgentGraphState:
221
+ if not settings.enforce_output_guardrails:
222
+ return {}
223
+ filtered_answer, urls = run_output_guardrails(state["answer"])
224
+ citations = list(state.get("citations", []))
225
+ for url in urls:
226
+ if not any(citation.url == url for citation in citations):
227
+ citations.append(Citation(title="Source", url=url))
228
+
229
+ current_trace = state.get("gateway_trace", [])
230
+ trace = current_trace + ["output_guardrails_pass"]
231
+
232
+ return {"answer": filtered_answer, "citations": citations, "gateway_trace": trace}
233
+
234
+ graph.add_node("input_guardrails", input_guardrails_node)
235
+ graph.add_node("retrieve_kb", retrieve_kb_node)
236
+ graph.add_node("kb_generate", kb_generation_node)
237
+ graph.add_node("web_search", web_search_node)
238
+ graph.add_node("search_generate", search_generation_node)
239
+ graph.add_node("output_guardrails", output_guardrails_node)
240
+
241
+ graph.set_entry_point("input_guardrails")
242
+ graph.add_edge("input_guardrails", "retrieve_kb")
243
+ graph.add_conditional_edges("retrieve_kb", route_after_retrieval, {"kb": "kb_generate", "search": "web_search"})
244
+ graph.add_edge("kb_generate", "output_guardrails")
245
+ graph.add_edge("web_search", "search_generate")
246
+ graph.add_edge("search_generate", "output_guardrails")
247
+ graph.add_edge("output_guardrails", END)
248
+
249
+ return graph.compile()
250
+
251
+
backend/data/feedback_db.json ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "message_id": "assistant-1762103842320",
4
+ "query": "Find the derivative of the function f(x)=x^3sin(x^2) and simplify your answer.",
5
+ "agent_response": {
6
+ "answer": "```json{ \"overview\": \"To find the derivative of the given function, we apply the product rule and the chain rule.\", \"steps\": [ { \"title\": \"Step 1: Identify the Rules for Differentiation\", \"content\": \"The function $f(x) = x^3 \\sin(x^2)$ is a product of two functions, $u(x) = x^3$ and $v(x) = \\sin(x^2)$. Therefore, we will use the product rule for differentiation, which states that if $f(x) = u(x)v(x)$, then $f'(x) = u'(x)v(x) + u(x)v'(x)$.\\n\\nAdditionally, to find the derivative of $v(x) = \\sin(x^2)$, we will need to apply the chain rule, which states that if $y = g(h(x))$, then $y' = g'(h(x))h'(x)$.\", \"expression\": \"f'(x) = u'(x)v(x) + u(x)v'(x)\" }, { \"title\": \"Step 2: Differentiate Each Part\", \"content\": \"First, let's find the derivative of $u(x) = x^3$:\\n$u'(x) = \\\\frac{d}{dx}(x^3) = 3x^2$\\n\\nNext, let's find the derivative of $v(x) = \\\\sin(x^2)$ using the chain rule. Let $w = x^2$, so $v(x) = \\\\sin(w)$.\\n$\\\\frac{dv}{dw} = \\\\cos(w)$\\n$\\\\frac{dw}{dx} = \\\\frac{d}{dx}(x^2) = 2x$\\n\\nApplying the chain rule, $v'(x) = \\\\frac{dv}{dw} \\\\cdot \\\\frac{dw}{dx} = \\\\cos(x^2) \\\\cdot 2x = 2x \\\\cos(x^2)$.\", \"expression\": \"u'(x) = 3x^2 \\\\quad \\\\text{and} \\\\quad v'(x) = 2x \\\\cos(x^2)\" }, { \"title\": \"Step 3: Apply the Product Rule\", \"content\": \"Now, we substitute $u(x)$, $u'(x)$, $v(x)$, and $v'(x)$ into the product rule formula:\\n$f'(x) = u'(x)v(x) + u(x)v'(x)$\\n$f'(x) = (3x^2)(\\\\sin(x^2)) + (x^3)(2x \\\\cos(x^2))$\\n$f'(x) = 3x^2 \\\\sin(x^2) + 2x^4 \\\\cos(x^2)$\", \"expression\": \"f'(x) = 3x^2 \\\\sin(x^2) + 2x^4 \\\\cos(x^2)\" }, { \"title\": \"Step 4: Simplify the Answer\", \"content\": \"We can simplify the expression by factoring out the common term $x^2$ from both terms:\\n$f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))$\", \"expression\": \"f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))\" } ], \"final_answer\": \"$f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))$\"}```",
7
+ "steps": [
8
+ {
9
+ "title": "Solution",
10
+ "content": "```json\n{\n \"overview\": \"To find the derivative of the given function, we apply the product rule and the chain rule.\",\n \"steps\": [\n {\n \"title\": \"Step 1: Identify the Rules for Differentiation\",\n \"content\": \"The function $f(x) = x^3 \\sin(x^2)$ is a product of two functions, $u(x) = x^3$ and $v(x) = \\sin(x^2)$. Therefore, we will use the product rule for differentiation, which states that if $f(x) = u(x)v(x)$, then $f'(x) = u'(x)v(x) + u(x)v'(x)$.\\n\\nAdditionally, to find the derivative of $v(x) = \\sin(x^2)$, we will need to apply the chain rule, which states that if $y = g(h(x))$, then $y' = g'(h(x))h'(x)$.\",\n \"expression\": \"f'(x) = u'(x)v(x) + u(x)v'(x)\"\n },\n {\n \"title\": \"Step 2: Differentiate Each Part\",\n \"content\": \"First, let's find the derivative of $u(x) = x^3$:\\n$u'(x) = \\\\frac{d}{dx}(x^3) = 3x^2$\\n\\nNext, let's find the derivative of $v(x) = \\\\sin(x^2)$ using the chain rule. Let $w = x^2$, so $v(x) = \\\\sin(w)$.\\n$\\\\frac{dv}{dw} = \\\\cos(w)$\\n$\\\\frac{dw}{dx} = \\\\frac{d}{dx}(x^2) = 2x$\\n\\nApplying the chain rule, $v'(x) = \\\\frac{dv}{dw} \\\\cdot \\\\frac{dw}{dx} = \\\\cos(x^2) \\\\cdot 2x = 2x \\\\cos(x^2)$.\",\n \"expression\": \"u'(x) = 3x^2 \\\\quad \\\\text{and} \\\\quad v'(x) = 2x \\\\cos(x^2)\"\n },\n {\n \"title\": \"Step 3: Apply the Product Rule\",\n \"content\": \"Now, we substitute $u(x)$, $u'(x)$, $v(x)$, and $v'(x)$ into the product rule formula:\\n$f'(x) = u'(x)v(x) + u(x)v'(x)$\\n$f'(x) = (3x^2)(\\\\sin(x^2)) + (x^3)(2x \\\\cos(x^2))$\\n$f'(x) = 3x^2 \\\\sin(x^2) + 2x^4 \\\\cos(x^2)$\",\n \"expression\": \"f'(x) = 3x^2 \\\\sin(x^2) + 2x^4 \\\\cos(x^2)\"\n },\n {\n \"title\": \"Step 4: Simplify the Answer\",\n \"content\": \"We can simplify the expression by factoring out the common term $x^2$ from both terms:\\n$f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))$\",\n \"expression\": \"f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))\"\n }\n ],\n \"final_answer\": \"$f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))$\"\n}\n```",
11
+ "expression": null
12
+ }
13
+ ],
14
+ "retrieved_from_kb": false,
15
+ "knowledge_hits": [
16
+ {
17
+ "document_id": "0",
18
+ "question": "Find the derivative of f(x)=x^3 sin(x) - MyTutor",
19
+ "answer": "The derivative of f(x)=x^3 sin(x) is x^2(3 sin(x) + x cos(x)), using the product rule.",
20
+ "similarity": 0.79119295
21
+ },
22
+ {
23
+ "document_id": "1",
24
+ "question": "Solved The derivative of f(x) can be expressed in the form",
25
+ "answer": "The derivative of f(x) can be expressed in the form f'(x)=g(x)(3sinx+2)2, where g(x) is some function of x. Find g(x).Answer: g(x)=(b",
26
+ "similarity": 0.78512776
27
+ },
28
+ {
29
+ "document_id": "2",
30
+ "question": "What is the first derivative of the function f(x) =x^3sinx?",
31
+ "answer": "The first derivative is: u'v+uv'= 3x^2 sin x + x^3 cos x, or, x^2 (3 sin x + x cos x). However, if you mean x^(3 sin x),",
32
+ "similarity": 0.7173491
33
+ },
34
+ {
35
+ "document_id": "3",
36
+ "question": "The derivative of the function f is f'(x)=x-3sin x^2. Which ...",
37
+ "answer": "The derivative of the function f is f'(x)=x-3sin x^2. Which interval contains the greatest number of relative minimums of f? A. (-2,2) B. (-3,1) C. (-4",
38
+ "similarity": 0.6929958
39
+ },
40
+ {
41
+ "document_id": "4",
42
+ "question": "Larson Calculus 2.3 #44: Differentiate f(x) = sin(x)/x^3 ...",
43
+ "answer": "Let me show you how I find the derivative with the quotient rule.",
44
+ "similarity": 0.63452125
45
+ }
46
+ ],
47
+ "citations": [
48
+ {
49
+ "title": "Find the derivative of f(x)=x^3 sin(x) - MyTutor",
50
+ "url": "https://www.mytutor.co.uk/answers/17399/A-Level/Maths/Find-the-derivative-of-f-x-x-3-sin-x/"
51
+ },
52
+ {
53
+ "title": "Solved The derivative of f(x) can be expressed in the form",
54
+ "url": "https://www.chegg.com/homework-help/questions-and-answers/derivative-f-x-expressed-form-f-x-g-x-3sinx-2-2-g-x-function-x-find-g-x--answer-g-x-b-let--q147428016"
55
+ },
56
+ {
57
+ "title": "What is the first derivative of the function f(x) =x^3sinx?",
58
+ "url": "https://www.quora.com/What-is-the-first-derivative-of-the-function-f-x-x-3sinx"
59
+ },
60
+ {
61
+ "title": "The derivative of the function f is f'(x)=x-3sin x^2. Which ...",
62
+ "url": "https://www.gauthmath.com/solution/1813181098343573/116-Multiple-Choice-The-derivative-of-the-function-is-f-x-x-3sin-x2-Which-interv"
63
+ },
64
+ {
65
+ "title": "Larson Calculus 2.3 #44: Differentiate f(x) = sin(x)/x^3 ...",
66
+ "url": "https://www.youtube.com/watch?v=9qbaO_Jihzg"
67
+ }
68
+ ],
69
+ "source": "tavily-sdk",
70
+ "feedback_required": true,
71
+ "gateway_trace": [
72
+ "input_guardrails_pass",
73
+ "router.inspect",
74
+ "router->search",
75
+ "search.source=tavily-sdk",
76
+ "output_guardrails_pass"
77
+ ]
78
+ },
79
+ "feedback": {
80
+ "thumbs_up": true,
81
+ "primary_issue": null,
82
+ "has_better_solution": false,
83
+ "solution_type": null,
84
+ "better_solution_text": null,
85
+ "better_solution_pdf_base64": null,
86
+ "better_solution_image_base64": null
87
+ }
88
+ },
89
+ {
90
+ "message_id": "assistant-1762145260594",
91
+ "query": "Find the derivative of the function: f(x)=3x^4−5x^2+2squareroot(x)−7/x^3",
92
+ "agent_response": {
93
+ "answer": "Thus, the function can be rewritten as:",
94
+ "steps": [
95
+ {
96
+ "title": "Overview",
97
+ "content": "Let's find the derivative of the function $f(x)$ step by step.",
98
+ "expression": null
99
+ }
100
+ ],
101
+ "retrieved_from_kb": false,
102
+ "knowledge_hits": [
103
+ {
104
+ "document_id": "0",
105
+ "question": "Derivative Calculator - Mathway",
106
+ "answer": "Upgrade Calculators Help Sign In Sign Up # Derivative Calculator **Step 1:** Enter the function you want to find the derivative of in the editor. > The Derivative Calculator supports solving first, second...., fourth derivatives, as well as implicit differentiation and finding the zeros/roots. You can also get a better visual and understanding of the function by using our graphing tool. > Chain Rule: ddx[f(g(x))]=f'(g(x))g'(x) **Step 2:** Click the blue arrow to submit. Choose **\"Find the Derivative\"** from the topic selector and click to see the result! (2x−9)10 using Amazon.Auth.AccessControlPolicy; Mathway requires javascript and a modern browser. Please ensure that your password is at least 8 characters and contains each of the following: * a special character: @$#!%\\*?&",
107
+ "similarity": 0.98579
108
+ },
109
+ {
110
+ "document_id": "1",
111
+ "question": "Derivative Calculator - Math Portal",
112
+ "answer": "* Calculators Math Calculators, Lessons and Formulas * Calculators * Calculators * Derivative Calculator # Derivative calculator * Fraction Calculator * Percentage Calculator Polynomials * Factoring Polynomials * Simplify Polynomials * Multiplication / Division Solving Equations * Polynomial Equations * Square Calculator * Rectangle Calculator * Circle Calculator * Hexagon Calculator * Rhombus Calculator * Trapezoid Calculator * Distance calculator * Midpoint Calculator * Triangle Calculator * Division * Determinant Calculator Calculus Calculators * Limit Calculator * Derivative Calculator * Integral Calculator Equations * Probability Calculator * T-Test Calculator Financial Calculators * Amortization Calculator * Annuity Calculator Other Calculators Related calculators Limit calculator Integral calculator Derivative calculator – Widget Code **Widget preview:** Derivative calculator Online Math Calculators",
113
+ "similarity": 0.98278
114
+ },
115
+ {
116
+ "document_id": "2",
117
+ "question": "Derivative Calculator - Symbolab",
118
+ "answer": "*Example* : Find the derivative of $f\\left(x\\right)=\\frac{1}{x}$. scientific calculator inverse calculator simplify calculator distance calculator fractions calculator interval notation calculator cross product calculator probability calculator derivative calculator series calculator ratios calculator statistics calculator integral calculator inverse laplace transform calculator rounding calculator gcf calculator algebra calculator tangent line calculator trigonometry calculator log calculator standard deviation calculator linear equation calculator antiderivative calculator laplace transform calculator quadratic equation calculator domain calculator decimals calculator limit calculator equation solver definite integral calculator matrix inverse calculator matrix calculator system of equations calculator calculus calculator slope calculator long division calculator factors calculator polynomial calculator square root calculator implicit differentiation calculator word problem solver differential equation calculator average calculator synthetic division calculator",
119
+ "similarity": 0.97997
120
+ },
121
+ {
122
+ "document_id": "3",
123
+ "question": "Calculus - How to find the derivative of a function using the power rule",
124
+ "answer": "Calculus - How to find the derivative of a function using the power rule\nMySecretMathTutor\n241000 subscribers\n6381 likes\n582542 views\n5 Sep 2013\nThis video shows how to find the derivative of a function using the power rule. Remember that this rule only works on functions of the form x^n where n is not equal to zero. For more videos please visit http://www.mysecretmathtutor.com\n189 comments\n",
125
+ "similarity": 0.97435
126
+ },
127
+ {
128
+ "document_id": "4",
129
+ "question": "Finding the Derivative of a Square Root Function Using Definition of ...",
130
+ "answer": "Finding the Derivative of a Square Root Function Using Definition of a Derivative\nPatrick J\n1400000 subscribers\n1825 likes\n268324 views\n11 Dec 2011\nIn this video we compute the derivative of a square root function, namely f(x) = sqrt (2x + 1) by using the definition of a derivative. We could use the power rule along with the chain rule as well of course, but this what one is often asked in the beginning of a calculus course.\nNote that to help simplify the expression for the derivative, we will make use of the conjugate, a common technique to help simplify these types of expressions with square roots.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n#math, #algebra, #khanacademy, #khan, #geometry, #education, #montessori, #calculus, #homeschooling, #functions, #notation, #organicchemistry, #organicchemistrytutor, #calculus, #derivative, #powerrule, #limits, #limit, #graph, #graphing, #conjugate, #infinity, #patrickjmt, #rational, #shortcut, #trick, #commondenominator, #fraction, #university, #admissions, #test #continuity #onesided #infinite #absolutevalue #tangents #tangentsandnormals #tangentline #definition #hindi #india\n126 comments\n",
131
+ "similarity": 0.96945
132
+ }
133
+ ],
134
+ "citations": [
135
+ {
136
+ "title": "Derivative Calculator - Mathway",
137
+ "url": "https://www.mathway.com/Calculator/derivative-calculator"
138
+ },
139
+ {
140
+ "title": "Derivative Calculator - Math Portal",
141
+ "url": "https://www.mathportal.org/calculators/calculus/derivative-calculator.php"
142
+ },
143
+ {
144
+ "title": "Derivative Calculator - Symbolab",
145
+ "url": "https://www.symbolab.com/solver/derivative-calculator"
146
+ },
147
+ {
148
+ "title": "Calculus - How to find the derivative of a function using the power rule",
149
+ "url": "https://www.youtube.com/watch?v=pBc4Udqw330"
150
+ },
151
+ {
152
+ "title": "Finding the Derivative of a Square Root Function Using Definition of ...",
153
+ "url": "https://www.youtube.com/watch?v=3bmw9UQxbRI"
154
+ }
155
+ ],
156
+ "source": "tavily-sdk",
157
+ "feedback_required": true,
158
+ "gateway_trace": [
159
+ "input_guardrails_pass",
160
+ "router.inspect",
161
+ "router->search",
162
+ "search.source=tavily-sdk",
163
+ "output_guardrails_pass"
164
+ ]
165
+ },
166
+ "feedback": {
167
+ "thumbs_up": false,
168
+ "primary_issue": "unclear",
169
+ "has_better_solution": false,
170
+ "solution_type": null,
171
+ "better_solution_text": null,
172
+ "better_solution_pdf_base64": null,
173
+ "better_solution_image_base64": null
174
+ }
175
+ }
176
+ ]
data/knowledge_base.jsonl ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"id": "gsm8k-001", "question": "Solve for x: 2x + 5 = 17.", "answer": "x = 6", "step_by_step": ["Subtract 5 from both sides: 2x = 12", "Divide both sides by 2: x = 6"]}
2
+ {"id": "gsm8k-002", "question": "A rectangle has a length of 12 cm and width 5 cm. What is its perimeter?", "answer": "Perimeter = 34 cm", "step_by_step": ["Perimeter of rectangle = 2*(length + width)", "Compute length + width: 12 + 5 = 17", "Multiply by 2: 2*17 = 34 cm"]}
3
+ {"id": "gsm8k-003", "question": "If 3/4 of a number is 18, what is the number?", "answer": "The number is 24", "step_by_step": ["Let the number be n", "3/4 * n = 18", "Multiply both sides by 4/3: n = 18 * 4/3", "Compute: 18 * 4 = 72", "Divide by 3: 72 / 3 = 24"]}
4
+ {"id": "gsm8k-004", "question": "Simplify: (x^2 - 9)/(x - 3)", "answer": "x + 3", "step_by_step": ["Factor numerator: x^2 - 9 = (x - 3)(x + 3)", "Cancel common factor (x - 3)", "Result: x + 3, for x ≠ 3"]}
5
+ {"id": "gsm8k-005", "question": "Evaluate the derivative d/dx of f(x) = 3x^3 - 4x^2 + x - 7.", "answer": "f'(x) = 9x^2 - 8x + 1", "step_by_step": ["Apply power rule to each term", "Derivative of 3x^3 is 9x^2", "Derivative of -4x^2 is -8x", "Derivative of x is 1", "Derivative of constant -7 is 0", "Combine terms: 9x^2 - 8x + 1"]}
6
+
mcp_servers/tavily_server.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal MCP-compatible stdio server wrapping Tavily search."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import sys
7
+ from typing import Any
8
+
9
+ from tavily import TavilyClient
10
+
11
+
12
+ def _respond(response: dict[str, Any]) -> None:
13
+ sys.stdout.write(json.dumps(response) + "\n")
14
+ sys.stdout.flush()
15
+
16
+
17
+ def main() -> None:
18
+ client = TavilyClient()
19
+
20
+ for line in sys.stdin:
21
+ line = line.strip()
22
+ if not line:
23
+ continue
24
+ try:
25
+ request = json.loads(line)
26
+ except json.JSONDecodeError:
27
+ continue
28
+
29
+ method = request.get("method")
30
+ req_id = request.get("id")
31
+
32
+ if method == "initialize":
33
+ _respond({
34
+ "jsonrpc": "2.0",
35
+ "id": req_id,
36
+ "result": {
37
+ "serverInfo": {"name": "tavily-mcp", "version": "0.1.0"},
38
+ "capabilities": {"tools": True},
39
+ },
40
+ })
41
+ continue
42
+
43
+ if method == "shutdown":
44
+ _respond({"jsonrpc": "2.0", "id": req_id, "result": None})
45
+ break
46
+
47
+ if method == "tavily_search":
48
+ params = request.get("params", {})
49
+ query = params.get("query", "")
50
+ max_results = params.get("max_results", 5)
51
+ include_images = params.get("include_images", False)
52
+ try:
53
+ result = client.search(query=query, max_results=max_results, include_images=include_images)
54
+ _respond({"jsonrpc": "2.0", "id": req_id, "result": result})
55
+ except Exception as exc: # pragma: no cover
56
+ _respond({"jsonrpc": "2.0", "id": req_id, "error": {"code": -32000, "message": str(exc)}})
57
+ continue
58
+
59
+ _respond({"jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": "Unknown method"}})
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
64
+
65
+
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "math-agent-backend"
3
+ version = "0.1.0"
4
+ description = "Agentic RAG math tutoring backend with guardrails, MCP search, and feedback loop"
5
+ authors = ["AI Assistant <assistant@example.com>"]
6
+ packages = [{ include = "app" }]
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.13"
10
+ fastapi = "*"
11
+ uvicorn = { extras = ["standard"], version = "*" }
12
+ pydantic = "*"
13
+ pydantic-settings = "*"
14
+ numpy = "*"
15
+ pandas = "*"
16
+ weaviate-client = "*"
17
+ sentence-transformers = "*"
18
+ scikit-learn = "*"
19
+ httpx = "*"
20
+ aiohttp = "*"
21
+ tavily = "*"
22
+ python-dotenv = "*"
23
+ google-generativeai = "*"
24
+ groq = "*"
25
+ # dspy-ai removed from core dependencies to avoid complex transitive
26
+ # dependency conflicts (fastapi/litellm). If you need DSPy features,
27
+ # add it to an optional extras group after verifying compatibility.
28
+ # dspy-ai = { version = "^2.6.0", python = "*" }
29
+ structlog = "*"
30
+ tenacity = "*"
31
+ langgraph = "*"
32
+
33
+ [tool.poetry.group.dev.dependencies]
34
+ pytest = "*"
35
+ pytest-asyncio = "*"
36
+ pytest-cov = "*"
37
+ ruff = "*"
38
+ black = "*"
39
+ mypy = "*"
40
+
41
+ [build-system]
42
+ requires = ["poetry-core>=1.0.0"]
43
+ build-backend = "poetry.core.masonry.api"
44
+
render.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ - type: web
3
+ name: math-agent-backend
4
+ env: python
5
+ region: oregon
6
+ plan: free
7
+ buildCommand: |
8
+ pip install poetry
9
+ poetry install --no-interaction --no-ansi --no-root
10
+ startCommand: |
11
+ poetry run uvicorn app.main:app --host 0.0.0.0 --port $PORT
12
+ autoDeploy: true
13
+ envVars:
14
+ - key: WEAVIATE_URL
15
+ sync: false
16
+ - key: WEAVIATE_API_KEY
17
+ sync: false
18
+ - key: TAVILY_API_KEY
19
+ sync: false
20
+ - key: MCP_TAVILY_URL
21
+ sync: false
22
+ - key: GEMINI_API_KEY
23
+ sync: false
24
+ - key: SIMILARITY_THRESHOLD
25
+ value: "0.80"
26
+ - key: ENFORCE_INPUT_GUARDRAILS
27
+ value: "true"
28
+ - key: ENFORCE_OUTPUT_GUARDRAILS
29
+ value: "true"
30
+
scripts/build_kb.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to build FAISS vector store from knowledge base dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+ import faiss
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sentence_transformers import SentenceTransformer
13
+
14
+ from app.config import settings
15
+
16
+
17
+ def load_dataset(path: Path) -> list[dict]:
18
+ data: list[dict] = []
19
+ for line in path.read_text(encoding="utf-8").splitlines():
20
+ if line.strip():
21
+ data.append(json.loads(line))
22
+ return data
23
+
24
+
25
+ def build_index(records: list[dict]) -> tuple[faiss.Index, np.ndarray]:
26
+ encoder = SentenceTransformer(settings.embedding_model_name)
27
+ corpus = [record["question"] + "\n" + record.get("answer", "") for record in records]
28
+ embeddings = encoder.encode(corpus, show_progress_bar=True)
29
+ embeddings = np.array(embeddings).astype("float32")
30
+ dimension = embeddings.shape[1]
31
+ index = faiss.IndexFlatIP(dimension)
32
+ faiss.normalize_L2(embeddings)
33
+ index.add(embeddings)
34
+ return index, embeddings
35
+
36
+
37
+ def main(dataset_path: Path, output_index: Path, output_metadata: Path) -> None:
38
+ records = load_dataset(dataset_path)
39
+ index, _ = build_index(records)
40
+
41
+ output_index.parent.mkdir(parents=True, exist_ok=True)
42
+ output_metadata.parent.mkdir(parents=True, exist_ok=True)
43
+
44
+ faiss.write_index(index, str(output_index))
45
+
46
+ metadata = pd.DataFrame(records)
47
+ metadata.to_parquet(output_metadata, index=False)
48
+
49
+ print(f"Vector store written to {output_index}")
50
+ print(f"Metadata written to {output_metadata}")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ parser = argparse.ArgumentParser(description="Build FAISS vector store for math agent")
55
+ parser.add_argument("--dataset", type=Path, default=Path("backend/data/knowledge_base.jsonl"))
56
+ parser.add_argument("--index", type=Path, default=Path(settings.vector_store_path))
57
+ parser.add_argument("--metadata", type=Path, default=Path(settings.vector_store_metadata_path))
58
+ args = parser.parse_args()
59
+
60
+ main(args.dataset, args.index, args.metadata)
61
+
62
+
scripts/build_weaviate_kb.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to build Weaviate vector store from knowledge base dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+ from typing import List, Dict, Any
9
+
10
+ import numpy as np
11
+ import weaviate
12
+ from sentence_transformers import SentenceTransformer
13
+ from tqdm import tqdm
14
+
15
+ from app.config import settings
16
+
17
+
18
+ def load_dataset(path: Path) -> List[Dict[str, Any]]:
19
+ data: List[Dict[str, Any]] = []
20
+ for line in path.read_text(encoding="utf-8").splitlines():
21
+ if line.strip():
22
+ data.append(json.loads(line))
23
+ return data
24
+
25
+
26
+ def get_weaviate_client() -> weaviate.Client:
27
+ """Initialize Weaviate client with authentication."""
28
+ auth_config = weaviate.auth.AuthApiKey(api_key=settings.weaviate_api_key)
29
+ client = weaviate.Client(
30
+ url=settings.weaviate_url,
31
+ auth_client_secret=auth_config,
32
+ )
33
+ return client
34
+
35
+
36
+ def setup_schema(client: weaviate.Client) -> None:
37
+ """Create the schema if it doesn't exist."""
38
+ if not client.schema.exists(settings.weaviate_class_name):
39
+ class_obj = {
40
+ "class": settings.weaviate_class_name,
41
+ "vectorizer": "none", # We'll provide our own vectors
42
+ "properties": [
43
+ {"name": "question", "dataType": ["text"]},
44
+ {"name": "answer", "dataType": ["text"]},
45
+ {"name": "source", "dataType": ["text"]},
46
+ ]
47
+ }
48
+ client.schema.create_class(class_obj)
49
+
50
+
51
+ def build_index(client: weaviate.Client, records: List[Dict[str, Any]], batch_size: int = 100) -> None:
52
+ """Build the vector index by importing records in batches."""
53
+ encoder = SentenceTransformer(settings.embedding_model_name)
54
+
55
+ with client.batch as batch:
56
+ batch.batch_size = batch_size
57
+
58
+ for record in tqdm(records, desc="Importing records"):
59
+ # Generate embedding from question and answer
60
+ text = record["question"] + "\n" + record.get("answer", "")
61
+ embedding = encoder.encode(text)
62
+
63
+ # Normalize the embedding
64
+ embedding = embedding / np.linalg.norm(embedding)
65
+
66
+ # Prepare properties
67
+ properties = {
68
+ "question": record["question"],
69
+ "answer": record.get("answer", ""),
70
+ "source": record.get("source", "")
71
+ }
72
+
73
+ # Import the object with its vector
74
+ batch.add_data_object(
75
+ data_object=properties,
76
+ class_name=settings.weaviate_class_name,
77
+ vector=embedding.tolist()
78
+ )
79
+
80
+
81
+ def main(dataset_path: Path) -> None:
82
+ print(f"Loading dataset from {dataset_path}")
83
+ records = load_dataset(dataset_path)
84
+
85
+ print("Initializing Weaviate client")
86
+ client = get_weaviate_client()
87
+
88
+ print("Setting up schema")
89
+ setup_schema(client)
90
+
91
+ print("Building index")
92
+ build_index(client, records)
93
+
94
+ print(f"Successfully imported {len(records)} records to Weaviate")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ parser = argparse.ArgumentParser(description="Build Weaviate vector store for math agent")
99
+ parser.add_argument("--dataset", type=Path, default=Path("backend/data/knowledge_base.jsonl"))
100
+ args = parser.parse_args()
101
+
102
+ main(args.dataset)