Spaces:
Build error
Build error
Commit ·
886572e
1
Parent(s): 84e9d55
Add backend FastAPI code
Browse files- Dockerfile +18 -0
- app/__init__.py +6 -0
- app/__pycache__/__init__.cpython-313.pyc +0 -0
- app/__pycache__/config.cpython-313.pyc +0 -0
- app/__pycache__/guardrails.cpython-313.pyc +0 -0
- app/__pycache__/logger.cpython-313.pyc +0 -0
- app/__pycache__/main.cpython-313.pyc +0 -0
- app/__pycache__/schemas.cpython-313.pyc +0 -0
- app/config.py +104 -0
- app/guardrails.py +100 -0
- app/logger.py +37 -0
- app/main.py +156 -0
- app/schemas.py +117 -0
- app/services/__pycache__/kb_updater.cpython-313.pyc +0 -0
- app/services/__pycache__/retrieval.cpython-313.pyc +0 -0
- app/services/__pycache__/vector_store.cpython-313.pyc +0 -0
- app/services/kb_updater.py +56 -0
- app/services/retrieval.py +109 -0
- app/services/vector_store.py +348 -0
- app/tools/__init__.py +2 -0
- app/tools/__pycache__/__init__.cpython-313.pyc +0 -0
- app/tools/__pycache__/audio.cpython-313.pyc +0 -0
- app/tools/__pycache__/dspy_pipeline.cpython-313.pyc +0 -0
- app/tools/__pycache__/validator.cpython-313.pyc +0 -0
- app/tools/__pycache__/vision.cpython-313.pyc +0 -0
- app/tools/__pycache__/web_search.cpython-313.pyc +0 -0
- app/tools/audio.py +41 -0
- app/tools/dspy_pipeline.py +476 -0
- app/tools/validator.py +28 -0
- app/tools/vision.py +33 -0
- app/tools/web_search.py +395 -0
- app/workflows/__init__.py +3 -0
- app/workflows/__pycache__/__init__.cpython-313.pyc +0 -0
- app/workflows/__pycache__/langgraph_pipeline.cpython-313.pyc +0 -0
- app/workflows/langgraph_pipeline.py +251 -0
- backend/data/feedback_db.json +176 -0
- data/knowledge_base.jsonl +6 -0
- mcp_servers/tavily_server.py +65 -0
- poetry.lock +0 -0
- pyproject.toml +44 -0
- render.yaml +30 -0
- scripts/build_kb.py +62 -0
- scripts/build_weaviate_kb.py +102 -0
Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13.0
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
# Install Poetry
|
| 6 |
+
RUN pip install poetry
|
| 7 |
+
|
| 8 |
+
# Copy your project files
|
| 9 |
+
COPY pyproject.toml poetry.lock ./
|
| 10 |
+
|
| 11 |
+
# Install project dependencies
|
| 12 |
+
RUN poetry install --no-interaction --no-ansi
|
| 13 |
+
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
|
| 18 |
+
CMD ["poetry", "run", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Math Agent FastAPI application package."""
|
| 2 |
+
|
| 3 |
+
from .main import create_app
|
| 4 |
+
|
| 5 |
+
__all__ = ["create_app"]
|
| 6 |
+
|
app/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (259 Bytes). View file
|
|
|
app/__pycache__/config.cpython-313.pyc
ADDED
|
Binary file (3.62 kB). View file
|
|
|
app/__pycache__/guardrails.cpython-313.pyc
ADDED
|
Binary file (4.77 kB). View file
|
|
|
app/__pycache__/logger.cpython-313.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
app/__pycache__/main.cpython-313.pyc
ADDED
|
Binary file (7.95 kB). View file
|
|
|
app/__pycache__/schemas.cpython-313.pyc
ADDED
|
Binary file (4.86 kB). View file
|
|
|
app/config.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Application configuration using Pydantic settings."""
|
| 2 |
+
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
import os
|
| 5 |
+
from typing import List
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 9 |
+
|
| 10 |
+
# Resolve .env path robustly so it works whether .env is at repo root or backend/.env
|
| 11 |
+
# This file is at backend/app/config.py → repo root is two levels up
|
| 12 |
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 13 |
+
_BACKEND_DIR = _REPO_ROOT / "backend"
|
| 14 |
+
_ENV_CANDIDATES = [
|
| 15 |
+
_REPO_ROOT / ".env", # D:/ai_planet/.env
|
| 16 |
+
_BACKEND_DIR / ".env", # D:/ai_planet/backend/.env
|
| 17 |
+
Path.cwd() / ".env", # fallback to CWD
|
| 18 |
+
]
|
| 19 |
+
_ENV_FILE = next((p for p in _ENV_CANDIDATES if p.exists()), _REPO_ROOT / ".env")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Settings(BaseSettings):
|
| 23 |
+
"""Runtime configuration loaded from environment or .env file."""
|
| 24 |
+
|
| 25 |
+
model_config = SettingsConfigDict(env_file=str(_ENV_FILE), env_file_encoding="utf-8", extra="allow")
|
| 26 |
+
|
| 27 |
+
app_name: str = "Math Agentic RAG"
|
| 28 |
+
environment: str = "local"
|
| 29 |
+
debug: bool = True
|
| 30 |
+
|
| 31 |
+
# Vector store configuration (Weaviate)
|
| 32 |
+
weaviate_url: str | None = None
|
| 33 |
+
weaviate_api_key: str | None = None
|
| 34 |
+
weaviate_collection: str = "mathvectors" # Collection name in Weaviate cloud
|
| 35 |
+
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 36 |
+
top_k: int = 4
|
| 37 |
+
similarity_threshold: float = 0.80
|
| 38 |
+
|
| 39 |
+
# Guardrail keywords for allow-listing math topics
|
| 40 |
+
allowed_subjects: List[str] = [
|
| 41 |
+
"algebra",
|
| 42 |
+
"geometry",
|
| 43 |
+
"calculus",
|
| 44 |
+
"probability",
|
| 45 |
+
"statistics",
|
| 46 |
+
"number theory",
|
| 47 |
+
"trigonometry",
|
| 48 |
+
"combinatorics",
|
| 49 |
+
"math",
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
blocked_keywords: List[str] = [
|
| 53 |
+
"violence",
|
| 54 |
+
"weapon",
|
| 55 |
+
"politics",
|
| 56 |
+
"hate",
|
| 57 |
+
"self-harm",
|
| 58 |
+
"explicit",
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# External service credentials
|
| 62 |
+
tavily_api_key: str | None = None
|
| 63 |
+
mcp_tavily_url: str | None = None # e.g., https://mcp.tavily.com/mcp/?tavilyApiKey=tvly-...
|
| 64 |
+
gemini_api_key: str | None = None
|
| 65 |
+
groq_api_key: str | None = None
|
| 66 |
+
|
| 67 |
+
# Model configuration
|
| 68 |
+
gemini_model: str = "gemini-2.5-flash" # Fast model for search grounding and vision
|
| 69 |
+
dspy_model: str = "gemini-2.5-flash" # For DSPy pipeline
|
| 70 |
+
whisper_model: str = "whisper-large-v3-turbo" # For audio transcription
|
| 71 |
+
dspy_max_tokens: int = 2048
|
| 72 |
+
|
| 73 |
+
# Feedback storage
|
| 74 |
+
feedback_store_path: str = "backend/data/feedback_db.json"
|
| 75 |
+
|
| 76 |
+
# Gateway / guardrails toggles
|
| 77 |
+
enforce_input_guardrails: bool = True
|
| 78 |
+
enforce_output_guardrails: bool = True
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@lru_cache(maxsize=1)
|
| 82 |
+
def get_settings() -> Settings:
|
| 83 |
+
"""Return a cached Settings instance."""
|
| 84 |
+
settings = Settings()
|
| 85 |
+
|
| 86 |
+
# Allow alternate env var aliases for KB similarity threshold
|
| 87 |
+
# Priority: KB_SIMILARITY_THRESHOLD -> KB_THRESHOLD -> existing value
|
| 88 |
+
kb_thresh_env = (
|
| 89 |
+
os.getenv("KB_SIMILARITY_THRESHOLD")
|
| 90 |
+
or os.getenv("KB_THRESHOLD")
|
| 91 |
+
)
|
| 92 |
+
if kb_thresh_env:
|
| 93 |
+
try:
|
| 94 |
+
settings.similarity_threshold = float(kb_thresh_env)
|
| 95 |
+
except ValueError:
|
| 96 |
+
# Ignore invalid values and keep existing
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
return settings
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
settings = get_settings()
|
| 103 |
+
|
| 104 |
+
|
app/guardrails.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Guardrail functions for input/output moderation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
from .config import settings
|
| 9 |
+
from .logger import get_logger
|
| 10 |
+
|
| 11 |
+
logger = get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GuardrailViolation(Exception):
|
| 15 |
+
"""Raised when a guardrail condition is violated."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, message: str, code: str = "guardrail_violation") -> None:
|
| 18 |
+
super().__init__(message)
|
| 19 |
+
self.code = code
|
| 20 |
+
self.message = message
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
SANITIZE_REGEX = re.compile(r"[^\x20-\x7E]+")
|
| 24 |
+
PII_REGEX = re.compile(r"\b(\d{3}-?\d{2}-?\d{4}|\d{16}|[A-Z]{5}[0-9]{4}[A-Z]{1})\b")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def sanitize_text(text: str) -> str:
|
| 28 |
+
"""Remove non-ascii characters and collapse whitespace."""
|
| 29 |
+
|
| 30 |
+
cleaned = SANITIZE_REGEX.sub("", text)
|
| 31 |
+
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
| 32 |
+
return cleaned
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def run_input_guardrails(user_text: str) -> str:
|
| 36 |
+
"""Validate and sanitize user input before routing.
|
| 37 |
+
|
| 38 |
+
Note: Since Gemini is used to validate math questions, we're very permissive here.
|
| 39 |
+
Only block obvious non-math content and security issues.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
logger.debug("guardrails.input.start", text=user_text)
|
| 43 |
+
|
| 44 |
+
text = sanitize_text(user_text)
|
| 45 |
+
|
| 46 |
+
# Only block PII - this is a security issue
|
| 47 |
+
if PII_REGEX.search(text):
|
| 48 |
+
logger.warning("guardrails.input.blocked", reason="pii_match")
|
| 49 |
+
raise GuardrailViolation("The request may contain sensitive information. Please remove it and try again.")
|
| 50 |
+
|
| 51 |
+
lower = text.lower()
|
| 52 |
+
|
| 53 |
+
# Block only obvious non-educational content - let Gemini decide if it's math
|
| 54 |
+
if any(keyword in lower for keyword in settings.blocked_keywords):
|
| 55 |
+
logger.warning("guardrails.input.blocked", reason="blocked_keyword")
|
| 56 |
+
raise GuardrailViolation("I can only help with mathematics-related educational questions.")
|
| 57 |
+
|
| 58 |
+
# Very permissive - allow anything that could be math-related
|
| 59 |
+
# Gemini will be the final validator, so we just do basic sanity checks
|
| 60 |
+
# Allow anything with math-like patterns, or let it through and let Gemini decide
|
| 61 |
+
|
| 62 |
+
logger.debug("guardrails.input.pass")
|
| 63 |
+
return text
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def run_output_guardrails(response_text: str) -> Tuple[str, list[str]]:
|
| 67 |
+
"""Apply output guardrails - very permissive since Gemini validates math content.
|
| 68 |
+
|
| 69 |
+
Only blocks obvious non-math content. Most validation is done by Gemini.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
logger.debug("guardrails.output.start")
|
| 73 |
+
text = sanitize_text(response_text)
|
| 74 |
+
|
| 75 |
+
citations: list[str] = []
|
| 76 |
+
|
| 77 |
+
# Only block PII - security issue
|
| 78 |
+
if PII_REGEX.search(text):
|
| 79 |
+
logger.warning("guardrails.output.flag", reason="pii_detected")
|
| 80 |
+
text = re.sub(PII_REGEX, "[redacted]", text)
|
| 81 |
+
|
| 82 |
+
# Extract URLs as citations
|
| 83 |
+
if "http" in text.lower():
|
| 84 |
+
citations = re.findall(r"(https?://\S+)", text)
|
| 85 |
+
|
| 86 |
+
# Very permissive - allow almost everything
|
| 87 |
+
# Since Gemini generates the responses, it should be math-related
|
| 88 |
+
# Only block if it's clearly and obviously non-math (e.g., contains blocked keywords)
|
| 89 |
+
lower = text.lower()
|
| 90 |
+
|
| 91 |
+
# Only block if it contains obviously inappropriate content
|
| 92 |
+
if any(keyword in lower for keyword in settings.blocked_keywords):
|
| 93 |
+
logger.warning("guardrails.output.blocked", reason="blocked_keyword", preview=text[:100])
|
| 94 |
+
raise GuardrailViolation("The generated response contains inappropriate content.")
|
| 95 |
+
|
| 96 |
+
# Everything else passes - let Gemini handle math validation
|
| 97 |
+
logger.debug("guardrails.output.pass", citations=len(citations))
|
| 98 |
+
return text, citations
|
| 99 |
+
|
| 100 |
+
|
app/logger.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Structured logging helpers."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import structlog
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def configure_logging(level: int = logging.INFO) -> None:
|
| 10 |
+
"""Configure structlog for the application."""
|
| 11 |
+
|
| 12 |
+
timestamper = structlog.processors.TimeStamper(fmt="iso")
|
| 13 |
+
|
| 14 |
+
structlog.configure(
|
| 15 |
+
processors=[
|
| 16 |
+
structlog.stdlib.add_logger_name,
|
| 17 |
+
structlog.stdlib.add_log_level,
|
| 18 |
+
timestamper,
|
| 19 |
+
structlog.processors.StackInfoRenderer(),
|
| 20 |
+
structlog.processors.format_exc_info,
|
| 21 |
+
structlog.processors.UnicodeDecoder(),
|
| 22 |
+
structlog.dev.ConsoleRenderer() if level <= logging.DEBUG else structlog.processors.JSONRenderer(),
|
| 23 |
+
],
|
| 24 |
+
context_class=dict,
|
| 25 |
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
| 26 |
+
cache_logger_on_first_use=True,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(level=level, format="%(message)s")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_logger(*args: Any, **kwargs: Any) -> structlog.stdlib.BoundLogger:
|
| 33 |
+
"""Return a configured structlog logger."""
|
| 34 |
+
|
| 35 |
+
return structlog.get_logger(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
|
app/main.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI entrypoint for the math agent backend."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import os
|
| 7 |
+
from typing import AsyncIterator
|
| 8 |
+
|
| 9 |
+
from contextlib import asynccontextmanager
|
| 10 |
+
|
| 11 |
+
from fastapi import Depends, FastAPI, HTTPException, Request
|
| 12 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
+
from fastapi.responses import JSONResponse
|
| 14 |
+
|
| 15 |
+
from .config import settings
|
| 16 |
+
from .guardrails import GuardrailViolation
|
| 17 |
+
from .logger import configure_logging, get_logger
|
| 18 |
+
from .schemas import AgentResponse, ChatRequest, FeedbackRequest
|
| 19 |
+
from .services.retrieval import MathAgent
|
| 20 |
+
from .services.vector_store import (
|
| 21 |
+
load_vector_store,
|
| 22 |
+
save_feedback_to_queue,
|
| 23 |
+
)
|
| 24 |
+
from .services.kb_updater import update_knowledge_base
|
| 25 |
+
from .tools.audio import transcribe_audio
|
| 26 |
+
from .tools.validator import validate_user_solution
|
| 27 |
+
from .tools.vision import extract_text_from_image
|
| 28 |
+
|
| 29 |
+
logger = get_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_app() -> FastAPI:
|
| 33 |
+
configure_logging()
|
| 34 |
+
|
| 35 |
+
app = FastAPI(title=settings.app_name, version="0.1.0")
|
| 36 |
+
|
| 37 |
+
# Configure CORS. Accept a single FRONTEND_URL or a comma-separated
|
| 38 |
+
# FRONTEND_URLS env var. Default to common dev ports (5173 and 3000).
|
| 39 |
+
frontend_env = os.getenv("FRONTEND_URL", os.getenv("FRONTEND_URLS", "http://localhost:5173,http://localhost:3000"))
|
| 40 |
+
# split and strip any whitespace
|
| 41 |
+
allowed_origins = [o.strip() for o in frontend_env.split(",") if o.strip()]
|
| 42 |
+
app.add_middleware(
|
| 43 |
+
CORSMiddleware,
|
| 44 |
+
allow_origins=allowed_origins,
|
| 45 |
+
allow_credentials=True,
|
| 46 |
+
allow_methods=["*"],
|
| 47 |
+
allow_headers=["*"],
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
@asynccontextmanager
|
| 51 |
+
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
| 52 |
+
# Startup
|
| 53 |
+
logger.info("app.startup")
|
| 54 |
+
app.state.vector_store = await asyncio.to_thread(load_vector_store)
|
| 55 |
+
|
| 56 |
+
yield # Server is running
|
| 57 |
+
|
| 58 |
+
# Shutdown
|
| 59 |
+
logger.info("app.shutdown")
|
| 60 |
+
# Clean up any resources if needed
|
| 61 |
+
|
| 62 |
+
app.router.lifespan_context = lifespan
|
| 63 |
+
|
| 64 |
+
def get_agent() -> MathAgent:
|
| 65 |
+
vector_store = getattr(app.state, "vector_store", None)
|
| 66 |
+
if vector_store is None:
|
| 67 |
+
vector_store = load_vector_store()
|
| 68 |
+
app.state.vector_store = vector_store
|
| 69 |
+
return MathAgent(vector_store=vector_store)
|
| 70 |
+
|
| 71 |
+
@app.middleware("http")
|
| 72 |
+
async def add_process_time_header(request: Request, call_next): # type: ignore[override]
|
| 73 |
+
response = await call_next(request)
|
| 74 |
+
response.headers["X-App-Env"] = settings.environment
|
| 75 |
+
return response
|
| 76 |
+
|
| 77 |
+
@app.get("/health")
|
| 78 |
+
async def health() -> dict[str, str]:
|
| 79 |
+
return {"status": "ok", "environment": settings.environment}
|
| 80 |
+
|
| 81 |
+
@app.post("/api/chat", response_model=AgentResponse, responses={400: {"description": "Guardrail failure"}})
|
| 82 |
+
async def chat_endpoint(payload: ChatRequest, agent: MathAgent = Depends(get_agent)) -> AgentResponse:
|
| 83 |
+
logger.info("chat.request", modality=payload.modality)
|
| 84 |
+
|
| 85 |
+
query = payload.query
|
| 86 |
+
if payload.modality == "audio" and payload.audio_base64:
|
| 87 |
+
query = transcribe_audio(payload.audio_base64)
|
| 88 |
+
elif payload.modality == "image" and payload.image_base64:
|
| 89 |
+
query = extract_text_from_image(payload.image_base64)
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
response = await agent.handle_query(query)
|
| 93 |
+
return response
|
| 94 |
+
except GuardrailViolation as exc:
|
| 95 |
+
raise HTTPException(status_code=400, detail=exc.message) from exc
|
| 96 |
+
except Exception as exc: # pragma: no cover - defensive logging
|
| 97 |
+
logger.exception("chat.error", error=str(exc))
|
| 98 |
+
raise HTTPException(status_code=500, detail="Unexpected error handling query") from exc
|
| 99 |
+
|
| 100 |
+
@app.post("/api/feedback")
|
| 101 |
+
async def feedback_endpoint(request: FeedbackRequest) -> JSONResponse:
|
| 102 |
+
"""Handle user feedback with optional solution upload."""
|
| 103 |
+
logger.info(
|
| 104 |
+
"feedback.received",
|
| 105 |
+
message_id=request.message_id,
|
| 106 |
+
helpful=request.feedback.thumbs_up,
|
| 107 |
+
issue=request.feedback.primary_issue
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Always save the feedback first
|
| 111 |
+
record = request.model_dump()
|
| 112 |
+
save_feedback_to_queue(record)
|
| 113 |
+
|
| 114 |
+
# If it's negative feedback with a solution
|
| 115 |
+
if not request.feedback.thumbs_up and request.feedback.has_better_solution:
|
| 116 |
+
solution = None
|
| 117 |
+
|
| 118 |
+
# Get solution based on type
|
| 119 |
+
if request.feedback.solution_type == "text":
|
| 120 |
+
solution = request.feedback.better_solution_text
|
| 121 |
+
elif request.feedback.solution_type == "pdf":
|
| 122 |
+
# TODO: Extract text from PDF
|
| 123 |
+
solution = request.feedback.better_solution_text
|
| 124 |
+
elif request.feedback.solution_type == "image":
|
| 125 |
+
# Use vision model to extract solution from image
|
| 126 |
+
if request.feedback.better_solution_image_base64:
|
| 127 |
+
solution = extract_text_from_image(
|
| 128 |
+
request.feedback.better_solution_image_base64,
|
| 129 |
+
"Extract the mathematical solution from this image."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if solution:
|
| 133 |
+
# Validate and update KB
|
| 134 |
+
success = await update_knowledge_base(request.query, solution)
|
| 135 |
+
return JSONResponse({
|
| 136 |
+
"status": "ok",
|
| 137 |
+
"feedback_saved": True,
|
| 138 |
+
"kb_updated": success
|
| 139 |
+
})
|
| 140 |
+
|
| 141 |
+
return JSONResponse({
|
| 142 |
+
"status": "ok",
|
| 143 |
+
"feedback_saved": True
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
@app.get("/api/vector-store/reload")
|
| 147 |
+
async def reload_vector_store() -> dict[str, str]:
|
| 148 |
+
app.state.vector_store = await asyncio.to_thread(load_vector_store, True)
|
| 149 |
+
return {"status": "reloaded"}
|
| 150 |
+
|
| 151 |
+
return app
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
app = create_app()
|
| 155 |
+
|
| 156 |
+
|
app/schemas.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic schemas for API endpoints."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, List, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Step(BaseModel):
|
| 11 |
+
"""A solution step with title, content and optional LaTeX expression."""
|
| 12 |
+
|
| 13 |
+
title: str
|
| 14 |
+
content: str
|
| 15 |
+
expression: Optional[str] = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Citation(BaseModel):
|
| 19 |
+
"""Represents a web citation."""
|
| 20 |
+
|
| 21 |
+
title: str
|
| 22 |
+
url: str
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class RetrievalContext(BaseModel):
|
| 26 |
+
"""Context snippet from knowledge base."""
|
| 27 |
+
|
| 28 |
+
document_id: str
|
| 29 |
+
question: str
|
| 30 |
+
answer: str
|
| 31 |
+
similarity: float
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class AgentResponse(BaseModel):
|
| 35 |
+
"""Structured agent response."""
|
| 36 |
+
|
| 37 |
+
answer: str
|
| 38 |
+
steps: List[Step]
|
| 39 |
+
retrieved_from_kb: bool = False
|
| 40 |
+
knowledge_hits: List[RetrievalContext] = Field(default_factory=list)
|
| 41 |
+
citations: List[Citation] = Field(default_factory=list)
|
| 42 |
+
source: str = Field(default="kb") # kb | tavily | gemini
|
| 43 |
+
feedback_required: bool = True
|
| 44 |
+
gateway_trace: List[str] = Field(default_factory=list)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ChatRequest(BaseModel):
|
| 48 |
+
"""Incoming chat message payload."""
|
| 49 |
+
|
| 50 |
+
query: str
|
| 51 |
+
modality: str = Field(default="text", description="text|image|audio")
|
| 52 |
+
image_base64: Optional[str] = None
|
| 53 |
+
audio_base64: Optional[str] = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FeedbackMetadata(BaseModel):
|
| 57 |
+
"""User feedback payload."""
|
| 58 |
+
|
| 59 |
+
# Minimal - Always Ask (5 seconds)
|
| 60 |
+
thumbs_up: bool = Field(description="👍 Helpful / 👎 Not helpful")
|
| 61 |
+
|
| 62 |
+
# If thumbs_down, expand to next level (10 seconds)
|
| 63 |
+
primary_issue: Optional[str] = Field(
|
| 64 |
+
None,
|
| 65 |
+
description="What went wrong?",
|
| 66 |
+
# These match exactly with the frontend enums
|
| 67 |
+
enum=["wrong-answer", "unclear", "missing-steps", "wrong-method"]
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Optional user solution upload (30-60 seconds)
|
| 71 |
+
has_better_solution: bool = False
|
| 72 |
+
solution_type: Optional[str] = Field(
|
| 73 |
+
None,
|
| 74 |
+
description="Type of solution upload",
|
| 75 |
+
enum=["text", "pdf", "image"]
|
| 76 |
+
)
|
| 77 |
+
better_solution_text: Optional[str] = None
|
| 78 |
+
better_solution_pdf_base64: Optional[str] = None
|
| 79 |
+
better_solution_image_base64: Optional[str] = None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class FeedbackRequest(BaseModel):
|
| 83 |
+
"""Feedback submission request."""
|
| 84 |
+
|
| 85 |
+
message_id: str
|
| 86 |
+
query: str
|
| 87 |
+
agent_response: AgentResponse
|
| 88 |
+
feedback: FeedbackMetadata
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class BenchmarkResult(BaseModel):
|
| 92 |
+
"""Result for a single benchmark item."""
|
| 93 |
+
|
| 94 |
+
question_id: str
|
| 95 |
+
reference_answer: str
|
| 96 |
+
agent_answer: str
|
| 97 |
+
score: float
|
| 98 |
+
source: str
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class BenchmarkSummary(BaseModel):
|
| 102 |
+
"""Aggregate benchmark output."""
|
| 103 |
+
|
| 104 |
+
dataset: str
|
| 105 |
+
total_questions: int
|
| 106 |
+
average_score: float
|
| 107 |
+
details: List[BenchmarkResult]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class ErrorResponse(BaseModel):
|
| 111 |
+
"""Error response envelope."""
|
| 112 |
+
|
| 113 |
+
detail: str
|
| 114 |
+
code: str = Field(default="error")
|
| 115 |
+
context: Optional[dict[str, Any]] = None
|
| 116 |
+
|
| 117 |
+
|
app/services/__pycache__/kb_updater.cpython-313.pyc
ADDED
|
Binary file (2.16 kB). View file
|
|
|
app/services/__pycache__/retrieval.cpython-313.pyc
ADDED
|
Binary file (4.97 kB). View file
|
|
|
app/services/__pycache__/vector_store.cpython-313.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
app/services/kb_updater.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Knowledge base update service for validated user solutions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
|
| 10 |
+
from ..config import settings
|
| 11 |
+
from ..logger import get_logger
|
| 12 |
+
from ..tools.validator import validate_user_solution
|
| 13 |
+
from .vector_store import load_vector_store
|
| 14 |
+
|
| 15 |
+
logger = get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
async def update_knowledge_base(question: str, solution: str, source: str = "user-feedback") -> bool:
|
| 19 |
+
"""Validate and add a new solution to the knowledge base."""
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
# First validate the solution
|
| 23 |
+
is_valid = await asyncio.to_thread(validate_user_solution, question, solution)
|
| 24 |
+
|
| 25 |
+
if not is_valid:
|
| 26 |
+
logger.warning("kb_update.validation_failed", question=question)
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
# Load vector store
|
| 30 |
+
vector_store = load_vector_store()
|
| 31 |
+
encoder = SentenceTransformer(settings.embedding_model_name)
|
| 32 |
+
|
| 33 |
+
# Generate embedding
|
| 34 |
+
text = question + "\n" + solution
|
| 35 |
+
embedding = encoder.encode(text)
|
| 36 |
+
|
| 37 |
+
# Add to Weaviate
|
| 38 |
+
data_object = {
|
| 39 |
+
"question": question,
|
| 40 |
+
"answer": solution,
|
| 41 |
+
"source": source
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
# Add with vector
|
| 45 |
+
vector_store.client.data_object.create(
|
| 46 |
+
data_object=data_object,
|
| 47 |
+
class_name=settings.weaviate_class_name,
|
| 48 |
+
vector=embedding.tolist()
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
logger.info("kb_update.success", question=question, source=source)
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
except Exception as exc:
|
| 55 |
+
logger.error("kb_update.failed", error=str(exc), question=question)
|
| 56 |
+
return False
|
app/services/retrieval.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Retrieval and routing logic for the math agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from ..guardrails import GuardrailViolation
|
| 6 |
+
from ..logger import get_logger
|
| 7 |
+
from ..schemas import AgentResponse, Step
|
| 8 |
+
from ..workflows.langgraph_pipeline import build_math_agent_graph
|
| 9 |
+
from .vector_store import VectorStore, load_vector_store
|
| 10 |
+
|
| 11 |
+
logger = get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MathAgent:
|
| 15 |
+
"""High-level orchestration for handling a math query."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, vector_store: VectorStore | None = None) -> None:
|
| 18 |
+
self.vector_store = vector_store or load_vector_store()
|
| 19 |
+
self.graph = build_math_agent_graph(self.vector_store)
|
| 20 |
+
|
| 21 |
+
async def handle_query(self, query: str) -> AgentResponse:
|
| 22 |
+
"""Run the query through guardrails, routing, and generation."""
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
result_state = await self.graph.ainvoke({"query": query, "gateway_trace": []})
|
| 26 |
+
except GuardrailViolation:
|
| 27 |
+
raise
|
| 28 |
+
except Exception as exc: # pragma: no cover - defensive logging
|
| 29 |
+
logger.exception("math_agent.graph_failure", error=str(exc))
|
| 30 |
+
raise
|
| 31 |
+
|
| 32 |
+
steps_raw = result_state.get("steps", []) or []
|
| 33 |
+
|
| 34 |
+
# Get KB hits and web hits separately, then combine for display
|
| 35 |
+
kb_hits = result_state.get("kb_hits", []) or []
|
| 36 |
+
web_hits = result_state.get("web_hits", []) or []
|
| 37 |
+
# Combine for knowledge_hits display (show both KB and web)
|
| 38 |
+
knowledge_hits = list(kb_hits) + list(web_hits)
|
| 39 |
+
# If no separate hits, fall back to combined knowledge_hits
|
| 40 |
+
if not knowledge_hits:
|
| 41 |
+
knowledge_hits = result_state.get("knowledge_hits", []) or []
|
| 42 |
+
|
| 43 |
+
citations = result_state.get("citations", []) or []
|
| 44 |
+
answer_text = result_state.get("answer", "") or ""
|
| 45 |
+
|
| 46 |
+
# If answer is too short (just overview), try to enhance it with steps
|
| 47 |
+
if len(answer_text.strip()) < 50 and steps_raw:
|
| 48 |
+
# Answer might be incomplete - use steps to build a better answer
|
| 49 |
+
step_texts = []
|
| 50 |
+
for step in steps_raw:
|
| 51 |
+
if isinstance(step, dict):
|
| 52 |
+
step_texts.append(step.get("content", "") or step.get("explanation", ""))
|
| 53 |
+
elif isinstance(step, (list, tuple)) and len(step) > 1:
|
| 54 |
+
step_texts.append(step[1])
|
| 55 |
+
|
| 56 |
+
if step_texts:
|
| 57 |
+
# Find the last step that looks like a conclusion
|
| 58 |
+
for step_text in reversed(step_texts):
|
| 59 |
+
if any(word in step_text.lower() for word in ["therefore", "answer", "result", "thus", "hence"]):
|
| 60 |
+
answer_text = step_text
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
# If still no good answer, use the last step
|
| 64 |
+
if len(answer_text.strip()) < 50:
|
| 65 |
+
answer_text = step_texts[-1] if step_texts else answer_text
|
| 66 |
+
|
| 67 |
+
retrieved_from_kb = result_state.get("retrieved_from_kb", False)
|
| 68 |
+
source = result_state.get("source", "kb+web" if (kb_hits and web_hits) else ("kb" if retrieved_from_kb else "gemini"))
|
| 69 |
+
gateway_trace = result_state.get("gateway_trace", [])
|
| 70 |
+
|
| 71 |
+
# Convert steps to new format with content and optional expressions
|
| 72 |
+
formatted_steps = []
|
| 73 |
+
for step in steps_raw:
|
| 74 |
+
# steps_raw is now a list of dicts with title/content/expression
|
| 75 |
+
if isinstance(step, dict):
|
| 76 |
+
formatted_steps.append(Step(
|
| 77 |
+
title=step.get("title", ""),
|
| 78 |
+
content=step.get("content", ""),
|
| 79 |
+
expression=step.get("expression")
|
| 80 |
+
))
|
| 81 |
+
else:
|
| 82 |
+
# Handle legacy tuple format for backward compatibility
|
| 83 |
+
title, explanation = step
|
| 84 |
+
formatted_steps.append(Step(
|
| 85 |
+
title=title,
|
| 86 |
+
content=explanation
|
| 87 |
+
))
|
| 88 |
+
|
| 89 |
+
response = AgentResponse(
|
| 90 |
+
answer=answer_text,
|
| 91 |
+
steps=formatted_steps,
|
| 92 |
+
retrieved_from_kb=retrieved_from_kb,
|
| 93 |
+
knowledge_hits=knowledge_hits,
|
| 94 |
+
citations=citations,
|
| 95 |
+
source=source,
|
| 96 |
+
gateway_trace=gateway_trace,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
logger.info(
|
| 100 |
+
"math_agent.completed",
|
| 101 |
+
source=response.source,
|
| 102 |
+
retrieved=retrieved_from_kb,
|
| 103 |
+
kb_hits=len(knowledge_hits),
|
| 104 |
+
trace=gateway_trace,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return response
|
| 108 |
+
|
| 109 |
+
|
app/services/vector_store.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector store utilities using Weaviate cloud for knowledge base retrieval."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import weaviate
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
from weaviate.collections.classes.config import DataType
|
| 13 |
+
|
| 14 |
+
from ..config import settings
|
| 15 |
+
from ..logger import get_logger
|
| 16 |
+
|
| 17 |
+
logger = get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VectorStore:
|
| 21 |
+
"""Wrapper around Weaviate vector store."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, client: weaviate.WeaviateClient, encoder: SentenceTransformer) -> None:
|
| 24 |
+
self.client = client
|
| 25 |
+
self.encoder = encoder
|
| 26 |
+
self.collection = self._ensure_collection()
|
| 27 |
+
self._property_map = self._detect_property_names()
|
| 28 |
+
|
| 29 |
+
def _ensure_collection(self) -> weaviate.collections.Collection:
|
| 30 |
+
"""Create collection if it doesn't exist."""
|
| 31 |
+
collection_name = settings.weaviate_collection
|
| 32 |
+
collections = self.client.collections.list_all()
|
| 33 |
+
|
| 34 |
+
# Check for exact or case-insensitive match to avoid 422 when class
|
| 35 |
+
# already exists with different casing (e.g. 'Mathvectors').
|
| 36 |
+
lower_names = [n.lower() for n in collections]
|
| 37 |
+
if collection_name not in collections and collection_name.lower() not in lower_names:
|
| 38 |
+
try:
|
| 39 |
+
collection = self.client.collections.create(
|
| 40 |
+
name=collection_name,
|
| 41 |
+
vectorizer_config=weaviate.classes.config.Configure.Vectorizer.none(),
|
| 42 |
+
vector_index_config=weaviate.classes.config.Configure.VectorIndex.hnsw(),
|
| 43 |
+
properties=[
|
| 44 |
+
{"name": "question", "data_type": DataType.TEXT},
|
| 45 |
+
{"name": "answer", "data_type": DataType.TEXT},
|
| 46 |
+
{"name": "source", "data_type": DataType.TEXT},
|
| 47 |
+
]
|
| 48 |
+
)
|
| 49 |
+
logger.info("vector_store.collection.created", name=collection_name)
|
| 50 |
+
except weaviate.exceptions.UnexpectedStatusCodeError as e:
|
| 51 |
+
# If the create failed because the class already exists (422),
|
| 52 |
+
# recover by locating the existing class name (case-insensitive)
|
| 53 |
+
msg = str(e)
|
| 54 |
+
if "already exists" in msg or "class name" in msg:
|
| 55 |
+
logger.warning("vector_store.collection.create_conflict", msg=msg)
|
| 56 |
+
# refresh list and try to find the existing class name
|
| 57 |
+
collections = self.client.collections.list_all()
|
| 58 |
+
existing = None
|
| 59 |
+
for n in collections:
|
| 60 |
+
if n.lower() == collection_name.lower():
|
| 61 |
+
existing = n
|
| 62 |
+
break
|
| 63 |
+
if existing:
|
| 64 |
+
collection = self.client.collections.get(existing)
|
| 65 |
+
logger.info("vector_store.collection.exists", name=existing)
|
| 66 |
+
else:
|
| 67 |
+
# re-raise if we can't resolve the conflict
|
| 68 |
+
raise
|
| 69 |
+
else:
|
| 70 |
+
raise
|
| 71 |
+
else:
|
| 72 |
+
# exact or case-insensitive match found — pick the actual existing name
|
| 73 |
+
if collection_name in collections:
|
| 74 |
+
existing_name = collection_name
|
| 75 |
+
else:
|
| 76 |
+
existing_name = next(n for n in collections if n.lower() == collection_name.lower())
|
| 77 |
+
collection = self.client.collections.get(existing_name)
|
| 78 |
+
logger.info("vector_store.collection.exists", name=existing_name)
|
| 79 |
+
|
| 80 |
+
return collection
|
| 81 |
+
|
| 82 |
+
def _detect_property_names(self) -> Dict[str, str]:
|
| 83 |
+
"""Detect the actual property names in the collection schema.
|
| 84 |
+
|
| 85 |
+
Returns a mapping from logical names (question, answer, source) to actual property names.
|
| 86 |
+
"""
|
| 87 |
+
property_map = {}
|
| 88 |
+
prop_names = []
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
# Get the collection configuration to see what properties exist
|
| 92 |
+
config = self.collection.config.get()
|
| 93 |
+
|
| 94 |
+
# Try different ways to access properties depending on Weaviate v4 structure
|
| 95 |
+
if hasattr(config, 'properties'):
|
| 96 |
+
props = config.properties
|
| 97 |
+
if props:
|
| 98 |
+
# Properties might be a list or iterable
|
| 99 |
+
if isinstance(props, (list, tuple)):
|
| 100 |
+
prop_names = [prop.name if hasattr(prop, 'name') else str(prop) for prop in props]
|
| 101 |
+
elif hasattr(props, '__iter__') and not isinstance(props, (str, bytes)):
|
| 102 |
+
prop_names = [prop.name if hasattr(prop, 'name') else str(prop) for prop in props]
|
| 103 |
+
|
| 104 |
+
# Also try accessing as attributes
|
| 105 |
+
if not prop_names and hasattr(config, 'properties'):
|
| 106 |
+
try:
|
| 107 |
+
props_dict = config.properties.__dict__ if hasattr(config.properties, '__dict__') else {}
|
| 108 |
+
prop_names = list(props_dict.keys()) if props_dict else []
|
| 109 |
+
except:
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
except Exception as exc:
|
| 113 |
+
logger.warning("vector_store.schema.detection_error", error=str(exc))
|
| 114 |
+
|
| 115 |
+
# Try to map to our expected names
|
| 116 |
+
# Check for new schema (question, answer, source) first
|
| 117 |
+
if prop_names:
|
| 118 |
+
if "question" in prop_names:
|
| 119 |
+
property_map["question"] = "question"
|
| 120 |
+
elif "input" in prop_names:
|
| 121 |
+
property_map["question"] = "input"
|
| 122 |
+
|
| 123 |
+
if "answer" in prop_names:
|
| 124 |
+
property_map["answer"] = "answer"
|
| 125 |
+
elif "label" in prop_names:
|
| 126 |
+
property_map["answer"] = "label"
|
| 127 |
+
|
| 128 |
+
if "source" in prop_names:
|
| 129 |
+
property_map["source"] = "source"
|
| 130 |
+
elif "source_file" in prop_names:
|
| 131 |
+
property_map["source"] = "source_file"
|
| 132 |
+
else:
|
| 133 |
+
# No properties detected - use fallback (old schema)
|
| 134 |
+
logger.warning("vector_store.schema.detection_failed", msg="Could not detect properties, using fallback")
|
| 135 |
+
property_map = {"question": "input", "answer": "label", "source": "source_file"}
|
| 136 |
+
|
| 137 |
+
# If detection partially failed, fill in missing mappings
|
| 138 |
+
if "question" not in property_map:
|
| 139 |
+
property_map["question"] = "input"
|
| 140 |
+
if "answer" not in property_map:
|
| 141 |
+
property_map["answer"] = "label"
|
| 142 |
+
if "source" not in property_map:
|
| 143 |
+
property_map["source"] = "source_file"
|
| 144 |
+
|
| 145 |
+
logger.info("vector_store.schema.detected", mapping=property_map, found_properties=prop_names)
|
| 146 |
+
return property_map
|
| 147 |
+
|
| 148 |
+
def search(self, query: str, top_k: int | None = None) -> List[Dict[str, Any]]:
|
| 149 |
+
"""Search for similar questions and return contexts."""
|
| 150 |
+
if top_k is None:
|
| 151 |
+
top_k = settings.top_k
|
| 152 |
+
|
| 153 |
+
# Generate embedding for query
|
| 154 |
+
embedding = self.encoder.encode(query)
|
| 155 |
+
embedding = embedding / np.linalg.norm(embedding) # Normalize
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
# Get the actual property names from the detected schema
|
| 159 |
+
question_prop = self._property_map.get("question", "input")
|
| 160 |
+
answer_prop = self._property_map.get("answer", "label")
|
| 161 |
+
source_prop = self._property_map.get("source", "source_file")
|
| 162 |
+
|
| 163 |
+
# Build Weaviate v4 query with near_vector using actual property names
|
| 164 |
+
response = self.collection.query.near_vector(
|
| 165 |
+
near_vector=embedding.tolist(),
|
| 166 |
+
certainty=settings.similarity_threshold,
|
| 167 |
+
limit=top_k,
|
| 168 |
+
return_metadata=["certainty"],
|
| 169 |
+
return_properties=[question_prop, answer_prop, source_prop]
|
| 170 |
+
)
|
| 171 |
+
except Exception as exc: # pragma: no cover - defensive
|
| 172 |
+
# If query failed, try with old schema as fallback
|
| 173 |
+
error_msg = str(exc).lower()
|
| 174 |
+
if "no such prop" in error_msg or "property" in error_msg:
|
| 175 |
+
logger.warning("vector_store.search.schema_mismatch", error=str(exc), trying="old_schema")
|
| 176 |
+
try:
|
| 177 |
+
# Fallback to old schema property names
|
| 178 |
+
response = self.collection.query.near_vector(
|
| 179 |
+
near_vector=embedding.tolist(),
|
| 180 |
+
certainty=settings.similarity_threshold,
|
| 181 |
+
limit=top_k,
|
| 182 |
+
return_metadata=["certainty"],
|
| 183 |
+
return_properties=["input", "label", "source_file"]
|
| 184 |
+
)
|
| 185 |
+
# Update property map for this query
|
| 186 |
+
self._property_map = {"question": "input", "answer": "label", "source": "source_file"}
|
| 187 |
+
question_prop = "input"
|
| 188 |
+
answer_prop = "label"
|
| 189 |
+
source_prop = "source_file"
|
| 190 |
+
except Exception as fallback_exc:
|
| 191 |
+
logger.exception("vector_store.search.fallback_failed", error=str(fallback_exc))
|
| 192 |
+
return []
|
| 193 |
+
else:
|
| 194 |
+
logger.exception("vector_store.search.error", error=str(exc))
|
| 195 |
+
return []
|
| 196 |
+
|
| 197 |
+
results = []
|
| 198 |
+
# Use the property names that were successfully used in the query
|
| 199 |
+
# (either from detection or from fallback)
|
| 200 |
+
question_prop = self._property_map.get("question", "input")
|
| 201 |
+
answer_prop = self._property_map.get("answer", "label")
|
| 202 |
+
source_prop = self._property_map.get("source", "source_file")
|
| 203 |
+
|
| 204 |
+
# Weaviate v4 returns a QueryReturn with an objects attribute
|
| 205 |
+
for obj in response.objects:
|
| 206 |
+
|
| 207 |
+
question = ""
|
| 208 |
+
answer = ""
|
| 209 |
+
source = ""
|
| 210 |
+
|
| 211 |
+
if isinstance(obj.properties, dict):
|
| 212 |
+
question = obj.properties.get(question_prop, "")
|
| 213 |
+
answer = obj.properties.get(answer_prop, "")
|
| 214 |
+
source = obj.properties.get(source_prop, "")
|
| 215 |
+
else:
|
| 216 |
+
# Try attribute access
|
| 217 |
+
question = getattr(obj.properties, question_prop, "")
|
| 218 |
+
answer = getattr(obj.properties, answer_prop, "")
|
| 219 |
+
source = getattr(obj.properties, source_prop, "")
|
| 220 |
+
|
| 221 |
+
# For old schema where answer might be a label (integer), convert to string
|
| 222 |
+
if answer_prop == "label" and isinstance(answer, (int, float)):
|
| 223 |
+
answer = str(answer)
|
| 224 |
+
|
| 225 |
+
# Extract similarity - use certainty if available, otherwise distance (inverted), default to 0.0
|
| 226 |
+
similarity = 0.0
|
| 227 |
+
if obj.metadata:
|
| 228 |
+
if hasattr(obj.metadata, 'certainty') and obj.metadata.certainty is not None:
|
| 229 |
+
similarity = float(obj.metadata.certainty)
|
| 230 |
+
elif hasattr(obj.metadata, 'distance') and obj.metadata.distance is not None:
|
| 231 |
+
# Convert distance to similarity (distance is lower for more similar items)
|
| 232 |
+
# For normalized vectors, distance = 1 - certainty approximately
|
| 233 |
+
similarity = max(0.0, 1.0 - float(obj.metadata.distance))
|
| 234 |
+
|
| 235 |
+
results.append({
|
| 236 |
+
"document_id": str(obj.uuid),
|
| 237 |
+
"question": question,
|
| 238 |
+
"answer": answer,
|
| 239 |
+
"source": source,
|
| 240 |
+
"similarity": similarity,
|
| 241 |
+
})
|
| 242 |
+
|
| 243 |
+
return results
|
| 244 |
+
|
| 245 |
+
def add_entry(self, question: str, answer: str, source: str = "kb") -> str:
|
| 246 |
+
"""Add a new entry to the vector store."""
|
| 247 |
+
# Generate embedding
|
| 248 |
+
text = question + "\n" + answer
|
| 249 |
+
embedding = self.encoder.encode(text)
|
| 250 |
+
embedding = embedding / np.linalg.norm(embedding)
|
| 251 |
+
|
| 252 |
+
# Add to Weaviate
|
| 253 |
+
result = self.collection.data.insert(
|
| 254 |
+
properties={
|
| 255 |
+
"question": question,
|
| 256 |
+
"answer": answer,
|
| 257 |
+
"source": source,
|
| 258 |
+
},
|
| 259 |
+
vector=embedding.tolist()
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return result.uuid
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
_vector_store: VectorStore | None = None
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class _NullVectorStore:
|
| 269 |
+
"""Fallback vector store used when Weaviate is unavailable at startup.
|
| 270 |
+
Provides the minimal API consumed by the workflow.
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
def search(self, query: str, top_k: int | None = None) -> List[Dict[str, Any]]: # type: ignore[override]
|
| 274 |
+
return []
|
| 275 |
+
|
| 276 |
+
def add_entry(self, question: str, answer: str, source: str = "kb") -> str: # pragma: no cover
|
| 277 |
+
logger.warning("vector_store.null.add_entry_ignored")
|
| 278 |
+
return ""
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def load_vector_store(force_reload: bool = False) -> VectorStore:
|
| 282 |
+
"""Lazy-load the Weaviate vector store."""
|
| 283 |
+
global _vector_store
|
| 284 |
+
if _vector_store is not None and not force_reload:
|
| 285 |
+
return _vector_store
|
| 286 |
+
|
| 287 |
+
if not settings.weaviate_url or not settings.weaviate_api_key:
|
| 288 |
+
logger.warning("vector_store.config.missing", msg="Starting without KB; searches will return empty.")
|
| 289 |
+
_vector_store = _NullVectorStore() # type: ignore[assignment]
|
| 290 |
+
return _vector_store # type: ignore[return-value]
|
| 291 |
+
|
| 292 |
+
logger.info("vector_store.load.start", url=settings.weaviate_url)
|
| 293 |
+
|
| 294 |
+
# Use the v4 helper to connect to Weaviate Cloud. The older v3-style
|
| 295 |
+
# `weaviate.Client(url=..., auth_client_secret=...)` constructor is not
|
| 296 |
+
# supported by the installed weaviate client and raises a TypeError.
|
| 297 |
+
# connect_to_weaviate_cloud will construct the correct ConnectionParams
|
| 298 |
+
# for cloud deployments (http + grpc hosts) and return a connected client.
|
| 299 |
+
# pass the API key string directly — the helper will parse it into the
|
| 300 |
+
# proper AuthCredentials object. Accessing `weaviate.Auth` at module
|
| 301 |
+
# level can raise AttributeError in some installs; passing a string is
|
| 302 |
+
# supported and avoids that problem.
|
| 303 |
+
try:
|
| 304 |
+
client = weaviate.connect_to_weaviate_cloud(
|
| 305 |
+
cluster_url=settings.weaviate_url,
|
| 306 |
+
auth_credentials=settings.weaviate_api_key,
|
| 307 |
+
skip_init_checks=True,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
encoder = SentenceTransformer(settings.embedding_model_name)
|
| 311 |
+
_vector_store = VectorStore(client=client, encoder=encoder)
|
| 312 |
+
logger.info("vector_store.load.success")
|
| 313 |
+
return _vector_store
|
| 314 |
+
except Exception as exc: # pragma: no cover
|
| 315 |
+
logger.warning(
|
| 316 |
+
"vector_store.load.failed_starting_without_kb",
|
| 317 |
+
error=str(exc),
|
| 318 |
+
)
|
| 319 |
+
_vector_store = _NullVectorStore() # type: ignore[assignment]
|
| 320 |
+
return _vector_store # type: ignore[return-value]
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def save_feedback_to_queue(feedback_record: dict) -> None:
|
| 324 |
+
"""Append feedback to feedback database."""
|
| 325 |
+
|
| 326 |
+
feedback_path = Path(settings.feedback_store_path)
|
| 327 |
+
feedback_path.parent.mkdir(parents=True, exist_ok=True)
|
| 328 |
+
if feedback_path.exists():
|
| 329 |
+
data = json.loads(feedback_path.read_text(encoding="utf-8"))
|
| 330 |
+
else:
|
| 331 |
+
data = []
|
| 332 |
+
data.append(feedback_record)
|
| 333 |
+
feedback_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def queue_candidate_kb_entry(question: str, solution: str, source: str) -> None:
|
| 337 |
+
"""Save a candidate knowledge base entry for human approval."""
|
| 338 |
+
|
| 339 |
+
queue_path = Path("backend/data/kb_candidate_queue.json")
|
| 340 |
+
queue_path.parent.mkdir(parents=True, exist_ok=True)
|
| 341 |
+
if queue_path.exists():
|
| 342 |
+
queue = json.loads(queue_path.read_text(encoding="utf-8"))
|
| 343 |
+
else:
|
| 344 |
+
queue = []
|
| 345 |
+
queue.append({"question": question, "answer": solution, "source": source})
|
| 346 |
+
queue_path.write_text(json.dumps(queue, ensure_ascii=False, indent=2), encoding="utf-8")
|
| 347 |
+
|
| 348 |
+
|
app/tools/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tooling helpers for math agent."""
|
| 2 |
+
|
app/tools/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
app/tools/__pycache__/audio.cpython-313.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
app/tools/__pycache__/dspy_pipeline.cpython-313.pyc
ADDED
|
Binary file (20.7 kB). View file
|
|
|
app/tools/__pycache__/validator.cpython-313.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
app/tools/__pycache__/vision.cpython-313.pyc
ADDED
|
Binary file (1.75 kB). View file
|
|
|
app/tools/__pycache__/web_search.cpython-313.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
app/tools/audio.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio transcription using Groq Whisper models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import io
|
| 7 |
+
|
| 8 |
+
from groq import Groq
|
| 9 |
+
|
| 10 |
+
from ..config import settings
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _strip_data_url(data: str) -> str:
|
| 14 |
+
if "," in data and data.startswith("data:"):
|
| 15 |
+
return data.split(",", 1)[1]
|
| 16 |
+
return data
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def transcribe_audio(audio_base64: str, language: str = "en") -> str:
|
| 20 |
+
"""Transcribe base64 encoded audio using Groq Whisper."""
|
| 21 |
+
|
| 22 |
+
if not settings.groq_api_key:
|
| 23 |
+
raise RuntimeError("GROQ_API_KEY not configured for speech transcription")
|
| 24 |
+
|
| 25 |
+
client = Groq(api_key=settings.groq_api_key)
|
| 26 |
+
clean_base64 = _strip_data_url(audio_base64)
|
| 27 |
+
audio_bytes = base64.b64decode(clean_base64)
|
| 28 |
+
audio_buffer = io.BytesIO(audio_bytes)
|
| 29 |
+
audio_buffer.name = "input.wav"
|
| 30 |
+
|
| 31 |
+
transcription = client.audio.transcriptions.create(
|
| 32 |
+
model="whisper-large-v3-turbo", file=audio_buffer, response_format="text", language=language
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if hasattr(transcription, "text"):
|
| 36 |
+
return transcription.text
|
| 37 |
+
if isinstance(transcription, str):
|
| 38 |
+
return transcription
|
| 39 |
+
raise RuntimeError("Unexpected response format from Groq Whisper API")
|
| 40 |
+
|
| 41 |
+
|
app/tools/dspy_pipeline.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DSPy powered reasoning pipeline for math explanations."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Iterable, List, Sequence
|
| 10 |
+
|
| 11 |
+
# Normalization helpers to convert LaTeX-like math to readable text (module-wide)
|
| 12 |
+
_LATEX_SIMPLE_REPLACEMENTS = [
|
| 13 |
+
(r"\\cdot", "·"),
|
| 14 |
+
(r"\\times", "×"),
|
| 15 |
+
(r"\\to", "→"),
|
| 16 |
+
(r"\\leq", "≤"),
|
| 17 |
+
(r"\\geq", "≥"),
|
| 18 |
+
(r"\\neq", "≠"),
|
| 19 |
+
(r"\\pm", "±"),
|
| 20 |
+
(r"\\approx", "≈"),
|
| 21 |
+
(r"\\infty", "∞"),
|
| 22 |
+
(r"\\Rightarrow", "⇒"),
|
| 23 |
+
(r"\\Leftarrow", "⇐"),
|
| 24 |
+
(r"\\sin", "sin"),
|
| 25 |
+
(r"\\cos", "cos"),
|
| 26 |
+
(r"\\tan", "tan"),
|
| 27 |
+
(r"\\ln", "ln"),
|
| 28 |
+
(r"\\log", "log"),
|
| 29 |
+
(r"\\lim", "lim"),
|
| 30 |
+
(r"\\mathbb\{R\}", "R"),
|
| 31 |
+
(r"\\mathbb\{R\}\^3", "R^3"),
|
| 32 |
+
(r"\\mathbb\{Z\}", "Z"),
|
| 33 |
+
(r"\\mathbb\{Q\}", "Q"),
|
| 34 |
+
(r"\\sqrt", "√"),
|
| 35 |
+
(r"\\pi", "π"),
|
| 36 |
+
(r"\\theta", "θ"),
|
| 37 |
+
(r"\\alpha", "α"),
|
| 38 |
+
(r"\\beta", "β"),
|
| 39 |
+
(r"\\gamma", "γ"),
|
| 40 |
+
(r"\\Delta", "Δ"),
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
_FRAC_PATTERN = re.compile(r"\\frac\{([^}]+)\}\{([^}]+)\}")
|
| 44 |
+
|
| 45 |
+
def _normalize_math(text: str) -> str:
|
| 46 |
+
if not text:
|
| 47 |
+
return text
|
| 48 |
+
# Remove inline/display math delimiters
|
| 49 |
+
text = (
|
| 50 |
+
text.replace("$$", "")
|
| 51 |
+
.replace("$", "")
|
| 52 |
+
.replace("\\[", "")
|
| 53 |
+
.replace("\\]", "")
|
| 54 |
+
.replace("\\(", "")
|
| 55 |
+
.replace("\\)", "")
|
| 56 |
+
)
|
| 57 |
+
# Strip common markdown artifacts (bold/headers/backticks)
|
| 58 |
+
text = text.replace("**", "").replace("`", "")
|
| 59 |
+
text = re.sub(r"^\s*#{1,6}\s*", "", text, flags=re.MULTILINE)
|
| 60 |
+
# Drop sizing/spacing commands
|
| 61 |
+
text = re.sub(r"\\(left|right|big|Big|quad|qquad)\\?", "", text)
|
| 62 |
+
# Replace \frac{a}{b} with (a)/(b)
|
| 63 |
+
def _frac_sub(match: re.Match) -> str:
|
| 64 |
+
return f"({match.group(1)})/({match.group(2)})"
|
| 65 |
+
text = _FRAC_PATTERN.sub(_frac_sub, text)
|
| 66 |
+
# Apply simple replacements
|
| 67 |
+
for pattern, repl in _LATEX_SIMPLE_REPLACEMENTS:
|
| 68 |
+
text = re.sub(pattern, repl, text)
|
| 69 |
+
# Normalize powers and tidy operators spacing
|
| 70 |
+
text = re.sub(r"\^2", "²", text)
|
| 71 |
+
text = re.sub(r"\^3", "³", text)
|
| 72 |
+
text = re.sub(r"\s*([=+\-×·*/])\s*", r" \1 ", text)
|
| 73 |
+
# Collapse excessive spaces
|
| 74 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 75 |
+
return text
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
import dspy # type: ignore
|
| 79 |
+
import google.generativeai as genai # type: ignore
|
| 80 |
+
|
| 81 |
+
from ..config import settings
|
| 82 |
+
from ..logger import get_logger
|
| 83 |
+
|
| 84 |
+
logger = get_logger(__name__)
|
| 85 |
+
|
| 86 |
+
class GeminiLM(dspy.LM):
|
| 87 |
+
"""Minimal DSPy LM wrapper around Gemini API."""
|
| 88 |
+
|
| 89 |
+
def __init__(self, model: str, api_key: str, max_output_tokens: int = 2048) -> None:
|
| 90 |
+
super().__init__(model=model, max_output_tokens=max_output_tokens)
|
| 91 |
+
genai.configure(api_key=api_key)
|
| 92 |
+
self._model = genai.GenerativeModel(model)
|
| 93 |
+
|
| 94 |
+
def __call__(self, prompt: str, **kwargs) -> dspy.Completion:
|
| 95 |
+
response = self._model.generate_content(prompt)
|
| 96 |
+
text = response.text or ""
|
| 97 |
+
return dspy.Completion(text=text)
|
| 98 |
+
|
| 99 |
+
def loglikelihood(self, prompt: str, continuation: str) -> float:
|
| 100 |
+
raise NotImplementedError("GeminiLM.loglikelihood is not supported")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MathTutorSignature(dspy.Signature):
|
| 104 |
+
"""DSPy signature describing the desired behavior."""
|
| 105 |
+
|
| 106 |
+
query = dspy.InputField(desc="Student's mathematics question - MUST solve completely")
|
| 107 |
+
context = dspy.InputField(desc="Knowledge snippets retrieved from KB or web")
|
| 108 |
+
requirements = dspy.InputField(desc="JSON encoded requirements for the explanation")
|
| 109 |
+
solution = dspy.OutputField(desc="Structured JSON with COMPLETE solution including all steps, all parts answered, and explicit final answer. Must be a complete solution, not just an overview.")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _ensure_dspy_configured() -> None:
|
| 113 |
+
if settings.gemini_api_key is None:
|
| 114 |
+
raise RuntimeError("Gemini API key missing. Set GEMINI_API_KEY in environment.")
|
| 115 |
+
|
| 116 |
+
lm = GeminiLM(model=settings.dspy_model, api_key=settings.gemini_api_key, max_output_tokens=settings.dspy_max_tokens)
|
| 117 |
+
dspy.settings.configure(lm=lm)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass
|
| 121 |
+
class WebDocument:
|
| 122 |
+
id: str
|
| 123 |
+
title: str
|
| 124 |
+
url: str
|
| 125 |
+
snippet: str
|
| 126 |
+
score: float
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class SearchResult:
|
| 131 |
+
query: str
|
| 132 |
+
source: str
|
| 133 |
+
documents: Sequence[WebDocument]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
async def generate_solution_with_cot(
|
| 137 |
+
query: str,
|
| 138 |
+
contexts: Iterable,
|
| 139 |
+
search_metadata: SearchResult | None = None,
|
| 140 |
+
) -> tuple[list[tuple[str, str]], str]:
|
| 141 |
+
"""Return (steps, final_answer) produced by Gemini via DSPy."""
|
| 142 |
+
|
| 143 |
+
_ensure_dspy_configured()
|
| 144 |
+
|
| 145 |
+
context_blocks: List[str] = []
|
| 146 |
+
for ctx in contexts:
|
| 147 |
+
block = json.dumps({
|
| 148 |
+
"document_id": getattr(ctx, "document_id", ""),
|
| 149 |
+
"question": getattr(ctx, "question", ""),
|
| 150 |
+
"answer": getattr(ctx, "answer", ""),
|
| 151 |
+
"similarity": getattr(ctx, "similarity", 0.0),
|
| 152 |
+
}, ensure_ascii=False)
|
| 153 |
+
context_blocks.append(block)
|
| 154 |
+
|
| 155 |
+
if search_metadata:
|
| 156 |
+
citations = [doc.__dict__ for doc in search_metadata.documents]
|
| 157 |
+
else:
|
| 158 |
+
citations = []
|
| 159 |
+
|
| 160 |
+
requirements = {
|
| 161 |
+
"style": "Explain as a friendly mathematics professor using numbered steps.",
|
| 162 |
+
"must_include": [
|
| 163 |
+
"Start the response with a single line: 'Answer: <final value or statement>'",
|
| 164 |
+
"You MUST solve the ENTIRE problem completely - do not stop with just an overview",
|
| 165 |
+
"List each algebraic manipulation explicitly",
|
| 166 |
+
"Show ALL intermediate steps - do not skip any calculations",
|
| 167 |
+
"Conclude with a clear final numeric or symbolic answer",
|
| 168 |
+
"If the problem has multiple parts, answer ALL parts and label them (a), (b), etc.",
|
| 169 |
+
"Avoid LaTeX markers like $...$ or \\frac{}{}; prefer natural language and simple unicode math",
|
| 170 |
+
"State units or interpretation when relevant (probability, geometry, word problems, etc.)",
|
| 171 |
+
],
|
| 172 |
+
"citations": citations,
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Build comprehensive context prompt
|
| 176 |
+
context_prompt = ""
|
| 177 |
+
if context_blocks:
|
| 178 |
+
context_prompt = "\n\nCONTEXT FROM KNOWLEDGE BASE AND WEB SEARCH:\n" + "\n".join(context_blocks)
|
| 179 |
+
|
| 180 |
+
enhanced_query = (
|
| 181 |
+
f"{query}\n\n"
|
| 182 |
+
"IMPORTANT: Solve this problem COMPLETELY. Provide all steps and the final answer. "
|
| 183 |
+
"Do not just give an overview - actually solve the entire problem." + context_prompt
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
predictor = dspy.Predict(MathTutorSignature)
|
| 187 |
+
|
| 188 |
+
async def _run_predict() -> dspy.Prediction:
|
| 189 |
+
return await asyncio.to_thread(
|
| 190 |
+
predictor,
|
| 191 |
+
query=enhanced_query,
|
| 192 |
+
context="\n".join(context_blocks) if context_blocks else "No additional context available.",
|
| 193 |
+
requirements=json.dumps(requirements, ensure_ascii=False),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
prediction = await _run_predict()
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
structured = json.loads(prediction.solution)
|
| 200 |
+
except json.JSONDecodeError:
|
| 201 |
+
logger.warning("dspy.solution.parse_failed", solution_preview=prediction.solution[:200] if prediction.solution else "")
|
| 202 |
+
# Try to extract solution from non-JSON response
|
| 203 |
+
solution_text = _normalize_math(prediction.solution or "")
|
| 204 |
+
# If it's not JSON, treat it as a natural language response and parse it
|
| 205 |
+
structured = {
|
| 206 |
+
"steps": [
|
| 207 |
+
{
|
| 208 |
+
"title": "Complete Solution",
|
| 209 |
+
"explanation": solution_text,
|
| 210 |
+
}
|
| 211 |
+
],
|
| 212 |
+
"final_answer": solution_text.split("\n")[-1] if "\n" in solution_text else solution_text,
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
steps_raw = structured.get("steps", [])
|
| 216 |
+
|
| 217 |
+
# Convert to dict format if it's tuple format
|
| 218 |
+
steps = []
|
| 219 |
+
for idx, step in enumerate(steps_raw):
|
| 220 |
+
if isinstance(step, dict):
|
| 221 |
+
steps.append({
|
| 222 |
+
"title": step.get("title", f"Step {idx+1}"),
|
| 223 |
+
"content": _normalize_math(step.get("explanation", step.get("content", ""))),
|
| 224 |
+
"expression": _normalize_math(step.get("expression")) if step.get("expression") else None
|
| 225 |
+
})
|
| 226 |
+
elif isinstance(step, (list, tuple)) and len(step) >= 2:
|
| 227 |
+
steps.append({
|
| 228 |
+
"title": step[0],
|
| 229 |
+
"content": _normalize_math(step[1]),
|
| 230 |
+
"expression": _normalize_math(step[2]) if len(step) > 2 and step[2] else None
|
| 231 |
+
})
|
| 232 |
+
else:
|
| 233 |
+
steps.append({
|
| 234 |
+
"title": f"Step {idx+1}",
|
| 235 |
+
"content": _normalize_math(str(step)),
|
| 236 |
+
"expression": None
|
| 237 |
+
})
|
| 238 |
+
|
| 239 |
+
final_answer = _normalize_math(structured.get("final_answer", ""))
|
| 240 |
+
|
| 241 |
+
# Ensure final_answer is not empty
|
| 242 |
+
if not final_answer and steps:
|
| 243 |
+
# Try to extract from last step
|
| 244 |
+
last_content = steps[-1].get("content", "")
|
| 245 |
+
if last_content:
|
| 246 |
+
# Look for answer patterns
|
| 247 |
+
for line in reversed(last_content.split("\n")):
|
| 248 |
+
if any(word in line.lower() for word in ["answer", "therefore", "thus", "hence", "result"]):
|
| 249 |
+
final_answer = _normalize_math(line)
|
| 250 |
+
break
|
| 251 |
+
if not final_answer:
|
| 252 |
+
final_answer = last_content.split("\n")[-1] if "\n" in last_content else last_content
|
| 253 |
+
|
| 254 |
+
# Ultimate fallback
|
| 255 |
+
if not final_answer:
|
| 256 |
+
final_answer = "Please see the steps above for the complete solution."
|
| 257 |
+
|
| 258 |
+
# Prepend a single-line explicit answer for UI readability
|
| 259 |
+
if not final_answer.lower().startswith("answer:"):
|
| 260 |
+
final_answer = f"Answer: {final_answer}"
|
| 261 |
+
|
| 262 |
+
return steps, final_answer
|
| 263 |
+
except Exception: # pragma: no cover - fallback when dspy isn't installed
|
| 264 |
+
# Provide a lightweight fallback so the application can start even if
|
| 265 |
+
# `dspy` is not installed. This fallback uses the Gemini client directly
|
| 266 |
+
# to generate a JSON-structured explanation.
|
| 267 |
+
import google.generativeai as genai # type: ignore
|
| 268 |
+
|
| 269 |
+
from ..config import settings
|
| 270 |
+
from ..logger import get_logger
|
| 271 |
+
|
| 272 |
+
logger = get_logger(__name__)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@dataclass
|
| 276 |
+
class WebDocument:
|
| 277 |
+
id: str
|
| 278 |
+
title: str
|
| 279 |
+
url: str
|
| 280 |
+
snippet: str
|
| 281 |
+
score: float
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@dataclass
|
| 285 |
+
class SearchResult:
|
| 286 |
+
query: str
|
| 287 |
+
source: str
|
| 288 |
+
documents: Sequence[WebDocument]
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
async def generate_solution_with_cot(
|
| 292 |
+
query: str,
|
| 293 |
+
contexts: Iterable,
|
| 294 |
+
search_metadata: SearchResult | None = None,
|
| 295 |
+
) -> tuple[list[dict[str, str]], str]:
|
| 296 |
+
"""Fallback: use Gemini directly to produce a JSON with `steps` and `final_answer`.
|
| 297 |
+
|
| 298 |
+
This is intentionally simple: it asks the model to return a JSON object. If parsing
|
| 299 |
+
fails, we return the raw text as a single reasoning step.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
if settings.gemini_api_key is None:
|
| 303 |
+
raise RuntimeError("Gemini API key missing. Set GEMINI_API_KEY in environment.")
|
| 304 |
+
|
| 305 |
+
genai.configure(api_key=settings.gemini_api_key)
|
| 306 |
+
model = genai.GenerativeModel(settings.dspy_model)
|
| 307 |
+
|
| 308 |
+
context_text = []
|
| 309 |
+
for ctx in contexts:
|
| 310 |
+
context_text.append(json.dumps({
|
| 311 |
+
"document_id": getattr(ctx, "document_id", ""),
|
| 312 |
+
"question": getattr(ctx, "question", ""),
|
| 313 |
+
"answer": getattr(ctx, "answer", ""),
|
| 314 |
+
"similarity": getattr(ctx, "similarity", 0.0),
|
| 315 |
+
}, ensure_ascii=False))
|
| 316 |
+
|
| 317 |
+
context_section = ""
|
| 318 |
+
if context_text:
|
| 319 |
+
context_section = "\n\nCONTEXT FROM KNOWLEDGE BASE AND WEB SEARCH:\n" + "\n".join(context_text) + "\n"
|
| 320 |
+
|
| 321 |
+
prompt = (
|
| 322 |
+
"You are an expert mathematics tutor across ALL math domains (arithmetic, algebra, geometry, trigonometry, calculus, linear algebra, probability, statistics, number theory, optimization, etc.). Provide clear, human-friendly solutions.\n\n"
|
| 323 |
+
"CRITICAL REQUIREMENTS - YOU MUST FOLLOW THESE:\n"
|
| 324 |
+
"1. Start your response with a single line: 'Answer: <final value or statement>'.\n"
|
| 325 |
+
"2. Solve the ENTIRE problem - do NOT just give an overview or stop midway.\n"
|
| 326 |
+
"3. Show EVERY step of your work - no skipped calculations.\n"
|
| 327 |
+
"4. If there are multiple parts, answer ALL parts, labeled (a), (b), etc.\n"
|
| 328 |
+
"5. Provide a clear, explicit FINAL ANSWER at the very end as well.\n"
|
| 329 |
+
"6. Avoid LaTeX markers like $...$ or \\frac{}{}; write math in natural language or simple unicode (e.g., 1/2, ·, ×, →).\n"
|
| 330 |
+
"7. Use concise, readable sentences; keep symbols understandable to non-experts.\n\n"
|
| 331 |
+
"RESPONSE STRUCTURE:\n"
|
| 332 |
+
"1. Brief Overview (1-2 sentences about the goal)\n"
|
| 333 |
+
"2. Step-by-Step Solution (numbered) with clear titles and calculations\n"
|
| 334 |
+
"3. Final Answer Section (repeat the answer clearly)\n\n"
|
| 335 |
+
"STUDENT'S QUESTION:\n" + query + context_section +
|
| 336 |
+
"\n\nNOW SOLVE THIS PROBLEM COMPLETELY.\n"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# Generate with more tokens to ensure complete answers
|
| 340 |
+
from google.generativeai.types import GenerationConfig
|
| 341 |
+
|
| 342 |
+
generation_config = GenerationConfig(
|
| 343 |
+
temperature=0.3, # Lower temperature for more focused responses
|
| 344 |
+
max_output_tokens=8192, # Increased for complete solutions
|
| 345 |
+
top_p=0.95,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
resp = model.generate_content(prompt, generation_config=generation_config)
|
| 349 |
+
text = (resp.text or "").strip()
|
| 350 |
+
text = _normalize_math(text)
|
| 351 |
+
|
| 352 |
+
if not text:
|
| 353 |
+
logger.warning("dspy.fallback.empty_response", query=query)
|
| 354 |
+
text = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
|
| 355 |
+
|
| 356 |
+
try:
|
| 357 |
+
# Parse the natural language response into structured steps
|
| 358 |
+
lines = text.strip().split("\n")
|
| 359 |
+
steps = []
|
| 360 |
+
current_step = None
|
| 361 |
+
final_answer = ""
|
| 362 |
+
buffer = []
|
| 363 |
+
overview_added = False
|
| 364 |
+
|
| 365 |
+
for line in lines:
|
| 366 |
+
line = line.strip()
|
| 367 |
+
if not line:
|
| 368 |
+
if current_step and buffer:
|
| 369 |
+
current_step["content"] = "\n".join(buffer)
|
| 370 |
+
buffer = []
|
| 371 |
+
continue
|
| 372 |
+
|
| 373 |
+
# Detect step headers (more flexible patterns)
|
| 374 |
+
step_match = False
|
| 375 |
+
if line.lower().startswith("step ") or re.match(r"^\d+[\.\)]\s+", line, re.IGNORECASE):
|
| 376 |
+
step_match = True
|
| 377 |
+
if current_step:
|
| 378 |
+
current_step["content"] = "\n".join(buffer)
|
| 379 |
+
steps.append(current_step)
|
| 380 |
+
buffer = []
|
| 381 |
+
# Extract step number and title
|
| 382 |
+
step_parts = re.split(r"[:\.]", line, 1)
|
| 383 |
+
step_title = step_parts[1].strip() if len(step_parts) > 1 else line
|
| 384 |
+
current_step = {
|
| 385 |
+
"title": step_title if step_title else f"Step {len(steps) + 1}",
|
| 386 |
+
"content": "",
|
| 387 |
+
"expression": None
|
| 388 |
+
}
|
| 389 |
+
# Detect final answer section
|
| 390 |
+
elif any(word in line.lower() for word in ["therefore", "thus", "finally", "hence", "answer:", "the answer is", "final answer"]):
|
| 391 |
+
if current_step:
|
| 392 |
+
current_step["content"] = "\n".join(buffer)
|
| 393 |
+
steps.append(current_step)
|
| 394 |
+
buffer = []
|
| 395 |
+
current_step = None
|
| 396 |
+
# Extract final answer
|
| 397 |
+
for word in ["therefore", "thus", "finally", "hence", "answer:", "the answer is", "final answer"]:
|
| 398 |
+
if word in line.lower():
|
| 399 |
+
final_answer = line.split(":", 1)[1].strip() if ":" in line else line
|
| 400 |
+
break
|
| 401 |
+
if not final_answer:
|
| 402 |
+
final_answer = line
|
| 403 |
+
# Handle equations (lines with = or mathematical operators)
|
| 404 |
+
elif current_step and ("=" in line or any(s in line for s in "+-*/^√∫∑∈ℝ")):
|
| 405 |
+
if buffer: # Save accumulated explanation first
|
| 406 |
+
current_step["content"] = "\n".join(buffer)
|
| 407 |
+
buffer = []
|
| 408 |
+
if not current_step["expression"]:
|
| 409 |
+
current_step["expression"] = _normalize_math(line)
|
| 410 |
+
else:
|
| 411 |
+
current_step["expression"] += "\n" + _normalize_math(line)
|
| 412 |
+
# Add to current step's content
|
| 413 |
+
elif current_step:
|
| 414 |
+
buffer.append(line)
|
| 415 |
+
# Handle text before first step (overview)
|
| 416 |
+
elif not overview_added and not step_match:
|
| 417 |
+
steps.append({
|
| 418 |
+
"title": "Overview",
|
| 419 |
+
"content": line,
|
| 420 |
+
"expression": None
|
| 421 |
+
})
|
| 422 |
+
overview_added = True
|
| 423 |
+
|
| 424 |
+
# Don't forget the last step
|
| 425 |
+
if current_step:
|
| 426 |
+
if buffer:
|
| 427 |
+
current_step["content"] = "\n".join(buffer)
|
| 428 |
+
steps.append(current_step)
|
| 429 |
+
|
| 430 |
+
# If no steps were parsed, create steps from the text
|
| 431 |
+
if not steps:
|
| 432 |
+
# Split by paragraphs or double newlines
|
| 433 |
+
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
| 434 |
+
for idx, para in enumerate(paragraphs):
|
| 435 |
+
if idx == 0:
|
| 436 |
+
steps.append({
|
| 437 |
+
"title": "Solution",
|
| 438 |
+
"content": _normalize_math(para),
|
| 439 |
+
"expression": None
|
| 440 |
+
})
|
| 441 |
+
else:
|
| 442 |
+
steps.append({
|
| 443 |
+
"title": f"Step {idx}",
|
| 444 |
+
"content": _normalize_math(para),
|
| 445 |
+
"expression": None
|
| 446 |
+
})
|
| 447 |
+
|
| 448 |
+
# Extract final answer if not found
|
| 449 |
+
if not final_answer:
|
| 450 |
+
# Look for answer patterns in the last few lines
|
| 451 |
+
for line in reversed(lines[-10:]):
|
| 452 |
+
line_lower = line.lower()
|
| 453 |
+
if any(word in line_lower for word in ["answer", "therefore", "thus", "hence", "result"]):
|
| 454 |
+
final_answer = line
|
| 455 |
+
break
|
| 456 |
+
|
| 457 |
+
# If still no answer, use the last meaningful line or paragraph
|
| 458 |
+
if not final_answer and steps:
|
| 459 |
+
last_content = steps[-1].get("content", "")
|
| 460 |
+
if last_content:
|
| 461 |
+
final_answer = last_content.split("\n")[-1] if "\n" in last_content else last_content
|
| 462 |
+
|
| 463 |
+
# Ultimate fallback
|
| 464 |
+
if not final_answer:
|
| 465 |
+
final_answer = "Please refer to the steps above for the complete solution."
|
| 466 |
+
|
| 467 |
+
return steps, final_answer
|
| 468 |
+
except Exception:
|
| 469 |
+
logger.warning("dspy.fallback.parse_failed")
|
| 470 |
+
# Even for plain text, maintain consistent structure
|
| 471 |
+
return [{
|
| 472 |
+
"title": "Solution",
|
| 473 |
+
"content": text
|
| 474 |
+
}], text
|
| 475 |
+
|
| 476 |
+
|
app/tools/validator.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Validation helpers for user-provided solutions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import google.generativeai as genai
|
| 6 |
+
|
| 7 |
+
from ..config import settings
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def validate_user_solution(question: str, proposed_solution: str) -> bool:
|
| 11 |
+
"""Use Gemini to validate a user-uploaded solution."""
|
| 12 |
+
|
| 13 |
+
if not settings.gemini_api_key:
|
| 14 |
+
raise RuntimeError("Gemini API key required for solution validation")
|
| 15 |
+
|
| 16 |
+
genai.configure(api_key=settings.gemini_api_key)
|
| 17 |
+
model = genai.GenerativeModel(settings.gemini_model)
|
| 18 |
+
prompt = (
|
| 19 |
+
"You are an expert mathematics professor. Validate the student's solution. "
|
| 20 |
+
"Return ONLY 'VALID' if the reasoning is mathematically correct. "
|
| 21 |
+
"Return ONLY 'INVALID' otherwise.\n\n"
|
| 22 |
+
f"Question: {question}\nStudent solution:\n{proposed_solution}"
|
| 23 |
+
)
|
| 24 |
+
response = model.generate_content(prompt)
|
| 25 |
+
text = (response.text or "").strip().lower()
|
| 26 |
+
return text.startswith("valid")
|
| 27 |
+
|
| 28 |
+
|
app/tools/vision.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image understanding helpers using Gemini multimodal models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
|
| 7 |
+
import google.generativeai as genai
|
| 8 |
+
|
| 9 |
+
from ..config import settings
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _strip_data_url(data: str) -> str:
|
| 13 |
+
if "," in data and data.startswith("data:"):
|
| 14 |
+
return data.split(",", 1)[1]
|
| 15 |
+
return data
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def extract_text_from_image(image_base64: str, prompt: str = "Extract all mathematics text from this image.") -> str:
|
| 19 |
+
"""Use Gemini to extract math text from base64 encoded image."""
|
| 20 |
+
|
| 21 |
+
if not settings.gemini_api_key:
|
| 22 |
+
raise RuntimeError("Gemini API key not configured for image understanding")
|
| 23 |
+
|
| 24 |
+
genai.configure(api_key=settings.gemini_api_key)
|
| 25 |
+
clean_base64 = _strip_data_url(image_base64)
|
| 26 |
+
image_bytes = base64.b64decode(clean_base64)
|
| 27 |
+
model = genai.GenerativeModel(settings.gemini_model)
|
| 28 |
+
response = model.generate_content(
|
| 29 |
+
[prompt, {"mime_type": "image/png", "data": image_bytes}]
|
| 30 |
+
)
|
| 31 |
+
return response.text or ""
|
| 32 |
+
|
| 33 |
+
|
app/tools/web_search.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Web search workflow: MCP → LangChain Tavily → SDK fallback."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
|
| 11 |
+
import aiohttp
|
| 12 |
+
|
| 13 |
+
from ..config import settings
|
| 14 |
+
from ..logger import get_logger
|
| 15 |
+
from .dspy_pipeline import SearchResult, WebDocument
|
| 16 |
+
|
| 17 |
+
logger = get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class MCPDocument:
|
| 22 |
+
id: str
|
| 23 |
+
title: str
|
| 24 |
+
url: str
|
| 25 |
+
snippet: str
|
| 26 |
+
score: float
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
async def _parse_tavily_response(result: dict) -> List[MCPDocument]:
|
| 30 |
+
"""Parse Tavily API response into MCPDocument objects."""
|
| 31 |
+
documents: List[MCPDocument] = []
|
| 32 |
+
|
| 33 |
+
# Handle both direct results and nested results
|
| 34 |
+
results_list = result.get("results", [])
|
| 35 |
+
if not results_list and isinstance(result, list):
|
| 36 |
+
results_list = result
|
| 37 |
+
|
| 38 |
+
for idx, item in enumerate(results_list):
|
| 39 |
+
documents.append(
|
| 40 |
+
MCPDocument(
|
| 41 |
+
id=item.get("id", str(idx)),
|
| 42 |
+
title=item.get("title", "Untitled"),
|
| 43 |
+
url=item.get("url", ""),
|
| 44 |
+
snippet=item.get("content", ""),
|
| 45 |
+
score=item.get("score", 0.0),
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
return documents
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
async def _invoke_mcp_tavily(query: str, max_results: int = 5) -> List[MCPDocument]:
|
| 52 |
+
"""Invoke Tavily via MCP server with multiple payload format attempts."""
|
| 53 |
+
mcp_url = settings.mcp_tavily_url or os.getenv("MCP_TAVILY_URL")
|
| 54 |
+
|
| 55 |
+
if not mcp_url:
|
| 56 |
+
logger.debug("web_search.mcp_not_configured")
|
| 57 |
+
raise RuntimeError("MCP URL not configured")
|
| 58 |
+
|
| 59 |
+
# Try different payload formats that MCP servers commonly use
|
| 60 |
+
payload_formats = [
|
| 61 |
+
# Format 1: MCP protocol with tools/call
|
| 62 |
+
{
|
| 63 |
+
"jsonrpc": "2.0",
|
| 64 |
+
"id": 1,
|
| 65 |
+
"method": "tools/call",
|
| 66 |
+
"params": {
|
| 67 |
+
"name": "tavily_search",
|
| 68 |
+
"arguments": {
|
| 69 |
+
"query": query,
|
| 70 |
+
"max_results": max_results
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
},
|
| 74 |
+
# Format 2: Direct tavily_search method
|
| 75 |
+
{
|
| 76 |
+
"jsonrpc": "2.0",
|
| 77 |
+
"id": 1,
|
| 78 |
+
"method": "tavily_search",
|
| 79 |
+
"params": {
|
| 80 |
+
"query": query,
|
| 81 |
+
"max_results": max_results
|
| 82 |
+
}
|
| 83 |
+
},
|
| 84 |
+
# Format 3: Call tool method
|
| 85 |
+
{
|
| 86 |
+
"jsonrpc": "2.0",
|
| 87 |
+
"id": 1,
|
| 88 |
+
"method": "call_tool",
|
| 89 |
+
"params": {
|
| 90 |
+
"name": "tavily_search",
|
| 91 |
+
"arguments": {
|
| 92 |
+
"query": query,
|
| 93 |
+
"max_results": max_results
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
},
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
last_error = None
|
| 100 |
+
|
| 101 |
+
async with aiohttp.ClientSession() as session:
|
| 102 |
+
# Quick health check
|
| 103 |
+
try:
|
| 104 |
+
async with session.head(mcp_url, timeout=10) as health_resp:
|
| 105 |
+
if health_resp.status not in (200, 405): # 405 = Method Not Allowed for HEAD
|
| 106 |
+
logger.warning("web_search.mcp_health_check_failed",
|
| 107 |
+
status=health_resp.status)
|
| 108 |
+
except Exception as health_exc:
|
| 109 |
+
logger.warning("web_search.mcp_unreachable",
|
| 110 |
+
error=str(health_exc),
|
| 111 |
+
url=mcp_url)
|
| 112 |
+
raise RuntimeError(f"MCP server unreachable: {health_exc}")
|
| 113 |
+
|
| 114 |
+
# Try each payload format
|
| 115 |
+
for idx, payload in enumerate(payload_formats, 1):
|
| 116 |
+
try:
|
| 117 |
+
logger.debug("web_search.mcp_attempting_format",
|
| 118 |
+
format_number=idx,
|
| 119 |
+
method=payload.get("method"))
|
| 120 |
+
|
| 121 |
+
headers = {
|
| 122 |
+
"Accept": "text/event-stream, application/json",
|
| 123 |
+
"Content-Type": "application/json",
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
async with session.post(mcp_url, headers=headers, json=payload, timeout=30) as resp:
|
| 127 |
+
ctype = resp.headers.get("Content-Type", "")
|
| 128 |
+
status = resp.status
|
| 129 |
+
|
| 130 |
+
if status != 200:
|
| 131 |
+
response_text = await resp.text()
|
| 132 |
+
logger.debug("web_search.mcp_format_failed",
|
| 133 |
+
format_number=idx,
|
| 134 |
+
status=status,
|
| 135 |
+
response=response_text[:200])
|
| 136 |
+
last_error = f"HTTP {status}: {response_text[:100]}"
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
# Handle SSE response
|
| 140 |
+
if "text/event-stream" in ctype:
|
| 141 |
+
last_json = None
|
| 142 |
+
async for raw in resp.content:
|
| 143 |
+
line = raw.decode("utf-8", errors="ignore").strip()
|
| 144 |
+
if line.startswith("data:"):
|
| 145 |
+
data_str = line[5:].strip()
|
| 146 |
+
if data_str and data_str != "[DONE]":
|
| 147 |
+
last_json = data_str
|
| 148 |
+
|
| 149 |
+
if not last_json:
|
| 150 |
+
logger.debug("web_search.mcp_empty_sse", format_number=idx)
|
| 151 |
+
last_error = "Empty SSE stream"
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
data = json.loads(last_json)
|
| 155 |
+
else:
|
| 156 |
+
# Handle JSON response
|
| 157 |
+
data = await resp.json()
|
| 158 |
+
|
| 159 |
+
# Check for JSON-RPC error
|
| 160 |
+
if "error" in data:
|
| 161 |
+
error_detail = data["error"]
|
| 162 |
+
logger.debug("web_search.mcp_format_error",
|
| 163 |
+
format_number=idx,
|
| 164 |
+
error=error_detail)
|
| 165 |
+
last_error = f"RPC error: {error_detail}"
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
# Extract result
|
| 169 |
+
result = data.get("result", {})
|
| 170 |
+
|
| 171 |
+
# Handle MCP protocol response format
|
| 172 |
+
if "content" in result and isinstance(result["content"], list):
|
| 173 |
+
for content_item in result["content"]:
|
| 174 |
+
if content_item.get("type") == "text":
|
| 175 |
+
text_data = content_item.get("text", "")
|
| 176 |
+
try:
|
| 177 |
+
tavily_result = json.loads(text_data)
|
| 178 |
+
documents = await _parse_tavily_response(tavily_result)
|
| 179 |
+
if documents:
|
| 180 |
+
logger.info("web_search.mcp_success",
|
| 181 |
+
format_number=idx,
|
| 182 |
+
document_count=len(documents))
|
| 183 |
+
return documents
|
| 184 |
+
except json.JSONDecodeError:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
# Handle direct Tavily response format
|
| 188 |
+
if "results" in result:
|
| 189 |
+
documents = await _parse_tavily_response(result)
|
| 190 |
+
if documents:
|
| 191 |
+
logger.info("web_search.mcp_success",
|
| 192 |
+
format_number=idx,
|
| 193 |
+
document_count=len(documents))
|
| 194 |
+
return documents
|
| 195 |
+
|
| 196 |
+
last_error = f"No results in response structure: {list(result.keys())}"
|
| 197 |
+
|
| 198 |
+
except asyncio.TimeoutError:
|
| 199 |
+
logger.debug("web_search.mcp_timeout", format_number=idx)
|
| 200 |
+
last_error = "Request timeout"
|
| 201 |
+
continue
|
| 202 |
+
except Exception as exc:
|
| 203 |
+
logger.debug("web_search.mcp_format_exception",
|
| 204 |
+
format_number=idx,
|
| 205 |
+
error=str(exc),
|
| 206 |
+
error_type=type(exc).__name__)
|
| 207 |
+
last_error = str(exc)
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
# All formats failed
|
| 211 |
+
logger.warning("web_search.mcp_all_formats_failed",
|
| 212 |
+
last_error=last_error,
|
| 213 |
+
formats_tried=len(payload_formats))
|
| 214 |
+
raise RuntimeError(f"MCP failed: {last_error}")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# LangChain Tavily tool (preferred fallback)
|
| 218 |
+
try:
|
| 219 |
+
from langchain_tavily import TavilySearch # type: ignore
|
| 220 |
+
except Exception:
|
| 221 |
+
TavilySearch = None # type: ignore
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
async def _run_tavily_langchain(query: str, max_results: int = 5) -> List[MCPDocument]:
|
| 225 |
+
"""Use LangChain Tavily tool - the recommended integration."""
|
| 226 |
+
if TavilySearch is None:
|
| 227 |
+
raise RuntimeError("LangChain Tavily not installed. Run: pip install -U langchain-tavily")
|
| 228 |
+
|
| 229 |
+
if not settings.tavily_api_key:
|
| 230 |
+
raise RuntimeError("TAVILY_API_KEY not configured")
|
| 231 |
+
|
| 232 |
+
# Ensure API key is in environment for LangChain
|
| 233 |
+
os.environ["TAVILY_API_KEY"] = settings.tavily_api_key
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
# Initialize the tool with recommended settings
|
| 237 |
+
tool = TavilySearch(
|
| 238 |
+
max_results=max_results,
|
| 239 |
+
topic="general",
|
| 240 |
+
search_depth="basic",
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Invoke the tool - it returns a dict with results
|
| 244 |
+
result = await asyncio.to_thread(tool.invoke, {"query": query})
|
| 245 |
+
|
| 246 |
+
documents: List[MCPDocument] = []
|
| 247 |
+
results_list = result.get("results", [])
|
| 248 |
+
|
| 249 |
+
for idx, item in enumerate(results_list):
|
| 250 |
+
documents.append(
|
| 251 |
+
MCPDocument(
|
| 252 |
+
id=item.get("id", str(idx)),
|
| 253 |
+
title=item.get("title", "Untitled"),
|
| 254 |
+
url=item.get("url", ""),
|
| 255 |
+
snippet=item.get("content", ""),
|
| 256 |
+
score=item.get("score", 0.0),
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
return documents
|
| 261 |
+
|
| 262 |
+
except Exception as exc:
|
| 263 |
+
logger.error("web_search.langchain_execution_error",
|
| 264 |
+
error=str(exc),
|
| 265 |
+
error_type=type(exc).__name__)
|
| 266 |
+
raise RuntimeError(f"LangChain Tavily failed: {exc}")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# Direct Tavily SDK (last resort fallback)
|
| 270 |
+
try:
|
| 271 |
+
from tavily import TavilyClient # type: ignore
|
| 272 |
+
except Exception:
|
| 273 |
+
TavilyClient = None # type: ignore
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
async def _run_tavily_sdk(query: str, max_results: int = 5) -> List[MCPDocument]:
|
| 277 |
+
"""Direct Tavily SDK as last resort fallback."""
|
| 278 |
+
if TavilyClient is None:
|
| 279 |
+
raise RuntimeError("Tavily SDK not installed. Run: pip install tavily-python")
|
| 280 |
+
|
| 281 |
+
if not settings.tavily_api_key:
|
| 282 |
+
raise RuntimeError("TAVILY_API_KEY not configured")
|
| 283 |
+
|
| 284 |
+
client = TavilyClient(api_key=settings.tavily_api_key)
|
| 285 |
+
result = await asyncio.to_thread(
|
| 286 |
+
client.search,
|
| 287 |
+
query=query,
|
| 288 |
+
max_results=max_results,
|
| 289 |
+
include_images=False
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
documents: List[MCPDocument] = []
|
| 293 |
+
for idx, item in enumerate(result.get("results", [])):
|
| 294 |
+
documents.append(
|
| 295 |
+
MCPDocument(
|
| 296 |
+
id=item.get("id", str(idx)),
|
| 297 |
+
title=item.get("title", "Untitled"),
|
| 298 |
+
url=item.get("url", ""),
|
| 299 |
+
snippet=item.get("content", ""),
|
| 300 |
+
score=item.get("score", 0.0),
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
return documents
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
async def run_web_search_with_fallback(query: str) -> Optional[SearchResult]:
|
| 307 |
+
"""Run web search with fallback chain: MCP → LangChain Tavily → SDK.
|
| 308 |
+
|
| 309 |
+
Priority order:
|
| 310 |
+
1. MCP Tavily (if configured) - Remote MCP server integration
|
| 311 |
+
2. LangChain Tavily - Official LangChain integration (recommended by Tavily)
|
| 312 |
+
3. Tavily SDK - Direct Python SDK as last resort
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
all_documents: List[MCPDocument] = []
|
| 316 |
+
sources_used: List[str] = []
|
| 317 |
+
|
| 318 |
+
# 1) Try MCP Tavily first (if configured)
|
| 319 |
+
mcp_url = settings.mcp_tavily_url or os.getenv("MCP_TAVILY_URL")
|
| 320 |
+
if mcp_url:
|
| 321 |
+
try:
|
| 322 |
+
tavily_docs = await _invoke_mcp_tavily(query, max_results=5)
|
| 323 |
+
if tavily_docs:
|
| 324 |
+
all_documents.extend(tavily_docs[:5])
|
| 325 |
+
sources_used.append("tavily-mcp")
|
| 326 |
+
logger.info("web_search.mcp_completed", document_count=len(tavily_docs))
|
| 327 |
+
except Exception as exc:
|
| 328 |
+
logger.warning("web_search.mcp_failed", error=str(exc))
|
| 329 |
+
|
| 330 |
+
# 2) Fallback to LangChain Tavily (preferred method)
|
| 331 |
+
if len(all_documents) < 3:
|
| 332 |
+
try:
|
| 333 |
+
lc_docs = await _run_tavily_langchain(query, max_results=5)
|
| 334 |
+
existing_urls = {doc.url for doc in all_documents if doc.url}
|
| 335 |
+
|
| 336 |
+
added_count = 0
|
| 337 |
+
for doc in lc_docs:
|
| 338 |
+
if doc.url not in existing_urls:
|
| 339 |
+
all_documents.append(doc)
|
| 340 |
+
added_count += 1
|
| 341 |
+
if len(all_documents) >= 5:
|
| 342 |
+
break
|
| 343 |
+
|
| 344 |
+
if added_count > 0:
|
| 345 |
+
sources_used.append("tavily-langchain")
|
| 346 |
+
logger.info("web_search.langchain_completed",
|
| 347 |
+
document_count=added_count,
|
| 348 |
+
total_documents=len(all_documents))
|
| 349 |
+
except Exception as lc_exc:
|
| 350 |
+
logger.warning("web_search.langchain_failed", error=str(lc_exc))
|
| 351 |
+
|
| 352 |
+
# 3) Last resort: Direct Tavily SDK
|
| 353 |
+
if len(all_documents) < 3:
|
| 354 |
+
try:
|
| 355 |
+
tavily_sdk_docs = await _run_tavily_sdk(query, max_results=5)
|
| 356 |
+
existing_urls = {doc.url for doc in all_documents if doc.url}
|
| 357 |
+
|
| 358 |
+
added_count = 0
|
| 359 |
+
for doc in tavily_sdk_docs:
|
| 360 |
+
if doc.url not in existing_urls:
|
| 361 |
+
all_documents.append(doc)
|
| 362 |
+
added_count += 1
|
| 363 |
+
if len(all_documents) >= 5:
|
| 364 |
+
break
|
| 365 |
+
|
| 366 |
+
if added_count > 0:
|
| 367 |
+
sources_used.append("tavily-sdk")
|
| 368 |
+
logger.info("web_search.sdk_completed",
|
| 369 |
+
document_count=added_count,
|
| 370 |
+
total_documents=len(all_documents))
|
| 371 |
+
except Exception as sdk_exc:
|
| 372 |
+
logger.warning("web_search.sdk_failed", error=str(sdk_exc))
|
| 373 |
+
|
| 374 |
+
if not all_documents:
|
| 375 |
+
logger.error("web_search.all_sources_failed", query=query)
|
| 376 |
+
return None
|
| 377 |
+
|
| 378 |
+
web_documents = [
|
| 379 |
+
WebDocument(
|
| 380 |
+
id=doc.id,
|
| 381 |
+
title=doc.title,
|
| 382 |
+
url=doc.url,
|
| 383 |
+
snippet=doc.snippet,
|
| 384 |
+
score=doc.score,
|
| 385 |
+
)
|
| 386 |
+
for doc in all_documents[:5]
|
| 387 |
+
]
|
| 388 |
+
|
| 389 |
+
source = "+".join(sources_used) if sources_used else "unknown"
|
| 390 |
+
logger.info("web_search.completed",
|
| 391 |
+
source=source,
|
| 392 |
+
total_documents=len(web_documents),
|
| 393 |
+
query=query)
|
| 394 |
+
|
| 395 |
+
return SearchResult(query=query, source=source, documents=web_documents)
|
app/workflows/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Workflow graphs for the math agent."""
|
| 2 |
+
|
| 3 |
+
|
app/workflows/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (193 Bytes). View file
|
|
|
app/workflows/__pycache__/langgraph_pipeline.cpython-313.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
app/workflows/langgraph_pipeline.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LangGraph implementation of the math agent workflow."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Literal, TypedDict
|
| 6 |
+
|
| 7 |
+
from langgraph.graph import END, StateGraph
|
| 8 |
+
|
| 9 |
+
from ..config import settings
|
| 10 |
+
from ..guardrails import run_input_guardrails, run_output_guardrails
|
| 11 |
+
from ..logger import get_logger
|
| 12 |
+
from ..schemas import Citation, RetrievalContext
|
| 13 |
+
from ..services.vector_store import VectorStore
|
| 14 |
+
from ..tools.dspy_pipeline import SearchResult, generate_solution_with_cot
|
| 15 |
+
from ..tools.web_search import run_web_search_with_fallback
|
| 16 |
+
|
| 17 |
+
logger = get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AgentGraphState(TypedDict, total=False):
|
| 21 |
+
query: str
|
| 22 |
+
sanitized_query: str
|
| 23 |
+
gateway_trace: list[str]
|
| 24 |
+
knowledge_hits: list[RetrievalContext] # Combined KB + web for generation
|
| 25 |
+
kb_hits: list[RetrievalContext] # KB hits only (for UI)
|
| 26 |
+
web_hits: list[RetrievalContext] # Web hits only (for UI)
|
| 27 |
+
search_result: SearchResult | None
|
| 28 |
+
steps: list[tuple[str, str]]
|
| 29 |
+
answer: str
|
| 30 |
+
source: str
|
| 31 |
+
citations: list[Citation]
|
| 32 |
+
retrieved_from_kb: bool
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def build_math_agent_graph(vector_store: VectorStore) -> StateGraph:
|
| 36 |
+
"""Compile the LangGraph workflow for the math agent."""
|
| 37 |
+
|
| 38 |
+
graph = StateGraph(AgentGraphState)
|
| 39 |
+
|
| 40 |
+
async def input_guardrails_node(state: AgentGraphState) -> AgentGraphState:
|
| 41 |
+
sanitized = run_input_guardrails(state["query"]) if settings.enforce_input_guardrails else state["query"]
|
| 42 |
+
trace = state.get("gateway_trace", []) + ["input_guardrails_pass"]
|
| 43 |
+
return {"sanitized_query": sanitized, "gateway_trace": trace}
|
| 44 |
+
|
| 45 |
+
async def retrieve_kb_node(state: AgentGraphState) -> AgentGraphState:
|
| 46 |
+
query = state["sanitized_query"]
|
| 47 |
+
kb_results = vector_store.search(query)
|
| 48 |
+
kb_contexts: list[RetrievalContext] = []
|
| 49 |
+
for result in kb_results:
|
| 50 |
+
# kb_results is a list of dicts with: document_id, question, answer, source, similarity
|
| 51 |
+
kb_contexts.append(
|
| 52 |
+
RetrievalContext(
|
| 53 |
+
document_id=result.get("document_id", ""),
|
| 54 |
+
question=result.get("question", ""),
|
| 55 |
+
answer=result.get("answer", ""),
|
| 56 |
+
similarity=result.get("similarity", 0.0) or 0.0,
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
trace = state.get("gateway_trace", []) + ["kb_search_complete"]
|
| 60 |
+
# Store KB contexts separately so we can combine them later
|
| 61 |
+
return {
|
| 62 |
+
"knowledge_hits": kb_contexts,
|
| 63 |
+
"kb_hits": kb_contexts, # Keep KB hits separate
|
| 64 |
+
"gateway_trace": trace
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def route_after_retrieval(state: AgentGraphState) -> Literal["kb", "search"]:
|
| 68 |
+
# Always proceed to both KB and web search, then generate with Gemini
|
| 69 |
+
# This ensures we always get comprehensive results
|
| 70 |
+
contexts = state.get("knowledge_hits", []) or []
|
| 71 |
+
if contexts and len(contexts) > 0:
|
| 72 |
+
top_similarity = contexts[0].similarity if contexts[0].similarity is not None else 0.0
|
| 73 |
+
if top_similarity >= settings.similarity_threshold:
|
| 74 |
+
logger.info("router.has_good_kb_match", similarity=top_similarity, threshold=settings.similarity_threshold)
|
| 75 |
+
# Still go to web search to enhance with additional context
|
| 76 |
+
return "search"
|
| 77 |
+
logger.info("router.kb_below_threshold", top_similarity=top_similarity, threshold=settings.similarity_threshold)
|
| 78 |
+
else:
|
| 79 |
+
logger.info("router.no_kb_results")
|
| 80 |
+
# Always proceed to web search to get additional context
|
| 81 |
+
return "search"
|
| 82 |
+
|
| 83 |
+
async def kb_generation_node(state: AgentGraphState) -> AgentGraphState:
|
| 84 |
+
# This node is deprecated - we now always go through web search then generation
|
| 85 |
+
# But keep it for backward compatibility in case routing changes
|
| 86 |
+
contexts = state.get("knowledge_hits", [])
|
| 87 |
+
# Also fetch web search to combine with KB
|
| 88 |
+
query = state["sanitized_query"]
|
| 89 |
+
search_result = await run_web_search_with_fallback(query)
|
| 90 |
+
web_contexts = []
|
| 91 |
+
if search_result:
|
| 92 |
+
web_contexts = [
|
| 93 |
+
RetrievalContext(
|
| 94 |
+
document_id=document.id,
|
| 95 |
+
question=document.title,
|
| 96 |
+
answer=document.snippet,
|
| 97 |
+
similarity=document.score or 0.0,
|
| 98 |
+
)
|
| 99 |
+
for document in search_result.documents
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
# Combine KB and web contexts
|
| 103 |
+
all_contexts = contexts + web_contexts
|
| 104 |
+
steps, answer_text = await generate_solution_with_cot(state["sanitized_query"], all_contexts, search_metadata=search_result)
|
| 105 |
+
trace = state.get("gateway_trace", []) + ["router->kb+web"]
|
| 106 |
+
return {
|
| 107 |
+
"steps": steps,
|
| 108 |
+
"answer": answer_text,
|
| 109 |
+
"source": "kb+web",
|
| 110 |
+
"retrieved_from_kb": True,
|
| 111 |
+
"citations": state.get("citations", []),
|
| 112 |
+
"gateway_trace": trace,
|
| 113 |
+
"search_result": search_result,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
async def web_search_node(state: AgentGraphState) -> AgentGraphState:
|
| 117 |
+
query = state["sanitized_query"]
|
| 118 |
+
# Get KB contexts from previous node (preserve them separately)
|
| 119 |
+
kb_contexts = state.get("kb_hits", []) or state.get("knowledge_hits", [])
|
| 120 |
+
|
| 121 |
+
# Always attempt web search
|
| 122 |
+
search_result = await run_web_search_with_fallback(query)
|
| 123 |
+
|
| 124 |
+
# Combine KB contexts with web search contexts for Gemini generation
|
| 125 |
+
all_contexts = list(kb_contexts) if kb_contexts else []
|
| 126 |
+
web_contexts = []
|
| 127 |
+
|
| 128 |
+
if search_result and search_result.documents:
|
| 129 |
+
web_contexts = [
|
| 130 |
+
RetrievalContext(
|
| 131 |
+
document_id=document.id,
|
| 132 |
+
question=document.title,
|
| 133 |
+
answer=document.snippet,
|
| 134 |
+
similarity=document.score or 0.0,
|
| 135 |
+
)
|
| 136 |
+
for document in search_result.documents
|
| 137 |
+
]
|
| 138 |
+
all_contexts.extend(web_contexts)
|
| 139 |
+
|
| 140 |
+
citations = [
|
| 141 |
+
Citation(title=document.title, url=document.url)
|
| 142 |
+
for document in search_result.documents
|
| 143 |
+
if document.url
|
| 144 |
+
]
|
| 145 |
+
source = search_result.source
|
| 146 |
+
else:
|
| 147 |
+
logger.warning("web_search.no_results", query=query)
|
| 148 |
+
citations = []
|
| 149 |
+
source = "kb-only" if kb_contexts else "direct"
|
| 150 |
+
|
| 151 |
+
trace = state.get("gateway_trace", []) + ["web_search_complete", f"source={source}"]
|
| 152 |
+
return {
|
| 153 |
+
"search_result": search_result,
|
| 154 |
+
"knowledge_hits": all_contexts, # Combined KB + web contexts for generation
|
| 155 |
+
"kb_hits": kb_contexts, # Keep KB hits separate for UI
|
| 156 |
+
"web_hits": web_contexts, # Keep web hits separate for UI
|
| 157 |
+
"citations": citations,
|
| 158 |
+
"retrieved_from_kb": len(kb_contexts) > 0,
|
| 159 |
+
"source": source,
|
| 160 |
+
"gateway_trace": trace,
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
async def search_generation_node(state: AgentGraphState) -> AgentGraphState:
|
| 164 |
+
# Always use Gemini to generate comprehensive answer from all available contexts
|
| 165 |
+
contexts = state.get("knowledge_hits", []) # Already combined KB + web contexts
|
| 166 |
+
search_metadata = state.get("search_result")
|
| 167 |
+
|
| 168 |
+
logger.info("generation.starting", contexts_count=len(contexts), has_search_metadata=search_metadata is not None)
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
# Always generate with Gemini using all available context
|
| 172 |
+
steps, answer_text = await generate_solution_with_cot(
|
| 173 |
+
state["sanitized_query"],
|
| 174 |
+
contexts,
|
| 175 |
+
search_metadata=search_metadata
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Ensure answer is comprehensive and not empty
|
| 179 |
+
if not answer_text or not answer_text.strip():
|
| 180 |
+
logger.warning("generation.empty_response", query=state["sanitized_query"])
|
| 181 |
+
# Try to construct answer from steps if available
|
| 182 |
+
if steps:
|
| 183 |
+
# Get the last step or combine all step content
|
| 184 |
+
if isinstance(steps[-1], dict):
|
| 185 |
+
answer_text = steps[-1].get("content", "") or steps[-1].get("explanation", "")
|
| 186 |
+
elif isinstance(steps[-1], (list, tuple)) and len(steps[-1]) > 1:
|
| 187 |
+
answer_text = steps[-1][1]
|
| 188 |
+
else:
|
| 189 |
+
answer_text = str(steps[-1])
|
| 190 |
+
|
| 191 |
+
if not answer_text:
|
| 192 |
+
# Combine all step contents
|
| 193 |
+
answer_parts = []
|
| 194 |
+
for step in steps:
|
| 195 |
+
if isinstance(step, dict):
|
| 196 |
+
content = step.get("content", "") or step.get("explanation", "")
|
| 197 |
+
if content:
|
| 198 |
+
answer_parts.append(content)
|
| 199 |
+
elif isinstance(step, (list, tuple)) and len(step) > 1:
|
| 200 |
+
answer_parts.append(step[1])
|
| 201 |
+
answer_text = "\n\n".join(answer_parts) if answer_parts else "Please see the steps above for the solution."
|
| 202 |
+
else:
|
| 203 |
+
answer_text = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
|
| 204 |
+
steps = [{"title": "Error", "content": "Unable to generate solution with available context."}]
|
| 205 |
+
|
| 206 |
+
# Log the response length to debug
|
| 207 |
+
logger.info("generation.completed", answer_length=len(answer_text), steps_count=len(steps))
|
| 208 |
+
except Exception as exc:
|
| 209 |
+
logger.exception("generation.failed", error=str(exc))
|
| 210 |
+
answer_text = "I encountered an error while processing your question. Please try again."
|
| 211 |
+
steps = [{"title": "Error", "content": str(exc)}]
|
| 212 |
+
|
| 213 |
+
# Always return with all context information - let UI decide what to show
|
| 214 |
+
return {
|
| 215 |
+
"steps": steps,
|
| 216 |
+
"answer": answer_text,
|
| 217 |
+
# Keep all context for UI display
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
async def output_guardrails_node(state: AgentGraphState) -> AgentGraphState:
|
| 221 |
+
if not settings.enforce_output_guardrails:
|
| 222 |
+
return {}
|
| 223 |
+
filtered_answer, urls = run_output_guardrails(state["answer"])
|
| 224 |
+
citations = list(state.get("citations", []))
|
| 225 |
+
for url in urls:
|
| 226 |
+
if not any(citation.url == url for citation in citations):
|
| 227 |
+
citations.append(Citation(title="Source", url=url))
|
| 228 |
+
|
| 229 |
+
current_trace = state.get("gateway_trace", [])
|
| 230 |
+
trace = current_trace + ["output_guardrails_pass"]
|
| 231 |
+
|
| 232 |
+
return {"answer": filtered_answer, "citations": citations, "gateway_trace": trace}
|
| 233 |
+
|
| 234 |
+
graph.add_node("input_guardrails", input_guardrails_node)
|
| 235 |
+
graph.add_node("retrieve_kb", retrieve_kb_node)
|
| 236 |
+
graph.add_node("kb_generate", kb_generation_node)
|
| 237 |
+
graph.add_node("web_search", web_search_node)
|
| 238 |
+
graph.add_node("search_generate", search_generation_node)
|
| 239 |
+
graph.add_node("output_guardrails", output_guardrails_node)
|
| 240 |
+
|
| 241 |
+
graph.set_entry_point("input_guardrails")
|
| 242 |
+
graph.add_edge("input_guardrails", "retrieve_kb")
|
| 243 |
+
graph.add_conditional_edges("retrieve_kb", route_after_retrieval, {"kb": "kb_generate", "search": "web_search"})
|
| 244 |
+
graph.add_edge("kb_generate", "output_guardrails")
|
| 245 |
+
graph.add_edge("web_search", "search_generate")
|
| 246 |
+
graph.add_edge("search_generate", "output_guardrails")
|
| 247 |
+
graph.add_edge("output_guardrails", END)
|
| 248 |
+
|
| 249 |
+
return graph.compile()
|
| 250 |
+
|
| 251 |
+
|
backend/data/feedback_db.json
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"message_id": "assistant-1762103842320",
|
| 4 |
+
"query": "Find the derivative of the function f(x)=x^3sin(x^2) and simplify your answer.",
|
| 5 |
+
"agent_response": {
|
| 6 |
+
"answer": "```json{ \"overview\": \"To find the derivative of the given function, we apply the product rule and the chain rule.\", \"steps\": [ { \"title\": \"Step 1: Identify the Rules for Differentiation\", \"content\": \"The function $f(x) = x^3 \\sin(x^2)$ is a product of two functions, $u(x) = x^3$ and $v(x) = \\sin(x^2)$. Therefore, we will use the product rule for differentiation, which states that if $f(x) = u(x)v(x)$, then $f'(x) = u'(x)v(x) + u(x)v'(x)$.\\n\\nAdditionally, to find the derivative of $v(x) = \\sin(x^2)$, we will need to apply the chain rule, which states that if $y = g(h(x))$, then $y' = g'(h(x))h'(x)$.\", \"expression\": \"f'(x) = u'(x)v(x) + u(x)v'(x)\" }, { \"title\": \"Step 2: Differentiate Each Part\", \"content\": \"First, let's find the derivative of $u(x) = x^3$:\\n$u'(x) = \\\\frac{d}{dx}(x^3) = 3x^2$\\n\\nNext, let's find the derivative of $v(x) = \\\\sin(x^2)$ using the chain rule. Let $w = x^2$, so $v(x) = \\\\sin(w)$.\\n$\\\\frac{dv}{dw} = \\\\cos(w)$\\n$\\\\frac{dw}{dx} = \\\\frac{d}{dx}(x^2) = 2x$\\n\\nApplying the chain rule, $v'(x) = \\\\frac{dv}{dw} \\\\cdot \\\\frac{dw}{dx} = \\\\cos(x^2) \\\\cdot 2x = 2x \\\\cos(x^2)$.\", \"expression\": \"u'(x) = 3x^2 \\\\quad \\\\text{and} \\\\quad v'(x) = 2x \\\\cos(x^2)\" }, { \"title\": \"Step 3: Apply the Product Rule\", \"content\": \"Now, we substitute $u(x)$, $u'(x)$, $v(x)$, and $v'(x)$ into the product rule formula:\\n$f'(x) = u'(x)v(x) + u(x)v'(x)$\\n$f'(x) = (3x^2)(\\\\sin(x^2)) + (x^3)(2x \\\\cos(x^2))$\\n$f'(x) = 3x^2 \\\\sin(x^2) + 2x^4 \\\\cos(x^2)$\", \"expression\": \"f'(x) = 3x^2 \\\\sin(x^2) + 2x^4 \\\\cos(x^2)\" }, { \"title\": \"Step 4: Simplify the Answer\", \"content\": \"We can simplify the expression by factoring out the common term $x^2$ from both terms:\\n$f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))$\", \"expression\": \"f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))\" } ], \"final_answer\": \"$f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))$\"}```",
|
| 7 |
+
"steps": [
|
| 8 |
+
{
|
| 9 |
+
"title": "Solution",
|
| 10 |
+
"content": "```json\n{\n \"overview\": \"To find the derivative of the given function, we apply the product rule and the chain rule.\",\n \"steps\": [\n {\n \"title\": \"Step 1: Identify the Rules for Differentiation\",\n \"content\": \"The function $f(x) = x^3 \\sin(x^2)$ is a product of two functions, $u(x) = x^3$ and $v(x) = \\sin(x^2)$. Therefore, we will use the product rule for differentiation, which states that if $f(x) = u(x)v(x)$, then $f'(x) = u'(x)v(x) + u(x)v'(x)$.\\n\\nAdditionally, to find the derivative of $v(x) = \\sin(x^2)$, we will need to apply the chain rule, which states that if $y = g(h(x))$, then $y' = g'(h(x))h'(x)$.\",\n \"expression\": \"f'(x) = u'(x)v(x) + u(x)v'(x)\"\n },\n {\n \"title\": \"Step 2: Differentiate Each Part\",\n \"content\": \"First, let's find the derivative of $u(x) = x^3$:\\n$u'(x) = \\\\frac{d}{dx}(x^3) = 3x^2$\\n\\nNext, let's find the derivative of $v(x) = \\\\sin(x^2)$ using the chain rule. Let $w = x^2$, so $v(x) = \\\\sin(w)$.\\n$\\\\frac{dv}{dw} = \\\\cos(w)$\\n$\\\\frac{dw}{dx} = \\\\frac{d}{dx}(x^2) = 2x$\\n\\nApplying the chain rule, $v'(x) = \\\\frac{dv}{dw} \\\\cdot \\\\frac{dw}{dx} = \\\\cos(x^2) \\\\cdot 2x = 2x \\\\cos(x^2)$.\",\n \"expression\": \"u'(x) = 3x^2 \\\\quad \\\\text{and} \\\\quad v'(x) = 2x \\\\cos(x^2)\"\n },\n {\n \"title\": \"Step 3: Apply the Product Rule\",\n \"content\": \"Now, we substitute $u(x)$, $u'(x)$, $v(x)$, and $v'(x)$ into the product rule formula:\\n$f'(x) = u'(x)v(x) + u(x)v'(x)$\\n$f'(x) = (3x^2)(\\\\sin(x^2)) + (x^3)(2x \\\\cos(x^2))$\\n$f'(x) = 3x^2 \\\\sin(x^2) + 2x^4 \\\\cos(x^2)$\",\n \"expression\": \"f'(x) = 3x^2 \\\\sin(x^2) + 2x^4 \\\\cos(x^2)\"\n },\n {\n \"title\": \"Step 4: Simplify the Answer\",\n \"content\": \"We can simplify the expression by factoring out the common term $x^2$ from both terms:\\n$f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))$\",\n \"expression\": \"f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))\"\n }\n ],\n \"final_answer\": \"$f'(x) = x^2 (3 \\\\sin(x^2) + 2x^2 \\\\cos(x^2))$\"\n}\n```",
|
| 11 |
+
"expression": null
|
| 12 |
+
}
|
| 13 |
+
],
|
| 14 |
+
"retrieved_from_kb": false,
|
| 15 |
+
"knowledge_hits": [
|
| 16 |
+
{
|
| 17 |
+
"document_id": "0",
|
| 18 |
+
"question": "Find the derivative of f(x)=x^3 sin(x) - MyTutor",
|
| 19 |
+
"answer": "The derivative of f(x)=x^3 sin(x) is x^2(3 sin(x) + x cos(x)), using the product rule.",
|
| 20 |
+
"similarity": 0.79119295
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"document_id": "1",
|
| 24 |
+
"question": "Solved The derivative of f(x) can be expressed in the form",
|
| 25 |
+
"answer": "The derivative of f(x) can be expressed in the form f'(x)=g(x)(3sinx+2)2, where g(x) is some function of x. Find g(x).Answer: g(x)=(b",
|
| 26 |
+
"similarity": 0.78512776
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"document_id": "2",
|
| 30 |
+
"question": "What is the first derivative of the function f(x) =x^3sinx?",
|
| 31 |
+
"answer": "The first derivative is: u'v+uv'= 3x^2 sin x + x^3 cos x, or, x^2 (3 sin x + x cos x). However, if you mean x^(3 sin x),",
|
| 32 |
+
"similarity": 0.7173491
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"document_id": "3",
|
| 36 |
+
"question": "The derivative of the function f is f'(x)=x-3sin x^2. Which ...",
|
| 37 |
+
"answer": "The derivative of the function f is f'(x)=x-3sin x^2. Which interval contains the greatest number of relative minimums of f? A. (-2,2) B. (-3,1) C. (-4",
|
| 38 |
+
"similarity": 0.6929958
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"document_id": "4",
|
| 42 |
+
"question": "Larson Calculus 2.3 #44: Differentiate f(x) = sin(x)/x^3 ...",
|
| 43 |
+
"answer": "Let me show you how I find the derivative with the quotient rule.",
|
| 44 |
+
"similarity": 0.63452125
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"citations": [
|
| 48 |
+
{
|
| 49 |
+
"title": "Find the derivative of f(x)=x^3 sin(x) - MyTutor",
|
| 50 |
+
"url": "https://www.mytutor.co.uk/answers/17399/A-Level/Maths/Find-the-derivative-of-f-x-x-3-sin-x/"
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"title": "Solved The derivative of f(x) can be expressed in the form",
|
| 54 |
+
"url": "https://www.chegg.com/homework-help/questions-and-answers/derivative-f-x-expressed-form-f-x-g-x-3sinx-2-2-g-x-function-x-find-g-x--answer-g-x-b-let--q147428016"
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"title": "What is the first derivative of the function f(x) =x^3sinx?",
|
| 58 |
+
"url": "https://www.quora.com/What-is-the-first-derivative-of-the-function-f-x-x-3sinx"
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"title": "The derivative of the function f is f'(x)=x-3sin x^2. Which ...",
|
| 62 |
+
"url": "https://www.gauthmath.com/solution/1813181098343573/116-Multiple-Choice-The-derivative-of-the-function-is-f-x-x-3sin-x2-Which-interv"
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"title": "Larson Calculus 2.3 #44: Differentiate f(x) = sin(x)/x^3 ...",
|
| 66 |
+
"url": "https://www.youtube.com/watch?v=9qbaO_Jihzg"
|
| 67 |
+
}
|
| 68 |
+
],
|
| 69 |
+
"source": "tavily-sdk",
|
| 70 |
+
"feedback_required": true,
|
| 71 |
+
"gateway_trace": [
|
| 72 |
+
"input_guardrails_pass",
|
| 73 |
+
"router.inspect",
|
| 74 |
+
"router->search",
|
| 75 |
+
"search.source=tavily-sdk",
|
| 76 |
+
"output_guardrails_pass"
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
"feedback": {
|
| 80 |
+
"thumbs_up": true,
|
| 81 |
+
"primary_issue": null,
|
| 82 |
+
"has_better_solution": false,
|
| 83 |
+
"solution_type": null,
|
| 84 |
+
"better_solution_text": null,
|
| 85 |
+
"better_solution_pdf_base64": null,
|
| 86 |
+
"better_solution_image_base64": null
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"message_id": "assistant-1762145260594",
|
| 91 |
+
"query": "Find the derivative of the function: f(x)=3x^4−5x^2+2squareroot(x)−7/x^3",
|
| 92 |
+
"agent_response": {
|
| 93 |
+
"answer": "Thus, the function can be rewritten as:",
|
| 94 |
+
"steps": [
|
| 95 |
+
{
|
| 96 |
+
"title": "Overview",
|
| 97 |
+
"content": "Let's find the derivative of the function $f(x)$ step by step.",
|
| 98 |
+
"expression": null
|
| 99 |
+
}
|
| 100 |
+
],
|
| 101 |
+
"retrieved_from_kb": false,
|
| 102 |
+
"knowledge_hits": [
|
| 103 |
+
{
|
| 104 |
+
"document_id": "0",
|
| 105 |
+
"question": "Derivative Calculator - Mathway",
|
| 106 |
+
"answer": "Upgrade Calculators Help Sign In Sign Up # Derivative Calculator **Step 1:** Enter the function you want to find the derivative of in the editor. > The Derivative Calculator supports solving first, second...., fourth derivatives, as well as implicit differentiation and finding the zeros/roots. You can also get a better visual and understanding of the function by using our graphing tool. > Chain Rule: ddx[f(g(x))]=f'(g(x))g'(x) **Step 2:** Click the blue arrow to submit. Choose **\"Find the Derivative\"** from the topic selector and click to see the result! (2x−9)10 using Amazon.Auth.AccessControlPolicy; Mathway requires javascript and a modern browser. Please ensure that your password is at least 8 characters and contains each of the following: * a special character: @$#!%\\*?&",
|
| 107 |
+
"similarity": 0.98579
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"document_id": "1",
|
| 111 |
+
"question": "Derivative Calculator - Math Portal",
|
| 112 |
+
"answer": "* Calculators Math Calculators, Lessons and Formulas * Calculators * Calculators * Derivative Calculator # Derivative calculator * Fraction Calculator * Percentage Calculator Polynomials * Factoring Polynomials * Simplify Polynomials * Multiplication / Division Solving Equations * Polynomial Equations * Square Calculator * Rectangle Calculator * Circle Calculator * Hexagon Calculator * Rhombus Calculator * Trapezoid Calculator * Distance calculator * Midpoint Calculator * Triangle Calculator * Division * Determinant Calculator Calculus Calculators * Limit Calculator * Derivative Calculator * Integral Calculator Equations * Probability Calculator * T-Test Calculator Financial Calculators * Amortization Calculator * Annuity Calculator Other Calculators Related calculators Limit calculator Integral calculator Derivative calculator – Widget Code **Widget preview:** Derivative calculator Online Math Calculators",
|
| 113 |
+
"similarity": 0.98278
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"document_id": "2",
|
| 117 |
+
"question": "Derivative Calculator - Symbolab",
|
| 118 |
+
"answer": "*Example* : Find the derivative of $f\\left(x\\right)=\\frac{1}{x}$. scientific calculator inverse calculator simplify calculator distance calculator fractions calculator interval notation calculator cross product calculator probability calculator derivative calculator series calculator ratios calculator statistics calculator integral calculator inverse laplace transform calculator rounding calculator gcf calculator algebra calculator tangent line calculator trigonometry calculator log calculator standard deviation calculator linear equation calculator antiderivative calculator laplace transform calculator quadratic equation calculator domain calculator decimals calculator limit calculator equation solver definite integral calculator matrix inverse calculator matrix calculator system of equations calculator calculus calculator slope calculator long division calculator factors calculator polynomial calculator square root calculator implicit differentiation calculator word problem solver differential equation calculator average calculator synthetic division calculator",
|
| 119 |
+
"similarity": 0.97997
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"document_id": "3",
|
| 123 |
+
"question": "Calculus - How to find the derivative of a function using the power rule",
|
| 124 |
+
"answer": "Calculus - How to find the derivative of a function using the power rule\nMySecretMathTutor\n241000 subscribers\n6381 likes\n582542 views\n5 Sep 2013\nThis video shows how to find the derivative of a function using the power rule. Remember that this rule only works on functions of the form x^n where n is not equal to zero. For more videos please visit http://www.mysecretmathtutor.com\n189 comments\n",
|
| 125 |
+
"similarity": 0.97435
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"document_id": "4",
|
| 129 |
+
"question": "Finding the Derivative of a Square Root Function Using Definition of ...",
|
| 130 |
+
"answer": "Finding the Derivative of a Square Root Function Using Definition of a Derivative\nPatrick J\n1400000 subscribers\n1825 likes\n268324 views\n11 Dec 2011\nIn this video we compute the derivative of a square root function, namely f(x) = sqrt (2x + 1) by using the definition of a derivative. We could use the power rule along with the chain rule as well of course, but this what one is often asked in the beginning of a calculus course.\nNote that to help simplify the expression for the derivative, we will make use of the conjugate, a common technique to help simplify these types of expressions with square roots.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n#math, #algebra, #khanacademy, #khan, #geometry, #education, #montessori, #calculus, #homeschooling, #functions, #notation, #organicchemistry, #organicchemistrytutor, #calculus, #derivative, #powerrule, #limits, #limit, #graph, #graphing, #conjugate, #infinity, #patrickjmt, #rational, #shortcut, #trick, #commondenominator, #fraction, #university, #admissions, #test #continuity #onesided #infinite #absolutevalue #tangents #tangentsandnormals #tangentline #definition #hindi #india\n126 comments\n",
|
| 131 |
+
"similarity": 0.96945
|
| 132 |
+
}
|
| 133 |
+
],
|
| 134 |
+
"citations": [
|
| 135 |
+
{
|
| 136 |
+
"title": "Derivative Calculator - Mathway",
|
| 137 |
+
"url": "https://www.mathway.com/Calculator/derivative-calculator"
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"title": "Derivative Calculator - Math Portal",
|
| 141 |
+
"url": "https://www.mathportal.org/calculators/calculus/derivative-calculator.php"
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"title": "Derivative Calculator - Symbolab",
|
| 145 |
+
"url": "https://www.symbolab.com/solver/derivative-calculator"
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"title": "Calculus - How to find the derivative of a function using the power rule",
|
| 149 |
+
"url": "https://www.youtube.com/watch?v=pBc4Udqw330"
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"title": "Finding the Derivative of a Square Root Function Using Definition of ...",
|
| 153 |
+
"url": "https://www.youtube.com/watch?v=3bmw9UQxbRI"
|
| 154 |
+
}
|
| 155 |
+
],
|
| 156 |
+
"source": "tavily-sdk",
|
| 157 |
+
"feedback_required": true,
|
| 158 |
+
"gateway_trace": [
|
| 159 |
+
"input_guardrails_pass",
|
| 160 |
+
"router.inspect",
|
| 161 |
+
"router->search",
|
| 162 |
+
"search.source=tavily-sdk",
|
| 163 |
+
"output_guardrails_pass"
|
| 164 |
+
]
|
| 165 |
+
},
|
| 166 |
+
"feedback": {
|
| 167 |
+
"thumbs_up": false,
|
| 168 |
+
"primary_issue": "unclear",
|
| 169 |
+
"has_better_solution": false,
|
| 170 |
+
"solution_type": null,
|
| 171 |
+
"better_solution_text": null,
|
| 172 |
+
"better_solution_pdf_base64": null,
|
| 173 |
+
"better_solution_image_base64": null
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
]
|
data/knowledge_base.jsonl
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"id": "gsm8k-001", "question": "Solve for x: 2x + 5 = 17.", "answer": "x = 6", "step_by_step": ["Subtract 5 from both sides: 2x = 12", "Divide both sides by 2: x = 6"]}
|
| 2 |
+
{"id": "gsm8k-002", "question": "A rectangle has a length of 12 cm and width 5 cm. What is its perimeter?", "answer": "Perimeter = 34 cm", "step_by_step": ["Perimeter of rectangle = 2*(length + width)", "Compute length + width: 12 + 5 = 17", "Multiply by 2: 2*17 = 34 cm"]}
|
| 3 |
+
{"id": "gsm8k-003", "question": "If 3/4 of a number is 18, what is the number?", "answer": "The number is 24", "step_by_step": ["Let the number be n", "3/4 * n = 18", "Multiply both sides by 4/3: n = 18 * 4/3", "Compute: 18 * 4 = 72", "Divide by 3: 72 / 3 = 24"]}
|
| 4 |
+
{"id": "gsm8k-004", "question": "Simplify: (x^2 - 9)/(x - 3)", "answer": "x + 3", "step_by_step": ["Factor numerator: x^2 - 9 = (x - 3)(x + 3)", "Cancel common factor (x - 3)", "Result: x + 3, for x ≠ 3"]}
|
| 5 |
+
{"id": "gsm8k-005", "question": "Evaluate the derivative d/dx of f(x) = 3x^3 - 4x^2 + x - 7.", "answer": "f'(x) = 9x^2 - 8x + 1", "step_by_step": ["Apply power rule to each term", "Derivative of 3x^3 is 9x^2", "Derivative of -4x^2 is -8x", "Derivative of x is 1", "Derivative of constant -7 is 0", "Combine terms: 9x^2 - 8x + 1"]}
|
| 6 |
+
|
mcp_servers/tavily_server.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal MCP-compatible stdio server wrapping Tavily search."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from tavily import TavilyClient
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _respond(response: dict[str, Any]) -> None:
|
| 13 |
+
sys.stdout.write(json.dumps(response) + "\n")
|
| 14 |
+
sys.stdout.flush()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main() -> None:
|
| 18 |
+
client = TavilyClient()
|
| 19 |
+
|
| 20 |
+
for line in sys.stdin:
|
| 21 |
+
line = line.strip()
|
| 22 |
+
if not line:
|
| 23 |
+
continue
|
| 24 |
+
try:
|
| 25 |
+
request = json.loads(line)
|
| 26 |
+
except json.JSONDecodeError:
|
| 27 |
+
continue
|
| 28 |
+
|
| 29 |
+
method = request.get("method")
|
| 30 |
+
req_id = request.get("id")
|
| 31 |
+
|
| 32 |
+
if method == "initialize":
|
| 33 |
+
_respond({
|
| 34 |
+
"jsonrpc": "2.0",
|
| 35 |
+
"id": req_id,
|
| 36 |
+
"result": {
|
| 37 |
+
"serverInfo": {"name": "tavily-mcp", "version": "0.1.0"},
|
| 38 |
+
"capabilities": {"tools": True},
|
| 39 |
+
},
|
| 40 |
+
})
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
if method == "shutdown":
|
| 44 |
+
_respond({"jsonrpc": "2.0", "id": req_id, "result": None})
|
| 45 |
+
break
|
| 46 |
+
|
| 47 |
+
if method == "tavily_search":
|
| 48 |
+
params = request.get("params", {})
|
| 49 |
+
query = params.get("query", "")
|
| 50 |
+
max_results = params.get("max_results", 5)
|
| 51 |
+
include_images = params.get("include_images", False)
|
| 52 |
+
try:
|
| 53 |
+
result = client.search(query=query, max_results=max_results, include_images=include_images)
|
| 54 |
+
_respond({"jsonrpc": "2.0", "id": req_id, "result": result})
|
| 55 |
+
except Exception as exc: # pragma: no cover
|
| 56 |
+
_respond({"jsonrpc": "2.0", "id": req_id, "error": {"code": -32000, "message": str(exc)}})
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
_respond({"jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": "Unknown method"}})
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
main()
|
| 64 |
+
|
| 65 |
+
|
poetry.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "math-agent-backend"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Agentic RAG math tutoring backend with guardrails, MCP search, and feedback loop"
|
| 5 |
+
authors = ["AI Assistant <assistant@example.com>"]
|
| 6 |
+
packages = [{ include = "app" }]
|
| 7 |
+
|
| 8 |
+
[tool.poetry.dependencies]
|
| 9 |
+
python = "^3.13"
|
| 10 |
+
fastapi = "*"
|
| 11 |
+
uvicorn = { extras = ["standard"], version = "*" }
|
| 12 |
+
pydantic = "*"
|
| 13 |
+
pydantic-settings = "*"
|
| 14 |
+
numpy = "*"
|
| 15 |
+
pandas = "*"
|
| 16 |
+
weaviate-client = "*"
|
| 17 |
+
sentence-transformers = "*"
|
| 18 |
+
scikit-learn = "*"
|
| 19 |
+
httpx = "*"
|
| 20 |
+
aiohttp = "*"
|
| 21 |
+
tavily = "*"
|
| 22 |
+
python-dotenv = "*"
|
| 23 |
+
google-generativeai = "*"
|
| 24 |
+
groq = "*"
|
| 25 |
+
# dspy-ai removed from core dependencies to avoid complex transitive
|
| 26 |
+
# dependency conflicts (fastapi/litellm). If you need DSPy features,
|
| 27 |
+
# add it to an optional extras group after verifying compatibility.
|
| 28 |
+
# dspy-ai = { version = "^2.6.0", python = "*" }
|
| 29 |
+
structlog = "*"
|
| 30 |
+
tenacity = "*"
|
| 31 |
+
langgraph = "*"
|
| 32 |
+
|
| 33 |
+
[tool.poetry.group.dev.dependencies]
|
| 34 |
+
pytest = "*"
|
| 35 |
+
pytest-asyncio = "*"
|
| 36 |
+
pytest-cov = "*"
|
| 37 |
+
ruff = "*"
|
| 38 |
+
black = "*"
|
| 39 |
+
mypy = "*"
|
| 40 |
+
|
| 41 |
+
[build-system]
|
| 42 |
+
requires = ["poetry-core>=1.0.0"]
|
| 43 |
+
build-backend = "poetry.core.masonry.api"
|
| 44 |
+
|
render.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
- type: web
|
| 3 |
+
name: math-agent-backend
|
| 4 |
+
env: python
|
| 5 |
+
region: oregon
|
| 6 |
+
plan: free
|
| 7 |
+
buildCommand: |
|
| 8 |
+
pip install poetry
|
| 9 |
+
poetry install --no-interaction --no-ansi --no-root
|
| 10 |
+
startCommand: |
|
| 11 |
+
poetry run uvicorn app.main:app --host 0.0.0.0 --port $PORT
|
| 12 |
+
autoDeploy: true
|
| 13 |
+
envVars:
|
| 14 |
+
- key: WEAVIATE_URL
|
| 15 |
+
sync: false
|
| 16 |
+
- key: WEAVIATE_API_KEY
|
| 17 |
+
sync: false
|
| 18 |
+
- key: TAVILY_API_KEY
|
| 19 |
+
sync: false
|
| 20 |
+
- key: MCP_TAVILY_URL
|
| 21 |
+
sync: false
|
| 22 |
+
- key: GEMINI_API_KEY
|
| 23 |
+
sync: false
|
| 24 |
+
- key: SIMILARITY_THRESHOLD
|
| 25 |
+
value: "0.80"
|
| 26 |
+
- key: ENFORCE_INPUT_GUARDRAILS
|
| 27 |
+
value: "true"
|
| 28 |
+
- key: ENFORCE_OUTPUT_GUARDRAILS
|
| 29 |
+
value: "true"
|
| 30 |
+
|
scripts/build_kb.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Script to build FAISS vector store from knowledge base dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import faiss
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from sentence_transformers import SentenceTransformer
|
| 13 |
+
|
| 14 |
+
from app.config import settings
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_dataset(path: Path) -> list[dict]:
|
| 18 |
+
data: list[dict] = []
|
| 19 |
+
for line in path.read_text(encoding="utf-8").splitlines():
|
| 20 |
+
if line.strip():
|
| 21 |
+
data.append(json.loads(line))
|
| 22 |
+
return data
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_index(records: list[dict]) -> tuple[faiss.Index, np.ndarray]:
|
| 26 |
+
encoder = SentenceTransformer(settings.embedding_model_name)
|
| 27 |
+
corpus = [record["question"] + "\n" + record.get("answer", "") for record in records]
|
| 28 |
+
embeddings = encoder.encode(corpus, show_progress_bar=True)
|
| 29 |
+
embeddings = np.array(embeddings).astype("float32")
|
| 30 |
+
dimension = embeddings.shape[1]
|
| 31 |
+
index = faiss.IndexFlatIP(dimension)
|
| 32 |
+
faiss.normalize_L2(embeddings)
|
| 33 |
+
index.add(embeddings)
|
| 34 |
+
return index, embeddings
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(dataset_path: Path, output_index: Path, output_metadata: Path) -> None:
|
| 38 |
+
records = load_dataset(dataset_path)
|
| 39 |
+
index, _ = build_index(records)
|
| 40 |
+
|
| 41 |
+
output_index.parent.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
output_metadata.parent.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
faiss.write_index(index, str(output_index))
|
| 45 |
+
|
| 46 |
+
metadata = pd.DataFrame(records)
|
| 47 |
+
metadata.to_parquet(output_metadata, index=False)
|
| 48 |
+
|
| 49 |
+
print(f"Vector store written to {output_index}")
|
| 50 |
+
print(f"Metadata written to {output_metadata}")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
parser = argparse.ArgumentParser(description="Build FAISS vector store for math agent")
|
| 55 |
+
parser.add_argument("--dataset", type=Path, default=Path("backend/data/knowledge_base.jsonl"))
|
| 56 |
+
parser.add_argument("--index", type=Path, default=Path(settings.vector_store_path))
|
| 57 |
+
parser.add_argument("--metadata", type=Path, default=Path(settings.vector_store_metadata_path))
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
|
| 60 |
+
main(args.dataset, args.index, args.metadata)
|
| 61 |
+
|
| 62 |
+
|
scripts/build_weaviate_kb.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Script to build Weaviate vector store from knowledge base dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import weaviate
|
| 12 |
+
from sentence_transformers import SentenceTransformer
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from app.config import settings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_dataset(path: Path) -> List[Dict[str, Any]]:
|
| 19 |
+
data: List[Dict[str, Any]] = []
|
| 20 |
+
for line in path.read_text(encoding="utf-8").splitlines():
|
| 21 |
+
if line.strip():
|
| 22 |
+
data.append(json.loads(line))
|
| 23 |
+
return data
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_weaviate_client() -> weaviate.Client:
|
| 27 |
+
"""Initialize Weaviate client with authentication."""
|
| 28 |
+
auth_config = weaviate.auth.AuthApiKey(api_key=settings.weaviate_api_key)
|
| 29 |
+
client = weaviate.Client(
|
| 30 |
+
url=settings.weaviate_url,
|
| 31 |
+
auth_client_secret=auth_config,
|
| 32 |
+
)
|
| 33 |
+
return client
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def setup_schema(client: weaviate.Client) -> None:
|
| 37 |
+
"""Create the schema if it doesn't exist."""
|
| 38 |
+
if not client.schema.exists(settings.weaviate_class_name):
|
| 39 |
+
class_obj = {
|
| 40 |
+
"class": settings.weaviate_class_name,
|
| 41 |
+
"vectorizer": "none", # We'll provide our own vectors
|
| 42 |
+
"properties": [
|
| 43 |
+
{"name": "question", "dataType": ["text"]},
|
| 44 |
+
{"name": "answer", "dataType": ["text"]},
|
| 45 |
+
{"name": "source", "dataType": ["text"]},
|
| 46 |
+
]
|
| 47 |
+
}
|
| 48 |
+
client.schema.create_class(class_obj)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_index(client: weaviate.Client, records: List[Dict[str, Any]], batch_size: int = 100) -> None:
|
| 52 |
+
"""Build the vector index by importing records in batches."""
|
| 53 |
+
encoder = SentenceTransformer(settings.embedding_model_name)
|
| 54 |
+
|
| 55 |
+
with client.batch as batch:
|
| 56 |
+
batch.batch_size = batch_size
|
| 57 |
+
|
| 58 |
+
for record in tqdm(records, desc="Importing records"):
|
| 59 |
+
# Generate embedding from question and answer
|
| 60 |
+
text = record["question"] + "\n" + record.get("answer", "")
|
| 61 |
+
embedding = encoder.encode(text)
|
| 62 |
+
|
| 63 |
+
# Normalize the embedding
|
| 64 |
+
embedding = embedding / np.linalg.norm(embedding)
|
| 65 |
+
|
| 66 |
+
# Prepare properties
|
| 67 |
+
properties = {
|
| 68 |
+
"question": record["question"],
|
| 69 |
+
"answer": record.get("answer", ""),
|
| 70 |
+
"source": record.get("source", "")
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
# Import the object with its vector
|
| 74 |
+
batch.add_data_object(
|
| 75 |
+
data_object=properties,
|
| 76 |
+
class_name=settings.weaviate_class_name,
|
| 77 |
+
vector=embedding.tolist()
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def main(dataset_path: Path) -> None:
|
| 82 |
+
print(f"Loading dataset from {dataset_path}")
|
| 83 |
+
records = load_dataset(dataset_path)
|
| 84 |
+
|
| 85 |
+
print("Initializing Weaviate client")
|
| 86 |
+
client = get_weaviate_client()
|
| 87 |
+
|
| 88 |
+
print("Setting up schema")
|
| 89 |
+
setup_schema(client)
|
| 90 |
+
|
| 91 |
+
print("Building index")
|
| 92 |
+
build_index(client, records)
|
| 93 |
+
|
| 94 |
+
print(f"Successfully imported {len(records)} records to Weaviate")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
parser = argparse.ArgumentParser(description="Build Weaviate vector store for math agent")
|
| 99 |
+
parser.add_argument("--dataset", type=Path, default=Path("backend/data/knowledge_base.jsonl"))
|
| 100 |
+
args = parser.parse_args()
|
| 101 |
+
|
| 102 |
+
main(args.dataset)
|