diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..9d055daf4c8e532c5e823c2a913dedb0cc751c9f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,47 @@ +# Same runtime as API; runs health endpoint + Celery worker (see worker_health.py) +FROM python:3.11-slim-bookworm + +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_ROOT_USER_ACTION=ignore \ + NO_ALBUMENTATIONS_UPDATE=1 \ + OMP_NUM_THREADS=1 \ + MKL_NUM_THREADS=1 \ + OPENBLAS_NUM_THREADS=1 + +WORKDIR /app +ENV PYTHONPATH=/app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ffmpeg \ + pkg-config \ + cmake \ + libcairo2 \ + libcairo2-dev \ + libpango-1.0-0 \ + libpango1.0-dev \ + libpangocairo-1.0-0 \ + libgdk-pixbuf-2.0-0 \ + libffi-dev \ + python3-dev \ + texlive-latex-base \ + texlive-fonts-recommended \ + texlive-latex-extra \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.worker-render.txt . +RUN pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.worker-render.txt + +COPY . . + +RUN python scripts/prewarm_render_worker.py + +ENV PORT=7860 \ + CELERY_WORKER_QUEUES=render +EXPOSE 7860 + +ENTRYPOINT [] +CMD ["sh", "-c", "exec python3 -u worker_health.py"] diff --git a/Dockerfile.worker.ocr b/Dockerfile.worker.ocr new file mode 100644 index 0000000000000000000000000000000000000000..d1d3e74edbcffa40b9c70a09d1e324b397b5e53a --- /dev/null +++ b/Dockerfile.worker.ocr @@ -0,0 +1,42 @@ +# Celery worker: OCR queue only (no Manim / LaTeX / Cairo stack). +FROM python:3.11-slim-bookworm + +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_ROOT_USER_ACTION=ignore \ + NO_ALBUMENTATIONS_UPDATE=1 \ + OMP_NUM_THREADS=1 \ + MKL_NUM_THREADS=1 \ + OPENBLAS_NUM_THREADS=1 + +WORKDIR /app +ENV PYTHONPATH=/app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + pkg-config \ + python3-dev \ + libglib2.0-0 \ + libgomp1 \ + libgl1 \ + libsm6 \ + libxext6 \ + libxrender1 \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.worker-ocr.txt . +RUN pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.worker-ocr.txt + +COPY . . + +RUN python scripts/prewarm_ocr_worker.py + +ENV PORT=7860 \ + CELERY_WORKER_QUEUES=ocr +EXPOSE 7860 + +ENTRYPOINT [] +CMD ["sh", "-c", "exec python3 -u worker_health.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c708b2b775ba7d59b1e306d597d763b138d0830e --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +--- +title: Math Solver Render Worker +emoji: 👷 +colorFrom: blue +colorTo: indigo +sdk: docker +app_port: 7860 +--- + +# Math Solver — Render worker (Manim) + +This Space runs **Celery** via `worker_health.py` and consumes **only** queue **`render`** (`render_geometry_video`). Image sets `CELERY_WORKER_QUEUES=render` by default (`Dockerfile.worker`). + +**Solve** (orchestrator, agents, OCR-in-request when `OCR_USE_CELERY` is off) runs on the **API** Space, not on this worker. + +## OCR offload (separate Space) + +Queue **`ocr`** is handled by a **dedicated OCR worker** (`Dockerfile.worker.ocr`, `README_HF_WORKER_OCR.md`, workflow `deploy-worker-ocr.yml`). On the API, set `OCR_USE_CELERY=true` and deploy an OCR Space that listens on `ocr`. + +## Secrets + +Same broker as the API: `REDIS_URL` / `CELERY_BROKER_URL`, Supabase, OpenRouter (renderer may use LLM paths), etc. + +**GitHub Actions:** repository secrets `HF_TOKEN` and `HF_WORKER_REPO` (`owner/space-name`) for this workflow (`deploy-worker.yml`). diff --git a/README_HF_WORKER_OCR.md b/README_HF_WORKER_OCR.md new file mode 100644 index 0000000000000000000000000000000000000000..5fae3ec6ad5fc8298f2c39417830c41bf8d3c710 --- /dev/null +++ b/README_HF_WORKER_OCR.md @@ -0,0 +1,27 @@ +--- +title: Math Solver OCR Worker +emoji: 👁️ +colorFrom: gray +colorTo: blue +sdk: docker +app_port: 7860 +--- + +# Math Solver — OCR-only worker + +This Space runs **Celery** (`worker_health.py`) consuming **only** the `ocr` queue. + +Set environment: + +- `CELERY_WORKER_QUEUES=ocr` (default in `Dockerfile.worker.ocr`) +- Same `REDIS_URL` / `CELERY_BROKER_URL` / `CELERY_RESULT_BACKEND` as the API + +This Space runs **raw OCR only** (YOLO, PaddleOCR, Pix2Tex). **OpenRouter / LLM tinh chỉnh** không chạy ở đây; API Space gọi `refine_with_llm` sau khi nhận kết quả từ queue `ocr`. + +On the **API** Space, set `OCR_USE_CELERY=true` so `run_ocr_from_url` tasks are sent to this worker instead of running Paddle/Pix2Tex on the API process. + +Optional: `OCR_CELERY_TIMEOUT_SEC` (default `180`). + +**Manim / video** uses a different Celery queue (`render`) and Space — see `README_HF_WORKER.md` and workflow `deploy-worker.yml`. + +GitHub Actions: repository secrets `HF_TOKEN` and `HF_OCR_WORKER_REPO` (`owner/space-name`) enable workflow `deploy-worker-ocr.yml`. diff --git a/agents/geometry_agent.py b/agents/geometry_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..79a4b022fddbb74251cf9ec2588cfc10a35bfce8 --- /dev/null +++ b/agents/geometry_agent.py @@ -0,0 +1,120 @@ +import os +import json +import logging +from openai import AsyncOpenAI +from typing import Dict, Any +from dotenv import load_dotenv + +load_dotenv() +logger = logging.getLogger(__name__) + +from app.url_utils import openai_compatible_api_key, sanitize_env +from app.llm_client import get_llm_client + + +class GeometryAgent: + def __init__(self): + self.llm = get_llm_client() + + async def generate_dsl(self, semantic_data: Dict[str, Any], previous_dsl: str = None) -> str: + logger.info("==[GeometryAgent] Generating DSL from semantic data==") + if previous_dsl: + logger.info(f"[GeometryAgent] Using previous DSL context (len={len(previous_dsl)})") + + system_prompt = """ +You are a Geometry DSL Generator. Convert semantic geometry data into a precise Geometry DSL program. + +=== MULTI-TURN CONTEXT === +If a PREVIOUS DSL is provided, your job is to UPDATE or EXTEND it. +1. DO NOT remove existing points unless the user explicitly asks to "redefine" or "move" them. +2. Ensure new segments/points connect correctly to existing ones. +3. Your output should be the ENTIRE updated DSL, not just the changes. + +=== DSL COMMANDS === +POINT(A) — declare a point +POINT(A, x, y, z) — declare a point with explicit coordinates +LENGTH(AB, 5) — distance between A and B is 5 (2D/3D) +ANGLE(A, 90) — interior angle at vertex A is 90° (2D/3D) +PARALLEL(AB, CD) — segment AB is parallel to CD (2D/3D) +PERPENDICULAR(AB, CD) — segment AB is perpendicular to CD (2D/3D) +MIDPOINT(M, AB) — M is the midpoint of segment AB +SECTION(E, A, C, k) — E satisfies vector AE = k * vector AC (k is decimal) +LINE(A, B) — infinite line passing through A and B +RAY(A, B) — ray starting at A and passing through B +CIRCLE(O, 5) — circle with center O and radius 5 (2D) +SPHERE(O, 5) — sphere with center O and radius 5 (3D) +SEGMENT(M, N) — auxiliary segment MN to be drawn +POLYGON_ORDER(A, B, C, D) — the order in which vertices form the polygon boundary +TRIANGLE(ABC) — equilateral/arbitrary triangle +PYRAMID(S_ABCD) — pyramid with apex S and base ABCD +PRISM(ABC_DEF) — triangular prism + +=== RULES === +1. 3D Coordinates: Use POINT(A, x, y, z) if specific coordinates are given in the problem. +2. Space Geometry: For pyramids/prisms, use the specialized commands. +3. Primary Vertices: Always declare the main vertices of the shape (e.g., A, B, C, D) using POINT(X). +4. POLYGON_ORDER: Always emit POLYGON_ORDER(...) for the main shape using ONLY these primary vertices. +5. All Points: EVERY point mentioned (A, B, C, H, M, etc.) MUST be declared with POINT(Name) first. +6. Altitudes/Perpendiculars: For an altitude AH to BC, use POINT(H) + PERPENDICULAR(AH, BC). +7. Format: Output ONLY DSL lines — NO explanation, NO markdown, NO code blocks. + +=== SHAPE EXAMPLES === + +--- Case: Square Pyramid S.ABCD with side 10, height 15 --- +PYRAMID(S_ABCD) +POINT(A, 0, 0, 0) +POINT(B, 10, 0, 0) +POINT(C, 10, 10, 0) +POINT(D, 0, 10, 0) +POINT(S) +POINT(O) +SECTION(O, A, C, 0.5) +LENGTH(SO, 15) +PERPENDICULAR(SO, AC) +PERPENDICULAR(SO, AB) +POLYGON_ORDER(A, B, C, D) + +--- Case: Right Triangle ABC at A, AB=3, AC=4, altitude AH --- +POLYGON_ORDER(A, B, C) +POINT(A) +POINT(B) +POINT(C) +POINT(H) +LENGTH(AB, 3) +LENGTH(AC, 4) +ANGLE(A, 90) +PERPENDICULAR(AH, BC) +SEGMENT(A, H) + +--- Case: Rectangle ABCD with AB=5, AD=10 --- +POLYGON_ORDER(A, B, C, D) +POINT(A) +POINT(B) +POINT(C) +POINT(D) +LENGTH(AB, 5) +LENGTH(AD, 10) +PERPENDICULAR(AB, AD) +PARALLEL(AB, CD) +PARALLEL(AD, BC) + +[Circle with center O radius 7] +POINT(O) +CIRCLE(O, 7) +""" + + user_content = f"Semantic Data: {json.dumps(semantic_data, ensure_ascii=False)}" + if previous_dsl: + user_content = f"PREVIOUS DSL:\n{previous_dsl}\n\nUPDATE WITH NEW DATA: {json.dumps(semantic_data, ensure_ascii=False)}" + + logger.debug("[GeometryAgent] Calling LLM (Multi-Layer)...") + content = await self.llm.chat_completions_create( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content} + ] + ) + dsl = content.strip() if content else "" + logger.info(f"[GeometryAgent] DSL generated ({len(dsl.splitlines())} lines).") + logger.debug(f"[GeometryAgent] DSL output:\n{dsl}") + return dsl diff --git a/agents/knowledge_agent.py b/agents/knowledge_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0a97c7ed78a849eee05a4892fd88127c32f32e --- /dev/null +++ b/agents/knowledge_agent.py @@ -0,0 +1,135 @@ +import logging +from typing import Dict, Any + +logger = logging.getLogger(__name__) + + +# ─── Shape rule registry ──────────────────────────────────────────────────── +# Each entry: keyword list → augmentation function +# Augmentation receives (values: dict, text: str) and returns updated values dict. + +class KnowledgeAgent: + """Knowledge Agent: Stores geometric theorems and common patterns to augment Parser output.""" + + def augment_semantic_data(self, semantic_data: Dict[str, Any]) -> Dict[str, Any]: + logger.info("==[KnowledgeAgent] Augmenting semantic data==") + text = str(semantic_data.get("input_text", "")).lower() + logger.debug(f"[KnowledgeAgent] Input text for matching: '{text[:200]}'") + + shape_type = self._detect_shape(text, semantic_data.get("type", "")) + if shape_type: + semantic_data["type"] = shape_type + values = semantic_data.get("values", {}) + values = self._augment_values(shape_type, values, text) + semantic_data["values"] = values + else: + logger.info("[KnowledgeAgent] No special rule matched. Returning data unchanged.") + + logger.debug(f"[KnowledgeAgent] Output semantic data: {semantic_data}") + return semantic_data + + # ─── Shape detection ──────────────────────────────────────────────────── + def _detect_shape(self, text: str, llm_type: str) -> str | None: + """Detect shape from text keywords. LLM type provides a hint.""" + checks = [ + (["hình vuông", "square"], "square"), + (["hình chữ nhật", "rectangle"], "rectangle"), + (["hình thoi", "rhombus"], "rhombus"), + (["hình bình hành", "parallelogram"], "parallelogram"), + (["hình thang vuông"], "right_trapezoid"), + (["hình thang", "trapezoid", "trapezium"], "trapezoid"), + (["tam giác vuông", "right triangle"], "right_triangle"), + (["tam giác đều", "equilateral triangle", "equilateral"], "equilateral_triangle"), + (["tam giác cân", "isosceles"], "isosceles_triangle"), + (["tam giác", "triangle"], "triangle"), + (["đường tròn", "circle"], "circle"), + ] + for keywords, shape in checks: + if any(kw in text for kw in keywords): + logger.info(f"[KnowledgeAgent] Rule MATCH: '{shape}' detected (keyword match).") + return shape + + # Fallback: trust LLM-detected type if it's a known type + known = { + "rectangle", "square", "rhombus", "parallelogram", + "trapezoid", "right_trapezoid", "triangle", "right_triangle", + "equilateral_triangle", "isosceles_triangle", "circle", + } + if llm_type in known: + logger.info(f"[KnowledgeAgent] Using LLM-detected type '{llm_type}'.") + return llm_type + + return None + + # ─── Value augmentation ────────────────────────────────────────────────── + def _augment_values(self, shape: str, values: dict, text: str) -> dict: + ab = values.get("AB") + ad = values.get("AD") + bc = values.get("BC") + cd = values.get("CD") + + if shape == "rectangle": + if ab and ad: + values.setdefault("CD", ab) + values.setdefault("BC", ad) + values.setdefault("angle_A", 90) + logger.info(f"[KnowledgeAgent] Rectangle: AB=CD={ab}, AD=BC={ad}, angle_A=90°") + else: + values.setdefault("angle_A", 90) + + elif shape == "square": + side = ab or ad or bc or cd or values.get("side") + if side: + values.update({"AB": side, "AD": side, "angle_A": 90}) + logger.info(f"[KnowledgeAgent] Square: side={side}, angle_A=90°") + else: + values.setdefault("angle_A", 90) + + elif shape == "rhombus": + side = ab or values.get("side") + if side: + values.update({"AB": side, "BC": side, "CD": side, "DA": side}) + logger.info(f"[KnowledgeAgent] Rhombus: all sides={side}") + + elif shape == "parallelogram": + if ab: + values.setdefault("CD", ab) + if ad: + values.setdefault("BC", ad) + logger.info(f"[KnowledgeAgent] Parallelogram: AB||CD, AD||BC") + + elif shape == "trapezoid": + logger.info("[KnowledgeAgent] Trapezoid: AB||CD (bottom||top)") + + elif shape == "right_trapezoid": + logger.info("[KnowledgeAgent] Right trapezoid: AB||CD, AD⊥AB") + values.setdefault("angle_A", 90) + + elif shape == "equilateral_triangle": + side = ab or values.get("side") + if side: + values.update({"AB": side, "BC": side, "CA": side, "angle_A": 60}) + logger.info(f"[KnowledgeAgent] Equilateral triangle: all sides={side}, angle_A=60°") + + elif shape == "right_triangle": + # Try to infer which vertex is the right angle + rt_vertex = _detect_right_angle_vertex(text) + values.setdefault(f"angle_{rt_vertex}", 90) + logger.info(f"[KnowledgeAgent] Right triangle: angle_{rt_vertex}=90°") + + elif shape == "isosceles_triangle": + logger.info("[KnowledgeAgent] Isosceles triangle: AB=AC (default, LLM may override)") + + elif shape == "circle": + logger.info("[KnowledgeAgent] Circle detected — no side augmentation needed.") + + return values + + +def _detect_right_angle_vertex(text: str) -> str: + """Heuristic: detect which vertex is right angle from text.""" + for vertex in ["A", "B", "C", "D"]: + patterns = [f"vuông tại {vertex}", f"góc {vertex} vuông", f"right angle at {vertex}"] + if any(p.lower() in text for p in patterns): + return vertex + return "A" # default diff --git a/agents/ocr_agent.py b/agents/ocr_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec13448c44f688b0de5ffe517d733810f157a1f --- /dev/null +++ b/agents/ocr_agent.py @@ -0,0 +1,112 @@ +import asyncio +import logging + +from vision_ocr.pipeline import OcrVisionPipeline + +logger = logging.getLogger(__name__) + + +class ImprovedOCRAgent: + """ + API-facing OCR: composes ``OcrVisionPipeline`` (vision only) with optional LLM refinement. + Celery OCR workers should import ``OcrVisionPipeline`` directly from ``vision_ocr``. + """ + + def __init__(self, skip_llm_refinement: bool = False): + self._skip_llm_refinement = bool(skip_llm_refinement) + self._vision = OcrVisionPipeline() + logger.info( + "[ImprovedOCRAgent] Vision pipeline ready (skip_llm_refinement=%s)...", + self._skip_llm_refinement, + ) + + if self._skip_llm_refinement: + self.llm = None + logger.info("[ImprovedOCRAgent] LLM client skipped (raw OCR only).") + else: + from app.llm_client import get_llm_client + + self.llm = get_llm_client() + logger.info("[ImprovedOCRAgent] Multi-Layer LLM Client initialized.") + + async def process_image(self, image_path: str) -> str: + combined_text = await self._vision.process_image(image_path) + + if not combined_text.strip(): + return combined_text + + if self._skip_llm_refinement or self.llm is None: + logger.info("[ImprovedOCRAgent] Skipping MegaLLM refinement (raw OCR output).") + return combined_text + + try: + logger.info("[ImprovedOCRAgent] Sending to MegaLLM for refinement...") + refined_text = await asyncio.wait_for( + self.refine_with_llm(combined_text), timeout=30.0 + ) + return refined_text + except asyncio.TimeoutError: + logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.") + return combined_text + except Exception as e: + logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e) + return combined_text + + async def refine_with_llm(self, text: str) -> str: + if not text.strip(): + return "" + if self.llm is None: + logger.warning("[ImprovedOCRAgent] refine_with_llm: no LLM client; returning raw text.") + return text + + prompt = f"""Bạn là một chuyên gia số hóa tài liệu toán học. +Dưới đây là kết quả OCR thô từ một trang sách toán Tiếng Việt. +Kết quả này có thể chứa lỗi chính tả, lỗi định dạng mã LaTeX, hoặc bị ngắt quãng không logic. + +Nhiệm vụ của bạn: +1. Sửa lỗi chính tả tiếng Việt. +2. Đảm bảo các công thức toán học được viết đúng định dạng LaTeX và nằm trong cặp dấu $...$. +3. Giữ nguyên cấu trúc logic của bài toán. +4. Trả về nội dung đã được làm sạch dưới dạng Markdown. + +Nội dung OCR thô: +--- +{text} +--- + +Kết quả làm sạch:""" + + try: + refined = await self.llm.chat_completions_create( + messages=[{"role": "user", "content": prompt}], + temperature=0.1, + ) + logger.info("[ImprovedOCRAgent] LLM refinement complete.") + return refined + except Exception as e: + logger.error("[ImprovedOCRAgent] LLM refinement failed: %s", e) + return text + + async def process_url(self, url: str) -> str: + combined_text = await self._vision.process_url(url) + + if not combined_text.strip() or combined_text.lstrip().startswith("Error:"): + return combined_text + + if self._skip_llm_refinement or self.llm is None: + return combined_text + + try: + return await asyncio.wait_for(self.refine_with_llm(combined_text), timeout=30.0) + except asyncio.TimeoutError: + logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.") + return combined_text + except Exception as e: + logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e) + return combined_text + + +class OCRAgent(ImprovedOCRAgent): + """Alias for compatibility with existing code.""" + + pass diff --git a/agents/orchestrator.py b/agents/orchestrator.py new file mode 100644 index 0000000000000000000000000000000000000000..2448ced01c95ad42a93092ea90f5d048d7845f0d --- /dev/null +++ b/agents/orchestrator.py @@ -0,0 +1,223 @@ +import json +import logging +from typing import Any, Dict + +from agents.geometry_agent import GeometryAgent +from agents.knowledge_agent import KnowledgeAgent +from agents.ocr_agent import OCRAgent +from agents.parser_agent import ParserAgent +from agents.solver_agent import SolverAgent +from app.logutil import log_step +from app.ocr_celery import ocr_from_image_url +from solver.dsl_parser import DSLParser +from solver.engine import GeometryEngine + +logger = logging.getLogger(__name__) + +_CLIP = 2000 + + +def _clip(val: Any, n: int = _CLIP) -> str | None: + if val is None: + return None + if isinstance(val, str): + s = val + else: + s = json.dumps(val, ensure_ascii=False, default=str) + return s if len(s) <= n else s[:n] + "…" + + +def _step_io(step: str, input_val: Any = None, output_val: Any = None) -> None: + """Debug: chỉ input/output (đã cắt), tránh dump dài dòng không cần thiết.""" + log_step(step, input=_clip(input_val), output=_clip(output_val)) + + +class Orchestrator: + def __init__(self): + self.parser_agent = ParserAgent() + self.geometry_agent = GeometryAgent() + self.ocr_agent = OCRAgent() + self.knowledge_agent = KnowledgeAgent() + self.solver_agent = SolverAgent() + self.solver_engine = GeometryEngine() + self.dsl_parser = DSLParser() + + def _generate_step_description(self, semantic_json: Dict[str, Any], engine_result: Dict[str, Any]) -> str: + """Tạo mô tả từng bước vẽ dựa trên kết quả của engine.""" + analysis = semantic_json.get("analysis", "") + if not analysis: + analysis = f"Giải bài toán về {semantic_json.get('type', 'hình học')}." + + steps = ["\n\n**Các bước dựng hình:**"] + drawing_phases = engine_result.get("drawing_phases", []) + + for phase in drawing_phases: + label = phase.get("label", f"Giai đoạn {phase['phase']}") + points = ", ".join(phase.get("points", [])) + segments = ", ".join([f"{s[0]}{s[1]}" for s in phase.get("segments", [])]) + + step_text = f"- **{label}**:" + if points: + step_text += f" Xác định các điểm {points}." + if segments: + step_text += f" Vẽ các đoạn thẳng {segments}." + steps.append(step_text) + + circles = engine_result.get("circles", []) + for c in circles: + steps.append(f"- **Đường tròn**: Vẽ đường tròn tâm {c['center']} bán kính {c['radius']}.") + + return analysis + "\n".join(steps) + + async def run( + self, + text: str, + image_url: str = None, + job_id: str = None, + session_id: str = None, + status_callback=None, + history: list = None, + ) -> Dict[str, Any]: + """ + Run the full pipeline. Optional history allows context-aware solving. + """ + _step_io( + "orchestrate_start", + input_val={ + "job_id": job_id, + "text_len": len(text or ""), + "image_url": image_url, + "history_len": len(history or []), + }, + output_val=None, + ) + + if status_callback: + await status_callback("processing") + + # 1. Extract context from history (if any) + previous_context = None + if history: + # Look for the last assistant message with geometry data + for msg in reversed(history): + if msg.get("role") == "assistant" and msg.get("metadata", {}).get("geometry_dsl"): + previous_context = { + "geometry_dsl": msg["metadata"]["geometry_dsl"], + "coordinates": msg["metadata"].get("coordinates", {}), + "analysis": msg.get("content", ""), + } + break + + if previous_context: + _step_io("context_found", input_val=None, output_val={"dsl_len": len(previous_context["geometry_dsl"])}) + + # 2. Gather input text (OCR or direct) + input_text = text + if image_url: + input_text = await ocr_from_image_url(image_url, self.ocr_agent) + _step_io("step1_ocr", input_val=image_url, output_val=input_text) + else: + _step_io("step1_ocr", input_val="(no image)", output_val=text) + + feedback = None + MAX_RETRIES = 2 + + for attempt in range(MAX_RETRIES + 1): + _step_io( + "attempt", + input_val=f"{attempt + 1}/{MAX_RETRIES + 1}", + output_val=None, + ) + if status_callback: + await status_callback("solving") + + # Parser with context + _step_io("step2_parse", input_val=f"{input_text[:50]}...", output_val=None) + semantic_json = await self.parser_agent.process(input_text, feedback=feedback, context=previous_context) + semantic_json["input_text"] = input_text + _step_io("step2_parse", input_val=None, output_val=semantic_json) + + # Knowledge augmentation + _step_io("step3_knowledge", input_val=semantic_json, output_val=None) + semantic_json = self.knowledge_agent.augment_semantic_data(semantic_json) + _step_io("step3_knowledge", input_val=None, output_val=semantic_json) + + # Geometry DSL with context (passing previous DSL to guide generation) + _step_io("step4_geometry_dsl", input_val=semantic_json, output_val=None) + dsl_code = await self.geometry_agent.generate_dsl( + semantic_json, + previous_dsl=previous_context["geometry_dsl"] if previous_context else None + ) + _step_io("step4_geometry_dsl", input_val=None, output_val=dsl_code) + + _step_io("step5_dsl_parse", input_val=dsl_code, output_val=None) + points, constraints, is_3d = self.dsl_parser.parse(dsl_code) + _step_io( + "step5_dsl_parse", + input_val=None, + output_val={ + "points": len(points), + "constraints": len(constraints), + "is_3d": is_3d, + }, + ) + + _step_io("step6_solve", input_val=f"{len(points)} pts / {len(constraints)} cons (is_3d={is_3d})", output_val=None) + import anyio + engine_result = await anyio.to_thread.run_sync(self.solver_engine.solve, points, constraints, is_3d) + + if engine_result: + coordinates = engine_result.get("coordinates") + _step_io("step6_solve", input_val=None, output_val=coordinates) + logger.info( + "[Orchestrator] geometry solved job_id=%s is_3d=%s n_coords=%d", + job_id, + is_3d, + len(coordinates) if isinstance(coordinates, dict) else 0, + ) + break + + feedback = "Geometry solver failed to find a valid solution for the given constraints. Parallelism or lengths might be inconsistent." + _step_io( + "step6_solve", + input_val=f"attempt {attempt + 1}", + output_val=feedback, + ) + if attempt == MAX_RETRIES: + _step_io( + "orchestrate_abort", + input_val=None, + output_val="solver_exhausted_retries", + ) + return { + "error": "Solver failed after multiple attempts.", + "last_dsl": dsl_code, + } + + _step_io("orchestrate_done", input_val=job_id, output_val="success") + + # 8. Solution calculation (New in v5.1) + solution = None + if engine_result: + _step_io("step8_solve_math", input_val=semantic_json.get("target_question"), output_val=None) + solution = await self.solver_agent.solve(semantic_json, engine_result) + _step_io("step8_solve_math", input_val=None, output_val=solution.get("answer")) + + final_analysis = self._generate_step_description(semantic_json, engine_result) + + status = "success" + return { + "status": status, + "job_id": job_id, + "geometry_dsl": dsl_code, + "coordinates": coordinates, + "polygon_order": engine_result.get("polygon_order", []), + "circles": engine_result.get("circles", []), + "lines": engine_result.get("lines", []), + "rays": engine_result.get("rays", []), + "drawing_phases": engine_result.get("drawing_phases", []), + "semantic": semantic_json, + "semantic_analysis": final_analysis, + "solution": solution, + "is_3d": is_3d, + } diff --git a/agents/parser_agent.py b/agents/parser_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1f3ffb48de8f685935ce9389d1486d94010d4c --- /dev/null +++ b/agents/parser_agent.py @@ -0,0 +1,106 @@ +import os +import json +import logging +from openai import AsyncOpenAI +from typing import Dict, Any +from dotenv import load_dotenv + +load_dotenv() +logger = logging.getLogger(__name__) + +from app.url_utils import openai_compatible_api_key, sanitize_env + + +from app.llm_client import get_llm_client + + +class ParserAgent: + def __init__(self): + self.llm = get_llm_client() + + async def process(self, text: str, feedback: str = None, context: Dict[str, Any] = None) -> Dict[str, Any]: + logger.info(f"==[ParserAgent] Processing input (len={len(text)})==") + if feedback: + logger.warning(f"[ParserAgent] Feedback from previous attempt: {feedback}") + if context: + logger.info(f"[ParserAgent] Using previous context (dsl_len={len(context.get('geometry_dsl', ''))})") + + system_prompt = """ + You are a Geometry Parser Agent. Extract geometric entities and constraints from Vietnamese/LaTeX math problem text. + + === CONTEXT AWARENESS === + If previous context is provided, it means this is a follow-up request. + - Combine old entities with new ones. + - Update 'analysis' to reflect the entire problem state. + + Output ONLY a JSON object with this EXACT structure (no extra keys, no markdown): + { + "entities": ["Point A", "Point B", ...], + "type": "pyramid|prism|sphere|rectangle|triangle|circle|parallelogram|trapezoid|square|rhombus|general", + "values": {"AB": 5, "SO": 15, "radius": 3}, + "target_question": "Câu hỏi cụ thể cần giải (ví dụ: 'Tính diện tích tam giác ABC'). NẾU KHÔNG CÓ CÂU HỎI THÌ ĐỂ null.", + "analysis": "Tóm tắt ngắn gọn toàn bộ bài toán sau khi đã cập nhật các yêu cầu mới bằng tiếng Việt." + } + Rules: + - "analysis" MUST be a meaningful and UP-TO-DATE summary of the problem in Vietnamese. + - "target_question" must be concise. + - Include midpoints, auxiliary points in "entities" if mentioned. + - If feedback is provided, correct your previous output accordingly. + """ + + user_content = f"Text: {text}" + if context: + user_content = f"PREVIOUS ANALYSIS: {context.get('analysis')}\nNEW REQUEST: {text}" + + if feedback: + user_content += f"\nFeedback from previous attempt: {feedback}. Please correct the constraints." + + logger.debug("[ParserAgent] Calling LLM (Multi-Layer)...") + raw = await self.llm.chat_completions_create( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content} + ], + response_format={"type": "json_object"} + ) + + # Pre-process raw string: extract the JSON block if present + import re + clean_raw = raw.strip() + # Handle potential markdown code blocks + if clean_raw.startswith("```"): + import re + match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL) + if match: + clean_raw = match.group(1).strip() + + try: + result = json.loads(clean_raw) + except json.JSONDecodeError as e: + logger.error(f"[ParserAgent] JSON Parse Error: {e}. Attempting regex fallback...") + import re + json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL) + if json_match: + try: + # Handle single quotes if present (common LLM failure) + json_str = json_match.group(1) + if "'" in json_str and '"' not in json_str: + json_str = json_str.replace("'", '"') + result = json.loads(json_str) + except: + result = None + else: + result = None + + if not result: + # Fallback for critical failure + result = { + "entities": [], + "type": "general", + "values": {}, + "target_question": None, + "analysis": text + } + logger.info(f"[ParserAgent] LLM response received.") + logger.debug(f"[ParserAgent] Parsed JSON: {json.dumps(result, ensure_ascii=False, indent=2)}") + return result diff --git a/agents/renderer_agent.py b/agents/renderer_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..b083c13e224ebcab9f1a70de2c71ef4408f79690 --- /dev/null +++ b/agents/renderer_agent.py @@ -0,0 +1,5 @@ +"""Shim: geometry rendering lives in ``geometry_render`` (worker-safe package).""" + +from geometry_render.renderer import RendererAgent + +__all__ = ["RendererAgent"] diff --git a/agents/solver_agent.py b/agents/solver_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..98cac7dd633bd73223d781a6e4ea780df680e28d --- /dev/null +++ b/agents/solver_agent.py @@ -0,0 +1,107 @@ +import json +import logging +import sympy as sp +from typing import Dict, Any, List +from app.llm_client import get_llm_client + +logger = logging.getLogger(__name__) + +class SolverAgent: + def __init__(self): + self.llm = get_llm_client() + + async def solve(self, semantic_data: Dict[str, Any], engine_result: Dict[str, Any]) -> Dict[str, Any]: + """ + Solves the geometric problem based on coordinates and the target question. + Returns a 'solution' dictionary with answer, steps, and symbolic_expression. + """ + target_question = semantic_data.get("target_question") + if not target_question: + # If no question, just return an empty solution structure + return { + "answer": None, + "steps": [], + "symbolic_expression": None + } + + logger.info(f"==[SolverAgent] Solving for: '{target_question}'==") + + input_text = semantic_data.get("input_text", "") + coordinates = engine_result.get("coordinates", {}) + + # We provide the coordinates and semantic context to the LLM to help it reason. + # The LLM is tasked with generating the solution structure directly. + + system_prompt = """ + You are a Geometry Solver Agent. Your goal is to provide a step-by-step solution for a specific geometric question. + + === DATA PROVIDED === + 1. Target Question: The specific question to answer. + 2. Geometry Data: Entities and values extracted from the problem. + 3. Coordinates: Calculated coordinates for all points. + + === REQUIREMENTS === + - Provide the solution in the SAME LANGUAGE as the user's input. + - Use SymPy concepts if appropriate. + - Steps should be clear, concise, and logical. + - The final answer should be numerically or symbolically accurate based on the coordinates and geometric properties. + - For geometric proofs (e.g., "Is AB perpendicular to AC?"), explain the reasoning based on the data. + + Output ONLY a JSON object with this structure: + { + "answer": "Chuỗi văn bản kết quả cuối cùng (kèm đơn vị nếu có)", + "steps": [ + "Bước 1: ...", + "Bước 2: ...", + ... + ], + "symbolic_expression": "Biểu thức toán học rút gọn (LaTeX format optional)" + } + """ + + user_content = f""" + INPUT_TEXT: {input_text} + TARGET_QUESTION: {target_question} + SEMANTIC_DATA: {json.dumps(semantic_data, ensure_ascii=False)} + COORDINATES: {json.dumps(coordinates)} + """ + + logger.debug("[SolverAgent] Requesting solution from LLM...") + try: + raw = await self.llm.chat_completions_create( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content} + ], + response_format={"type": "json_object"} + ) + + clean_raw = raw.strip() + # Handle potential markdown code blocks if the response_format wasn't strictly honored + if clean_raw.startswith("```"): + import re + match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL) + if match: + clean_raw = match.group(1).strip() + + try: + solution = json.loads(clean_raw) + except json.JSONDecodeError: + # Last resort: try to find anything between { and } + import re + json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL) + if json_match: + solution = json.loads(json_match.group(1)) + else: + raise + + logger.info("[SolverAgent] Solution generated successfully.") + return solution + except Exception as e: + logger.error(f"[SolverAgent] Error generating solution: {e}") + logger.debug(f"[SolverAgent] Raw LLM output was: \n{raw if 'raw' in locals() else 'N/A'}") + return { + "answer": "Không thể tính toán lời giải tại thời điểm này.", + "steps": ["Đã xảy ra lỗi trong quá trình xử lý lời giải."], + "symbolic_expression": None + } diff --git a/agents/torch_ultralytics_compat.py b/agents/torch_ultralytics_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..f98a320ce08f03cf720d27f232d07568aa45ecbe --- /dev/null +++ b/agents/torch_ultralytics_compat.py @@ -0,0 +1,5 @@ +"""Shim: moved to ``vision_ocr.compat`` for OCR worker isolation.""" + +from vision_ocr.compat import allow_ultralytics_weights + +__all__ = ["allow_ultralytics_weights"] diff --git a/app/chat_image_upload.py b/app/chat_image_upload.py new file mode 100644 index 0000000000000000000000000000000000000000..beb6965fdf2ba2774ae49cf122cd6dbf7ddc576f --- /dev/null +++ b/app/chat_image_upload.py @@ -0,0 +1,206 @@ +"""Validate and upload chat/solve attachment images to Supabase Storage (image bucket).""" + +from __future__ import annotations + +import logging +import os +import uuid +from typing import Any, Dict, Tuple + +from fastapi import HTTPException + +logger = logging.getLogger(__name__) + + +def _get_next_image_version(session_id: str) -> int: + """Same logic as worker.asset_manager.get_next_version for asset_type image.""" + from app.supabase_client import get_supabase + + supabase = get_supabase() + try: + res = ( + supabase.table("session_assets") + .select("version") + .eq("session_id", session_id) + .eq("asset_type", "image") + .order("version", desc=True) + .limit(1) + .execute() + ) + if res.data: + return res.data[0]["version"] + 1 + return 1 + except Exception as e: + logger.error("Error fetching image version: %s", e) + return 1 + +_MAX_BYTES_DEFAULT = 10 * 1024 * 1024 + +_EXT_TO_MIME: dict[str, str] = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".webp": "image/webp", + ".gif": "image/gif", + ".bmp": "image/bmp", +} + + +def _max_bytes() -> int: + raw = os.getenv("CHAT_IMAGE_MAX_BYTES") + if raw and raw.isdigit(): + return min(int(raw), 50 * 1024 * 1024) + return _MAX_BYTES_DEFAULT + + +def _magic_ok(ext: str, body: bytes) -> bool: + if len(body) < 12: + return False + if ext == ".png": + return body.startswith(b"\x89PNG\r\n\x1a\n") + if ext in (".jpg", ".jpeg"): + return body.startswith(b"\xff\xd8\xff") + if ext == ".webp": + return body.startswith(b"RIFF") and body[8:12] == b"WEBP" + if ext == ".gif": + return body.startswith(b"GIF87a") or body.startswith(b"GIF89a") + if ext == ".bmp": + return body.startswith(b"BM") + return False + + +def validate_chat_image_bytes( + filename: str | None, + body: bytes, + declared_content_type: str | None, +) -> Tuple[str, str]: + """ + Validate size, extension, and magic bytes. + Returns (extension_with_dot, content_type). + """ + max_b = _max_bytes() + if not body: + raise HTTPException(status_code=400, detail="Empty file.") + if len(body) > max_b: + raise HTTPException( + status_code=413, + detail=f"Image too large (max {max_b // (1024 * 1024)} MB).", + ) + + ext = os.path.splitext(filename or "")[1].lower() + if not ext: + ext = ".png" + if ext not in _EXT_TO_MIME: + raise HTTPException( + status_code=400, + detail=f"Unsupported image type: {ext}. Allowed: {', '.join(sorted(_EXT_TO_MIME))}", + ) + + if not _magic_ok(ext, body): + raise HTTPException( + status_code=400, + detail="File content does not match declared image type.", + ) + + mime = _EXT_TO_MIME[ext] + if declared_content_type: + decl = declared_content_type.split(";")[0].strip().lower() + if decl and decl not in ("application/octet-stream", mime) and decl != mime: + logger.warning( + "Content-Type mismatch (declared=%s, inferred=%s); using inferred.", + declared_content_type, + mime, + ) + return ext, mime + + +def upload_session_chat_image( + session_id: str, + job_id: str, + file_bytes: bytes, + ext_with_dot: str, + content_type: str, +) -> Dict[str, Any]: + """ + Upload to SUPABASE_IMAGE_BUCKET (default: image), insert session_assets row. + Returns dict with public_url, storage_path, version, session_asset_id (if returned). + """ + from app.supabase_client import get_supabase + + supabase = get_supabase() + bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image") + raw_ext = ext_with_dot.lstrip(".").lower() + version = _get_next_image_version(session_id) + file_name = f"image_v{version}_{job_id}.{raw_ext}" + storage_path = f"sessions/{session_id}/{file_name}" + + supabase.storage.from_(bucket_name).upload( + path=storage_path, + file=file_bytes, + file_options={"content-type": content_type}, + ) + public_url = supabase.storage.from_(bucket_name).get_public_url(storage_path) + if isinstance(public_url, dict): + public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url) + + row = { + "session_id": session_id, + "job_id": job_id, + "asset_type": "image", + "storage_path": storage_path, + "public_url": public_url, + "version": version, + } + ins = supabase.table("session_assets").insert(row).select("id").execute() + asset_id = None + if ins.data and len(ins.data) > 0: + asset_id = ins.data[0].get("id") + + log_data = { + "public_url": public_url, + "storage_path": storage_path, + "version": version, + "session_asset_id": str(asset_id) if asset_id else None, + } + logger.info("Uploaded chat image: %s", log_data) + return { + "public_url": public_url, + "storage_path": storage_path, + "version": version, + "session_asset_id": str(asset_id) if asset_id else None, + } + + +def upload_ephemeral_ocr_blob( + file_bytes: bytes, + ext_with_dot: str, + content_type: str, +) -> Tuple[str, str]: + """ + Upload bytes to image bucket under _ocr_temp/ for worker-only OCR (no session_assets row). + Returns (storage_path, public_url). Caller must delete_storage_object when done. + """ + from app.supabase_client import get_supabase + + bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image") + raw_ext = ext_with_dot.lstrip(".").lower() or "png" + name = f"_ocr_temp/{uuid.uuid4().hex}.{raw_ext}" + supabase = get_supabase() + supabase.storage.from_(bucket_name).upload( + path=name, + file=file_bytes, + file_options={"content-type": content_type}, + ) + public_url = supabase.storage.from_(bucket_name).get_public_url(name) + if isinstance(public_url, dict): + public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url) + return name, public_url + + +def delete_storage_object(bucket_name: str, storage_path: str) -> None: + try: + from app.supabase_client import get_supabase + + get_supabase().storage.from_(bucket_name).remove([storage_path]) + except Exception as e: + logger.warning("delete_storage_object failed path=%s: %s", storage_path, e) diff --git a/app/dependencies.py b/app/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8c3c227fc6819301045767a69eedfaba987ed1 --- /dev/null +++ b/app/dependencies.py @@ -0,0 +1,69 @@ +from fastapi import HTTPException, Header + +from app.supabase_client import get_supabase, get_supabase_for_user_jwt + + +async def get_current_user_id(authorization: str | None = Header(None)): + """ + Authenticate user using Supabase JWT. + Expected Header: Authorization: Bearer + """ + import os + + if not authorization: + raise HTTPException( + status_code=401, + detail="Authorization header missing or invalid. Use 'Bearer '", + ) + + if os.getenv("ALLOW_TEST_BYPASS") == "true" and authorization.startswith("Test "): + return authorization.split(" ")[1] + + if not authorization.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail="Authorization header missing or invalid. Use 'Bearer '", + ) + + token = authorization.split(" ")[1] + supabase = get_supabase() + + try: + user_response = supabase.auth.get_user(token) + if not user_response or not user_response.user: + raise HTTPException(status_code=401, detail="Invalid session or token.") + + return user_response.user.id + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}") + + +async def get_authenticated_supabase(authorization: str = Header(...)): + """ + Supabase client that carries the user's JWT (anon key + Authorization header). + Use for routes that should respect Row Level Security; pair with app logic as needed. + """ + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail="Authorization header missing or invalid. Use 'Bearer '", + ) + + token = authorization.split(" ")[1] + supabase = get_supabase() + + try: + user_response = supabase.auth.get_user(token) + if not user_response or not user_response.user: + raise HTTPException(status_code=401, detail="Invalid session or token.") + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}") + + try: + return get_supabase_for_user_jwt(token) + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) diff --git a/app/errors.py b/app/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd951dd3cc5b3fa11cdff345edd4784ce5acf98 --- /dev/null +++ b/app/errors.py @@ -0,0 +1,59 @@ +"""Map exceptions to short, user-visible messages (avoid leaking HTML bodies from 404 proxies).""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + + +def _looks_like_html(text: str) -> bool: + t = text.lstrip()[:500].lower() + return t.startswith(" str: + """ + Produce a safe message for chat/UI. Full detail stays in server logs via logger.exception. + """ + # httpx: wrong URL often returns 404 HTML; don't show body + try: + import httpx + + if isinstance(exc, httpx.HTTPStatusError): + req = exc.request + code = exc.response.status_code + url_hint = "" + try: + url_hint = str(req.url.host) if req and req.url else "" + except Exception: + pass + logger.warning( + "HTTPStatusError %s for %s (response not shown to user)", + code, + url_hint or "?", + ) + return ( + "Kiểm tra URL API, khóa bí mật và biến môi trường (OpenRouter/Supabase/Redis)." + ) + + if isinstance(exc, httpx.RequestError): + return "Không kết nối được tới dịch vụ ngoài (mạng hoặc URL sai)." + except ImportError: + pass + + raw = str(exc).strip() + if not raw: + return "Đã xảy ra lỗi không xác định." + + if _looks_like_html(raw): + logger.warning("Suppressed HTML error body from user-facing message") + return ( + "Dịch vụ trả về trang lỗi (thường là URL API sai hoặc endpoint không tồn tại — HTTP 404). " + "Kiểm tra OPENROUTER_MODEL và khóa API trên server." + ) + + if len(raw) > 800: + return raw[:800] + "…" + + return raw diff --git a/app/job_poll.py b/app/job_poll.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c906c3af893e21c55eadb4a505e4d7601ac8c8 --- /dev/null +++ b/app/job_poll.py @@ -0,0 +1,47 @@ +"""Normalize Supabase `jobs` rows for polling / WebSocket clients (stable `job_id` + JSON `result`).""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def _coerce_result(value: Any) -> Any: + if value is None: + return None + if isinstance(value, (dict, list)): + return value + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + logger.warning("job_poll: result is non-JSON string, returning raw") + return {"raw": value} + return value + + +def normalize_job_row_for_client(row: dict[str, Any]) -> dict[str, Any]: + """ + Build a JSON-serializable dict that always includes: + - ``job_id`` (alias of DB ``id``) for clients that expect it on poll bodies + - ``status`` as str + - ``result`` as object/array when stored as JSON string + All other columns are passed through (UUID/datetime become JSON-safe via FastAPI encoder). + """ + out = dict(row) + jid = out.get("id") + if jid is not None: + out["job_id"] = str(jid) + st = out.get("status") + if st is not None: + out["status"] = str(st) + if "result" in out: + out["result"] = _coerce_result(out.get("result")) + if out.get("user_id") is not None: + out["user_id"] = str(out["user_id"]) + if out.get("session_id") is not None: + out["session_id"] = str(out["session_id"]) + return out diff --git a/app/llm_client.py b/app/llm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1aba7a5806e8a0e56824e22f31a271b2189dd2 --- /dev/null +++ b/app/llm_client.py @@ -0,0 +1,100 @@ +import os +import json +import asyncio +import logging +from openai import AsyncOpenAI +from typing import List, Dict, Any, Optional +from app.url_utils import openai_compatible_api_key, sanitize_env + +logger = logging.getLogger(__name__) + +class MultiLayerLLMClient: + def __init__(self): + # 1. Models sequence loading + self.models = [] + for i in range(1, 4): + model = os.getenv(f"OPENROUTER_MODEL_{i}") + if model: + self.models.append(model) + + # Fallback to legacy OPENROUTER_MODEL if no numbered models found + if not self.models: + legacy_model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-001") + self.models = [legacy_model] + + # 2. Key selection (No rotation, always use the first available key) + api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY") + + if not api_key: + logger.error("[LLM] No OpenRouter API key found.") + self.client = None + else: + self.client = AsyncOpenAI( + api_key=openai_compatible_api_key(api_key), + base_url="https://openrouter.ai/api/v1", + timeout=60.0, + default_headers={ + "HTTP-Referer": "https://mathsolver.ai", + "X-Title": "MathSolver Backend", + } + ) + + async def chat_completions_create( + self, + messages: List[Dict[str, str]], + response_format: Optional[Dict[str, str]] = None, + **kwargs + ) -> str: + """ + Implements Model Fallback Sequence: Model 1 -> Model 2 -> Model 3. + Always starts from Model 1 for every new call. + """ + if not self.client: + raise ValueError("No API client configured. Check your API keys.") + + MAX_ATTEMPTS = len(self.models) + RETRY_DELAY = 1.0 # second + + for attempt_idx in range(MAX_ATTEMPTS): + current_model = self.models[attempt_idx] + attempt_num = attempt_idx + 1 + + try: + logger.info(f"[LLM] Attempt {attempt_num}/{MAX_ATTEMPTS} using Model: {current_model}...") + + response = await self.client.chat.completions.create( + model=current_model, + messages=messages, + response_format=response_format, + **kwargs + ) + + if not response or not getattr(response, "choices", None): + raise ValueError(f"Invalid response structure from model {current_model}") + + content = response.choices[0].message.content + if content: + logger.info(f"[LLM] SUCCESS on attempt {attempt_num} ({current_model}).") + return content + + raise ValueError(f"Empty content from model {current_model}") + + except Exception as e: + err_msg = f"{type(e).__name__}: {str(e)}" + logger.warning(f"[LLM] FAILED on attempt {attempt_num} ({current_model}): {err_msg}") + + if attempt_num < MAX_ATTEMPTS: + logger.info(f"[LLM] Retrying next model in {RETRY_DELAY}s...") + await asyncio.sleep(RETRY_DELAY) + else: + logger.error(f"[LLM] FINAL FAILURE after {attempt_num} models.") + raise e + +# Global instance for easy reuse (singleton-ish) +_llm_client = None + +def get_llm_client() -> MultiLayerLLMClient: + global _llm_client + if _llm_client is None: + _llm_client = MultiLayerLLMClient() + return _llm_client diff --git a/app/logging_setup.py b/app/logging_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd89d4815d29b38a606fb3bf2b22a08751eb71a --- /dev/null +++ b/app/logging_setup.py @@ -0,0 +1,112 @@ +"""Logging theo một biến LOG_LEVEL: debug | info | warning | error.""" + +from __future__ import annotations + +import logging +import os +from typing import Final + +_SETUP_DONE = False + +PIPELINE_LOGGER_NAME: Final = "app.pipeline" +CACHE_LOGGER_NAME: Final = "app.cache" +STEPS_LOGGER_NAME: Final = "app.steps" +ACCESS_LOGGER_NAME: Final = "app.access" + + +def _normalize_level() -> str: + raw = os.getenv("LOG_LEVEL", "info").strip().lower() + if raw in ("debug", "info", "warning", "error"): + return raw + return "info" + + +def setup_application_logging() -> None: + """Idempotent; gọi khi khởi động process (uvicorn, celery, worker_health).""" + global _SETUP_DONE + if _SETUP_DONE: + return + _SETUP_DONE = True + + mode = _normalize_level() + + level_map = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + } + root_level = level_map[mode] + + fmt_named = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s" + fmt_short = "%(asctime)s | %(levelname)-8s | %(message)s" + + logging.basicConfig( + level=root_level, + format=fmt_named if mode == "debug" else fmt_short, + datefmt="%H:%M:%S", + force=True, + ) + + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + logging.getLogger("openai").setLevel(logging.WARNING) + logging.getLogger("uvicorn.access").setLevel(logging.WARNING) + logging.getLogger("uvicorn.error").setLevel(logging.INFO) + # HTTP/2 stack (httpx/httpcore) — khi LOG_LEVEL=debug root=DEBUG sẽ tràn log hpack; không cần cho debug app + for _name in ("hpack", "h2", "hyperframe", "urllib3"): + logging.getLogger(_name).setLevel(logging.WARNING) + + if mode == "debug": + logging.getLogger("agents").setLevel(logging.DEBUG) + logging.getLogger("solver").setLevel(logging.DEBUG) + logging.getLogger("app").setLevel(logging.DEBUG) + logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.DEBUG) + logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.DEBUG) + logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.INFO) + logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO) + logging.getLogger("app.main").setLevel(logging.INFO) + logging.getLogger("worker").setLevel(logging.INFO) + elif mode == "info": + # Chỉ HTTP access (app.access) + startup; ẩn chi tiết agents/orchestrator/pipeline SUCCESS + logging.getLogger("agents").setLevel(logging.INFO) + logging.getLogger("solver").setLevel(logging.WARNING) + logging.getLogger("app").setLevel(logging.INFO) + logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING) + logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING) + logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING) + logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO) + logging.getLogger("app.main").setLevel(logging.INFO) + logging.getLogger("worker").setLevel(logging.WARNING) + elif mode == "warning": + logging.getLogger("agents").setLevel(logging.WARNING) + logging.getLogger("solver").setLevel(logging.WARNING) + logging.getLogger("app.routers").setLevel(logging.WARNING) + logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING) + logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING) + logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING) + logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.WARNING) + logging.getLogger("app.main").setLevel(logging.WARNING) + logging.getLogger("worker").setLevel(logging.WARNING) + else: # error + logging.getLogger("agents").setLevel(logging.ERROR) + logging.getLogger("solver").setLevel(logging.ERROR) + logging.getLogger("app.routers").setLevel(logging.ERROR) + logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.ERROR) + logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.ERROR) + logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.ERROR) + logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.ERROR) + logging.getLogger("app.main").setLevel(logging.ERROR) + logging.getLogger("worker").setLevel(logging.ERROR) + + logging.getLogger(__name__).debug( + "LOG_LEVEL=%s root=%s", mode, logging.getLevelName(root_level) + ) + + +def get_log_level() -> str: + return _normalize_level() + + +def is_debug_level() -> bool: + return _normalize_level() == "debug" diff --git a/app/logutil.py b/app/logutil.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfb9efda8f632ce9175abb187fb6631b4d12a01 --- /dev/null +++ b/app/logutil.py @@ -0,0 +1,67 @@ +"""log_step (debug), pipeline (debug), access log ở middleware.""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any + +from app.logging_setup import PIPELINE_LOGGER_NAME, STEPS_LOGGER_NAME + +_pipeline = logging.getLogger(PIPELINE_LOGGER_NAME) +_steps = logging.getLogger(STEPS_LOGGER_NAME) + + +def is_debug_mode() -> bool: + """Chi tiết từng bước chỉ khi LOG_LEVEL=debug.""" + return os.getenv("LOG_LEVEL", "info").strip().lower() == "debug" + + +def _truncate(val: Any, max_len: int = 2000) -> Any: + if val is None: + return None + if isinstance(val, (int, float, bool)): + return val + s = str(val) + if len(s) > max_len: + return s[:max_len] + f"... (+{len(s) - max_len} chars)" + return s + + +def log_step(step: str, **fields: Any) -> None: + """Chỉ khi LOG_LEVEL=debug: DB / cache / orchestrator.""" + if not is_debug_mode(): + return + safe = {k: _truncate(v) for k, v in fields.items()} + try: + payload = json.dumps(safe, ensure_ascii=False, default=str) + except Exception: + payload = str(safe) + _steps.debug("[step:%s] %s", step, payload) + + +def log_pipeline_success(operation: str, **fields: Any) -> None: + """Chỉ hiện khi debug (pipeline SUCCESS không dùng ở info — đã có app.access).""" + if not is_debug_mode(): + return + safe = {k: _truncate(v, 500) for k, v in fields.items()} + _pipeline.info( + "SUCCESS %s %s", + operation, + json.dumps(safe, ensure_ascii=False, default=str), + ) + + +def log_pipeline_failure(operation: str, error: str | None = None, **fields: Any) -> None: + """Thất bại pipeline: luôn dùng WARNING để vẫn thấy khi LOG_LEVEL=warning.""" + if is_debug_mode(): + safe = {k: _truncate(v, 500) for k, v in fields.items()} + _pipeline.warning( + "FAIL %s err=%s %s", + operation, + _truncate(error, 300), + json.dumps(safe, ensure_ascii=False, default=str), + ) + else: + _pipeline.warning("FAIL %s", operation) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0bf6ea24b8da248054d5687ecfd7e0eca78726 --- /dev/null +++ b/app/main.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import logging +import os +import time +import uuid +import warnings + +from dotenv import load_dotenv +from fastapi import Depends, FastAPI, File, HTTPException, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from starlette.requests import Request + +load_dotenv() + +from app.runtime_env import apply_runtime_env_defaults + +apply_runtime_env_defaults() + +os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" +warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") +warnings.filterwarnings("ignore", category=UserWarning, module="albumentations") + +from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging + +setup_application_logging() + +# Routers (after logging) +from app.dependencies import get_current_user_id +from app.ocr_local_file import ocr_from_local_image_path +from app.routers import auth, sessions, solve +from agents.ocr_agent import OCRAgent +from app.routers.solve import get_orchestrator +from app.job_poll import normalize_job_row_for_client +from app.supabase_client import get_supabase +from app.websocket_manager import register_websocket_routes + +logger = logging.getLogger("app.main") +_access = logging.getLogger(ACCESS_LOGGER_NAME) + +app = FastAPI(title="Visual Math Solver API v5.1") + + +@app.middleware("http") +async def access_log_middleware(request: Request, call_next): + """LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error.""" + start = time.perf_counter() + response = await call_next(request) + ms = (time.perf_counter() - start) * 1000 + mode = get_log_level() + method = request.method + path = request.url.path + status = response.status_code + + if mode in ("debug", "info"): + _access.info("%s %s -> %s (%.0fms)", method, path, status, ms) + elif mode == "warning": + if status >= 500: + _access.error("%s %s -> %s (%.0fms)", method, path, status, ms) + elif status >= 400: + _access.warning("%s %s -> %s (%.0fms)", method, path, status, ms) + elif mode == "error": + if status >= 400: + _access.error("%s %s -> %s", method, path, status) + + return response + + +from worker.celery_app import BROKER_URL + +_broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL +if get_log_level() in ("debug", "info"): + logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail) +else: + logger.warning( + "App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail + ) + +app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://localhost:3000", + "http://127.0.0.1:3000", + "http://localhost:3005", + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(auth.router) +app.include_router(sessions.router) +app.include_router(solve.router) + +register_websocket_routes(app) + + +def get_ocr_agent() -> OCRAgent: + """Same OCR instance as the solve pipeline (no duplicate model load).""" + return get_orchestrator().ocr_agent + + +supabase_client = get_supabase() + + +@app.get("/") +def read_root(): + return {"message": "Visual Math Solver API v5.1 is running", "version": "5.1"} + + +@app.post("/api/v1/ocr") +async def upload_ocr( + file: UploadFile = File(...), + _user_id=Depends(get_current_user_id), +): + """OCR upload: requires authenticated user.""" + temp_path = f"temp_{uuid.uuid4()}.png" + with open(temp_path, "wb") as buffer: + buffer.write(await file.read()) + + try: + text = await ocr_from_local_image_path(temp_path, file.filename, get_ocr_agent()) + return {"text": text} + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + +@app.get("/api/v1/solve/{job_id}") +async def get_job_status( + job_id: str, + user_id=Depends(get_current_user_id), +): + """Retrieve job status (can be used for polling if WS fails). Owner-only.""" + response = supabase_client.table("jobs").select("*").eq("id", job_id).execute() + if not response.data: + raise HTTPException(status_code=404, detail="Job not found") + job = response.data[0] + if job.get("user_id") is not None and str(job["user_id"]) != str(user_id): + raise HTTPException(status_code=403, detail="Forbidden: You do not own this job.") + # Stable contract for FE poll (job_id alias, parsed result JSON, string UUIDs) + return normalize_job_row_for_client(job) diff --git a/app/models/schemas.py b/app/models/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..9f636732f4b38b08084028a98ab91629445479c6 --- /dev/null +++ b/app/models/schemas.py @@ -0,0 +1,80 @@ +from pydantic import BaseModel, EmailStr, field_validator +from typing import Optional, List, Any, Dict +from datetime import datetime +import uuid + +from app.url_utils import sanitize_url + +# --- Auth Schemas --- +class UserProfile(BaseModel): + id: uuid.UUID + display_name: Optional[str] = None + avatar_url: Optional[str] = None + created_at: datetime + +class User(BaseModel): + id: uuid.UUID + email: EmailStr + +# --- Session Schemas --- +class SessionBase(BaseModel): + title: str = "Bài toán mới" + +class SessionCreate(SessionBase): + pass + +class Session(SessionBase): + id: uuid.UUID + user_id: uuid.UUID + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + +# --- Message Schemas --- +class MessageBase(BaseModel): + role: str + type: str = "text" + content: str + metadata: Dict[str, Any] = {} + +class MessageCreate(MessageBase): + session_id: uuid.UUID + +class Message(MessageBase): + id: uuid.UUID + session_id: uuid.UUID + created_at: datetime + + class Config: + from_attributes = True + +# --- Solve Job Schemas --- +class SolveRequest(BaseModel): + text: str + image_url: Optional[str] = None + + @field_validator("image_url", mode="before") + @classmethod + def _clean_image_url(cls, v): + return sanitize_url(v) if v is not None else None + +class SolveResponse(BaseModel): + job_id: str + status: str + +class RenderVideoRequest(BaseModel): + job_id: Optional[str] = None + +class RenderVideoResponse(BaseModel): + job_id: str + status: str + + +class OcrPreviewResponse(BaseModel): + """Stateless OCR preview before POST .../solve (no DB writes, no job).""" + + ocr_text: str + user_message: str = "" + combined_draft: str diff --git a/app/ocr_celery.py b/app/ocr_celery.py new file mode 100644 index 0000000000000000000000000000000000000000..120f33cb4d32eb4fa722edb20dd3c89ff65b700f --- /dev/null +++ b/app/ocr_celery.py @@ -0,0 +1,54 @@ +"""Run OCR on a remote worker via Celery (queue `ocr`) when OCR_USE_CELERY is enabled.""" + +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING + +import anyio + +if TYPE_CHECKING: + from agents.ocr_agent import OCRAgent + +logger = logging.getLogger(__name__) + + +def ocr_celery_enabled() -> bool: + return os.getenv("OCR_USE_CELERY", "").strip().lower() in ("1", "true", "yes", "on") + + +def _ocr_timeout_sec() -> float: + raw = os.getenv("OCR_CELERY_TIMEOUT_SEC", "180") + try: + return max(30.0, float(raw)) + except ValueError: + return 180.0 + + +def _run_ocr_celery_sync(image_url: str) -> str: + from worker.ocr_tasks import run_ocr_from_url + + async_result = run_ocr_from_url.apply_async(args=[image_url]) + return async_result.get(timeout=_ocr_timeout_sec()) + + +def _is_ocr_error_response(text: str) -> bool: + s = (text or "").lstrip() + return s.startswith("Error:") + + +async def ocr_from_image_url(image_url: str, fallback_agent: "OCRAgent") -> str: + """ + If OCR_USE_CELERY: delegate to Celery task `run_ocr_from_url` (worker queue `ocr`, raw OCR only), + then run ``refine_with_llm`` on the API process. + Else: use fallback_agent.process_url (in-process full pipeline). + """ + if not ocr_celery_enabled(): + return await fallback_agent.process_url(image_url) + logger.info("OCR_USE_CELERY: delegating OCR to Celery queue=ocr (LLM refine on API)") + raw = await anyio.to_thread.run_sync(_run_ocr_celery_sync, image_url) + raw = raw if raw is not None else "" + if not raw.strip() or _is_ocr_error_response(raw): + return raw + return await fallback_agent.refine_with_llm(raw) diff --git a/app/ocr_local_file.py b/app/ocr_local_file.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3c78af737d2455eb091b96131a2aff6fe6c5bc --- /dev/null +++ b/app/ocr_local_file.py @@ -0,0 +1,43 @@ +"""OCR from a local file path, optionally via Celery worker (upload temp blob first).""" + +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING + +from app.chat_image_upload import ( + delete_storage_object, + upload_ephemeral_ocr_blob, + validate_chat_image_bytes, +) +from app.ocr_celery import ocr_celery_enabled, ocr_from_image_url + +if TYPE_CHECKING: + from agents.ocr_agent import OCRAgent + +logger = logging.getLogger(__name__) + + +async def ocr_from_local_image_path( + local_path: str, + original_filename: str | None, + fallback_agent: "OCRAgent", +) -> str: + """ + Run OCR on a file on local disk. If OCR_USE_Celery, upload to ephemeral storage URL + then delegate to worker; otherwise process_image in-process. + """ + if not ocr_celery_enabled(): + return await fallback_agent.process_image(local_path) + + with open(local_path, "rb") as f: + body = f.read() + ext = os.path.splitext(original_filename or local_path)[1].lower() or ".png" + _, content_type = validate_chat_image_bytes(original_filename or local_path, body, None) + bucket = os.getenv("SUPABASE_IMAGE_BUCKET", "image") + path, url = upload_ephemeral_ocr_blob(body, ext, content_type) + try: + return await ocr_from_image_url(url, fallback_agent) + finally: + delete_storage_object(bucket, path) diff --git a/app/ocr_text_merge.py b/app/ocr_text_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..b5cc1346f7570174eb2d204c061ec01fb844d4c4 --- /dev/null +++ b/app/ocr_text_merge.py @@ -0,0 +1,14 @@ +"""Helpers for OCR preview combined draft (no Pydantic email deps).""" + +from __future__ import annotations + +from typing import Optional + + +def build_combined_ocr_preview_draft(user_message: Optional[str], ocr_text: str) -> str: + """Merge user caption and OCR text for confirm step (user message first, then OCR).""" + u = (user_message or "").strip() + o = (ocr_text or "").strip() + if u and o: + return f"{u}\n\n{o}" + return u or o diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45e499f6564756ee3155638e98e3727b71c03759 --- /dev/null +++ b/app/routers/__init__.py @@ -0,0 +1 @@ +from . import auth, sessions, solve diff --git a/app/routers/auth.py b/app/routers/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbcc505fa7701fd463cc23f619a71c900bfe7d3 --- /dev/null +++ b/app/routers/auth.py @@ -0,0 +1,23 @@ +from fastapi import APIRouter, Depends, HTTPException +from app.dependencies import get_current_user_id +from app.supabase_client import get_supabase +from app.models.schemas import UserProfile +import uuid + +router = APIRouter(prefix="/api/v1/auth", tags=["Auth"]) + +@router.get("/me") +async def get_me(user_id=Depends(get_current_user_id)): + """获取当前登录用户的信息 (Retrieve current user profile)""" + supabase = get_supabase() + res = supabase.table("profiles").select("*").eq("id", user_id).execute() + if not res.data: + raise HTTPException(status_code=404, detail="Profile not found.") + return res.data[0] + +@router.patch("/me") +async def update_me(data: dict, user_id=Depends(get_current_user_id)): + """Cập nhật profile hiện tại (Update current profile)""" + supabase = get_supabase() + res = supabase.table("profiles").update(data).eq("id", user_id).execute() + return res.data[0] diff --git a/app/routers/sessions.py b/app/routers/sessions.py new file mode 100644 index 0000000000000000000000000000000000000000..edf3eb88d994546610cc255483853757cf8939f4 --- /dev/null +++ b/app/routers/sessions.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import logging +import time +from typing import List + +from fastapi import APIRouter, Depends, HTTPException + +from app.dependencies import get_current_user_id +from app.logutil import log_step +from app.session_cache import ( + get_sessions_list_cached, + invalidate_for_user, + invalidate_session_owner, + session_owned_by_user, +) +from app.supabase_client import get_supabase + +router = APIRouter(prefix="/api/v1/sessions", tags=["Sessions"]) +logger = logging.getLogger(__name__) + + +@router.get("", response_model=List[dict]) +async def list_sessions(user_id=Depends(get_current_user_id)): + """Danh sách các phiên chat của người dùng (List user's chat sessions)""" + supabase = get_supabase() + t0 = time.perf_counter() + + def fetch() -> list: + res = ( + supabase.table("sessions") + .select("id, user_id, title, created_at, updated_at") + .eq("user_id", user_id) + .order("updated_at", desc=True) + .execute() + ) + log_step("db_select", table="sessions", op="list", user_id=str(user_id)) + return res.data + + out = get_sessions_list_cached(str(user_id), fetch) + logger.info( + "sessions.list user=%s count=%d %.1fms", + user_id, + len(out), + (time.perf_counter() - t0) * 1000, + ) + return out + + +@router.post("", response_model=dict) +async def create_session(user_id=Depends(get_current_user_id)): + """Tạo một phiên chat mới (Create a new chat session)""" + supabase = get_supabase() + t0 = time.perf_counter() + res = supabase.table("sessions").insert( + {"user_id": user_id, "title": "Bài toán mới"} + ).execute() + log_step("db_insert", table="sessions", op="create") + invalidate_for_user(str(user_id)) + row = res.data[0] + logger.info( + "sessions.create user=%s id=%s %.1fms", + user_id, + row.get("id"), + (time.perf_counter() - t0) * 1000, + ) + return row + + +@router.get("/{session_id}/messages", response_model=List[dict]) +async def get_session_messages(session_id: str, user_id=Depends(get_current_user_id)): + """Lấy toàn bộ lịch sử tin nhắn của một phiên (Get chat history for a session)""" + supabase = get_supabase() + + def owns() -> bool: + res = ( + supabase.table("sessions") + .select("id") + .eq("id", session_id) + .eq("user_id", user_id) + .execute() + ) + log_step("db_select", table="sessions", op="owner_check", session_id=session_id) + return bool(res.data) + + if not session_owned_by_user(session_id, str(user_id), owns): + raise HTTPException( + status_code=403, detail="Forbidden: You do not own this session." + ) + + res = ( + supabase.table("messages") + .select("*") + .eq("session_id", session_id) + .order("created_at", desc=False) + .execute() + ) + log_step("db_select", table="messages", op="list", session_id=session_id) + return res.data + + +@router.delete("/{session_id}") +async def delete_session(session_id: str, user_id=Depends(get_current_user_id)): + """Xóa một phiên chat (Delete a chat session)""" + supabase = get_supabase() + + def owns() -> bool: + res = ( + supabase.table("sessions") + .select("id") + .eq("id", session_id) + .eq("user_id", user_id) + .execute() + ) + return bool(res.data) + + if not session_owned_by_user(session_id, str(user_id), owns): + raise HTTPException( + status_code=403, detail="Forbidden: You do not own this session." + ) + + # jobs.session_id FK must be cleared before sessions row + supabase.table("jobs").delete().eq("session_id", session_id).eq("user_id", user_id).execute() + log_step("db_delete", table="jobs", op="by_session", session_id=session_id) + supabase.table("messages").delete().eq("session_id", session_id).execute() + log_step("db_delete", table="messages", op="by_session", session_id=session_id) + res = ( + supabase.table("sessions") + .delete() + .eq("id", session_id) + .eq("user_id", user_id) + .execute() + ) + log_step("db_delete", table="sessions", session_id=session_id) + invalidate_for_user(str(user_id)) + invalidate_session_owner(session_id, str(user_id)) + return {"status": "ok", "deleted_id": session_id} + + +@router.patch("/{session_id}/title") +async def update_session_title(title: str, session_id: str, user_id=Depends(get_current_user_id)): + """Cập nhật tiêu đề phiên chat (Rename a chat session)""" + supabase = get_supabase() + res = ( + supabase.table("sessions") + .update({"title": title}) + .eq("id", session_id) + .eq("user_id", user_id) + .execute() + ) + log_step("db_update", table="sessions", op="title", session_id=session_id) + invalidate_for_user(str(user_id)) + return res.data[0] + + +@router.get("/{session_id}/assets", response_model=List[dict]) +async def get_session_assets(session_id: str, user_id=Depends(get_current_user_id)): + """Lấy danh sách video đã render trong session (Get versioned assets for a session)""" + supabase = get_supabase() + + def owns() -> bool: + res = ( + supabase.table("sessions") + .select("id") + .eq("id", session_id) + .eq("user_id", user_id) + .execute() + ) + return bool(res.data) + + if not session_owned_by_user(session_id, str(user_id), owns): + raise HTTPException( + status_code=403, detail="Forbidden: You do not own this session." + ) + + res = ( + supabase.table("session_assets") + .select("*") + .eq("session_id", session_id) + .order("version", desc=True) + .execute() + ) + log_step("db_select", table="session_assets", op="list", session_id=session_id) + return res.data diff --git a/app/routers/solve.py b/app/routers/solve.py new file mode 100644 index 0000000000000000000000000000000000000000..76e44103c5ea06cd530670d868973e9c036f7c30 --- /dev/null +++ b/app/routers/solve.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +import logging +import os +import uuid + +from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile + +from agents.orchestrator import Orchestrator +from app.chat_image_upload import upload_session_chat_image, validate_chat_image_bytes +from app.ocr_celery import ocr_celery_enabled +from app.ocr_local_file import ocr_from_local_image_path +from app.dependencies import get_current_user_id +from app.errors import format_error_for_user +from app.logutil import log_pipeline_failure, log_pipeline_success, log_step +from app.models.schemas import ( + OcrPreviewResponse, + RenderVideoRequest, + RenderVideoResponse, + SolveRequest, + SolveResponse, +) +from app.ocr_text_merge import build_combined_ocr_preview_draft +from app.session_cache import invalidate_for_user, session_owned_by_user +from app.supabase_client import get_supabase + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/v1/sessions", tags=["Solve"]) + +# Eager init: all agents and models load at import time (also run in Docker build via scripts/prewarm_models.py). +ORCHESTRATOR = Orchestrator() + + +def get_orchestrator() -> Orchestrator: + return ORCHESTRATOR + + +_OCR_PREVIEW_MAX_BYTES = 10 * 1024 * 1024 + + +def _assert_session_owner(supabase, session_id: str, user_id, uid: str, op: str) -> None: + def owns() -> bool: + res = ( + supabase.table("sessions") + .select("id") + .eq("id", session_id) + .eq("user_id", user_id) + .execute() + ) + log_step("db_select", table="sessions", op=op, session_id=session_id) + return bool(res.data) + + if not session_owned_by_user(session_id, uid, owns): + log_pipeline_failure("solve_request", error="forbidden", session_id=session_id) + raise HTTPException( + status_code=403, detail="Forbidden: You do not own this session." + ) + + +def _enqueue_solve_common( + supabase, + background_tasks: BackgroundTasks, + session_id: str, + user_id, + uid: str, + request: SolveRequest, + message_metadata: dict, + job_id: str, +) -> SolveResponse: + """Insert user message, job row, enqueue pipeline; update title when first message.""" + supabase.table("messages").insert( + { + "session_id": session_id, + "role": "user", + "type": "text", + "content": request.text, + "metadata": message_metadata, + } + ).execute() + log_step("db_insert", table="messages", op="user_message", session_id=session_id) + + supabase.table("jobs").insert( + { + "id": job_id, + "user_id": user_id, + "session_id": session_id, + "status": "processing", + "input_text": request.text, + } + ).execute() + log_step("db_insert", table="jobs", job_id=job_id) + + background_tasks.add_task(process_session_job, job_id, session_id, request, str(user_id)) + + title_check = supabase.table("sessions").select("title").eq("id", session_id).execute() + if title_check.data and title_check.data[0]["title"] == "Bài toán mới": + new_title = request.text[:50] + ("..." if len(request.text) > 50 else "") + supabase.table("sessions").update({"title": new_title}).eq("id", session_id).execute() + log_step("db_update", table="sessions", op="title_from_first_message") + invalidate_for_user(uid) + + log_pipeline_success("solve_accepted", job_id=job_id, session_id=session_id) + return SolveResponse(job_id=job_id, status="processing") + + +@router.post("/{session_id}/ocr_preview", response_model=OcrPreviewResponse) +async def ocr_preview( + session_id: str, + user_id=Depends(get_current_user_id), + file: UploadFile = File(...), + user_message: str | None = Form(None), +): + """ + Run OCR on an uploaded image and merge with optional user_message into combined_draft. + Does not insert messages or start a solve job. After user confirms, call POST .../solve + with text=combined_draft (edited) and omit image_url to avoid double OCR. + """ + supabase = get_supabase() + uid = str(user_id) + _assert_session_owner(supabase, session_id, user_id, uid, "owner_check_ocr_preview") + + body = await file.read() + if len(body) > _OCR_PREVIEW_MAX_BYTES: + raise HTTPException( + status_code=413, + detail=f"Image too large (max {_OCR_PREVIEW_MAX_BYTES // (1024 * 1024)} MB).", + ) + if not body: + raise HTTPException(status_code=400, detail="Empty file.") + + if ocr_celery_enabled(): + validate_chat_image_bytes(file.filename, body, file.content_type) + + suffix = os.path.splitext(file.filename or "")[1].lower() + if suffix not in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ""): + suffix = ".png" + temp_path = f"temp_ocr_preview_{uuid.uuid4()}{suffix or '.png'}" + try: + with open(temp_path, "wb") as f: + f.write(body) + ocr_text = await ocr_from_local_image_path( + temp_path, file.filename, get_orchestrator().ocr_agent + ) + if ocr_text is None: + ocr_text = "" + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + um = (user_message or "").strip() + combined = build_combined_ocr_preview_draft(user_message, ocr_text) + log_step("ocr_preview_done", session_id=session_id, ocr_len=len(ocr_text), user_len=len(um)) + return OcrPreviewResponse( + ocr_text=ocr_text, + user_message=um, + combined_draft=combined, + ) + + +@router.post("/{session_id}/solve", response_model=SolveResponse) +async def solve_problem( + session_id: str, + request: SolveRequest, + background_tasks: BackgroundTasks, + user_id=Depends(get_current_user_id), +): + """ + Gửi câu hỏi giải toán trong một session (Submit geometry problem in a session). + Lưu câu hỏi vào history và bắt đầu tiến trình giải (chỉ giải toán và tạo hình tĩnh). + """ + supabase = get_supabase() + uid = str(user_id) + _assert_session_owner(supabase, session_id, user_id, uid, "owner_check") + + message_metadata = {"image_url": request.image_url} if request.image_url else {} + job_id = str(uuid.uuid4()) + return _enqueue_solve_common( + supabase, + background_tasks, + session_id, + user_id, + uid, + request, + message_metadata, + job_id, + ) + + +@router.post("/{session_id}/solve_multipart", response_model=SolveResponse) +async def solve_multipart( + session_id: str, + background_tasks: BackgroundTasks, + user_id=Depends(get_current_user_id), + text: str = Form(...), + file: UploadFile = File(...), +): + """ + Gửi text + file ảnh trong một request multipart: validate, upload bucket `image`, + ghi session_assets, lưu message kèm metadata (URL, size, type), rồi enqueue solve + (image_url trỏ public URL để orchestrator OCR). + """ + supabase = get_supabase() + uid = str(user_id) + _assert_session_owner(supabase, session_id, user_id, uid, "owner_check_solve_multipart") + + t = (text or "").strip() + if not t: + raise HTTPException(status_code=400, detail="text must not be empty.") + + body = await file.read() + ext, content_type = validate_chat_image_bytes(file.filename, body, file.content_type) + + job_id = str(uuid.uuid4()) + up = upload_session_chat_image(session_id, job_id, body, ext, content_type) + public_url = up["public_url"] + + message_metadata = { + "image_url": public_url, + "attachment": { + "public_url": public_url, + "storage_path": up["storage_path"], + "size_bytes": len(body), + "content_type": content_type, + "original_filename": file.filename or "", + "session_asset_id": up.get("session_asset_id"), + }, + } + request = SolveRequest(text=t, image_url=public_url) + return _enqueue_solve_common( + supabase, + background_tasks, + session_id, + user_id, + uid, + request, + message_metadata, + job_id, + ) + + +@router.post("/{session_id}/render_video", response_model=RenderVideoResponse) +async def render_video( + session_id: str, + request: RenderVideoRequest, + background_tasks: BackgroundTasks, + user_id=Depends(get_current_user_id), +): + """ + Yêu cầu tạo video Manim từ trạng thái hình ảnh mới nhất của session. + """ + supabase = get_supabase() + + # 1. Kiểm tra quyền sở hữu + res = supabase.table("sessions").select("id").eq("id", session_id).eq("user_id", user_id).execute() + if not res.data: + raise HTTPException(status_code=403, detail="Forbidden: You do not own this session.") + + # 2. Tìm tin nhắn assistant có metadata hình học (cụ thể job_id hoặc mới nhất trong 10 tin nhắn gần nhất) + msg_res = ( + supabase.table("messages") + .select("metadata") + .eq("session_id", session_id) + .eq("role", "assistant") + .order("created_at", desc=True) + .limit(10) + .execute() + ) + + latest_geometry = None + if msg_res.data: + for msg in msg_res.data: + meta = msg.get("metadata", {}) + # Nếu có yêu cầu job_id cụ thể, phải khớp job_id + if request.job_id and meta.get("job_id") != request.job_id: + continue + + # Phải có dữ liệu hình học + if meta.get("geometry_dsl") and meta.get("coordinates"): + latest_geometry = meta + break + + if not latest_geometry: + raise HTTPException(status_code=404, detail="Không tìm thấy dữ liệu hình học để render video.") + + # 3. Tạo Job rendering + job_id = str(uuid.uuid4()) + supabase.table("jobs").insert({ + "id": job_id, + "user_id": user_id, + "session_id": session_id, + "status": "rendering_queued", + "input_text": f"Render video requested at {job_id}", + }).execute() + + # 4. Dispatch background task + background_tasks.add_task(process_render_job, job_id, session_id, latest_geometry) + + return RenderVideoResponse(job_id=job_id, status="rendering_queued") + + +async def process_session_job( + job_id: str, session_id: str, request: SolveRequest, user_id: str +): + """Tiến trình giải toán ngầm, tạo hình ảnh tĩnh.""" + from app.websocket_manager import notify_status + + async def status_update(status: str): + await notify_status(job_id, {"status": status, "job_id": job_id}) + + supabase = get_supabase() + try: + history_res = ( + supabase.table("messages") + .select("*") + .eq("session_id", session_id) + .order("created_at", desc=False) + .execute() + ) + history = history_res.data if history_res.data else [] + + result = await get_orchestrator().run( + request.text, + request.image_url, + job_id=job_id, + session_id=session_id, + status_callback=status_update, + history=history, + ) + + status = result.get("status", "error") if "error" not in result else "error" + + supabase.table("jobs").update({"status": status, "result": result}).eq( + "id", job_id + ).execute() + + supabase.table("messages").insert( + { + "session_id": session_id, + "role": "assistant", + "type": "analysis" if "error" not in result else "error", + "content": ( + result.get("semantic_analysis", "Đã có lỗi xảy ra.") + if "error" not in result + else result["error"] + ), + "metadata": { + "job_id": job_id, + "coordinates": result.get("coordinates"), + "geometry_dsl": result.get("geometry_dsl"), + "polygon_order": result.get("polygon_order", []), + "drawing_phases": result.get("drawing_phases", []), + "circles": result.get("circles", []), + "lines": result.get("lines", []), + "rays": result.get("rays", []), + "solution": result.get("solution"), + "is_3d": result.get("is_3d", False), + }, + } + ).execute() + + await notify_status(job_id, {"status": status, "job_id": job_id, "result": result}) + + except Exception as e: + logger.exception("Error processing session job %s", job_id) + error_msg = format_error_for_user(e) + supabase = get_supabase() + supabase.table("jobs").update( + {"status": "error", "result": {"error": str(e)}} + ).eq("id", job_id).execute() + supabase.table("messages").insert( + { + "session_id": session_id, + "role": "assistant", + "type": "error", + "content": error_msg, + "metadata": {"job_id": job_id}, + } + ).execute() + await notify_status(job_id, {"status": "error", "job_id": job_id, "error": error_msg}) + +async def process_render_job(job_id: str, session_id: str, geometry_data: dict): + """Tiến trình render video từ metadata có sẵn.""" + from app.websocket_manager import notify_status + from worker.tasks import render_geometry_video + + await notify_status(job_id, {"status": "rendering_queued", "job_id": job_id}) + + # Prepare payload for Celery (similar to what orchestrator used to do) + result_payload = { + "geometry_dsl": geometry_data.get("geometry_dsl"), + "coordinates": geometry_data.get("coordinates"), + "polygon_order": geometry_data.get("polygon_order", []), + "drawing_phases": geometry_data.get("drawing_phases", []), + "circles": geometry_data.get("circles", []), + "lines": geometry_data.get("lines", []), + "rays": geometry_data.get("rays", []), + "semantic": geometry_data.get("semantic", {}), + "semantic_analysis": geometry_data.get("semantic_analysis", "🎬 Video minh họa dựng từ trạng thái gần nhất."), + "session_id": session_id, + } + + try: + logger.info(f"[RenderJob] Attempting to dispatch Celery task for job {job_id}...") + render_geometry_video.delay(job_id, result_payload) + logger.info(f"[RenderJob] SUCCESS: Dispatched Celery task for job {job_id}") + except Exception as e: + logger.exception(f"[RenderJob] FAILED to dispatch Celery task: {e}") + supabase = get_supabase() + supabase.table("jobs").update({"status": "error", "result": {"error": f"Task dispatch failed: {str(e)}"}}).eq("id", job_id).execute() + await notify_status(job_id, {"status": "error", "job_id": job_id, "error": str(e)}) diff --git a/app/runtime_env.py b/app/runtime_env.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6087960acbfb3e0a25987103dec8179598e0f2 --- /dev/null +++ b/app/runtime_env.py @@ -0,0 +1,12 @@ +"""Default process env vars (Paddle/OpenMP). Call as early as possible after load_dotenv.""" + +from __future__ import annotations + +import os + + +def apply_runtime_env_defaults() -> None: + # Paddle respects OMP_NUM_THREADS at import; setdefault loses if platform already set 2+ + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + os.environ["OPENBLAS_NUM_THREADS"] = "1" diff --git a/app/session_cache.py b/app/session_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9a5faf70b8d44d0928f7720cc5aed84bc285bc --- /dev/null +++ b/app/session_cache.py @@ -0,0 +1,48 @@ +"""TTL in-memory cache để giảm truy vấn Supabase lặp lại (list session, quyền sở hữu session).""" + +from __future__ import annotations + +from typing import Any, Callable + +from cachetools import TTLCache + +from app.logutil import log_step + +_session_list: TTLCache[str, list[Any]] = TTLCache(maxsize=512, ttl=45) +_session_owner: TTLCache[tuple[str, str], bool] = TTLCache(maxsize=4096, ttl=45) + + +def invalidate_for_user(user_id: str) -> None: + """Xoá cache list session của user (sau create / delete / rename / solve đổi title).""" + _session_list.pop(user_id, None) + log_step("cache_invalidate", target="session_list", user_id=user_id) + + +def invalidate_session_owner(session_id: str, user_id: str) -> None: + _session_owner.pop((session_id, user_id), None) + log_step("cache_invalidate", target="session_owner", session_id=session_id, user_id=user_id) + + +def get_sessions_list_cached(user_id: str, fetch: Callable[[], list[Any]]) -> list[Any]: + if user_id in _session_list: + log_step("cache_hit", kind="session_list", user_id=user_id) + return _session_list[user_id] + log_step("cache_miss", kind="session_list", user_id=user_id) + data = fetch() + _session_list[user_id] = data + return data + + +def session_owned_by_user( + session_id: str, + user_id: str, + fetch: Callable[[], bool], +) -> bool: + key = (session_id, user_id) + if key in _session_owner: + log_step("cache_hit", kind="session_owner", session_id=session_id) + return _session_owner[key] + log_step("cache_miss", kind="session_owner", session_id=session_id) + ok = fetch() + _session_owner[key] = ok + return ok diff --git a/app/supabase_client.py b/app/supabase_client.py new file mode 100644 index 0000000000000000000000000000000000000000..da3eab3c4e47c7c9f0d842a6607ac04d38c40901 --- /dev/null +++ b/app/supabase_client.py @@ -0,0 +1,37 @@ +import os +from supabase import Client, ClientOptions, create_client +from supabase_auth import SyncMemoryStorage +from dotenv import load_dotenv + +load_dotenv() + +from app.url_utils import sanitize_env + + +def get_supabase() -> Client: + """Service-role client for server-side operations (bypasses RLS when policies expect service role).""" + url = sanitize_env(os.getenv("SUPABASE_URL")) + key = sanitize_env(os.getenv("SUPABASE_SERVICE_ROLE_KEY") or os.getenv("SUPABASE_KEY")) + if not url or not key: + raise RuntimeError( + "SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY (or SUPABASE_KEY) must be set" + ) + return create_client(url, key) + + +def get_supabase_for_user_jwt(access_token: str) -> Client: + """ + Client scoped to the logged-in user: PostgREST sends the user's JWT so RLS applies. + Use SUPABASE_ANON_KEY (publishable), not the service role key. + """ + url = sanitize_env(os.getenv("SUPABASE_URL")) + anon = sanitize_env(os.getenv("SUPABASE_ANON_KEY") or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY")) + if not url or not anon: + raise RuntimeError( + "SUPABASE_URL and SUPABASE_ANON_KEY (or NEXT_PUBLIC_SUPABASE_ANON_KEY) must be set " + "for user-scoped Supabase access" + ) + base_opts = ClientOptions(storage=SyncMemoryStorage()) + merged_headers = {**dict(base_opts.headers), "Authorization": f"Bearer {access_token}"} + opts = ClientOptions(storage=SyncMemoryStorage(), headers=merged_headers) + return create_client(url, anon, opts) diff --git a/app/url_utils.py b/app/url_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a37dde695e90efaa585bfeb16c74a369d7f7564 --- /dev/null +++ b/app/url_utils.py @@ -0,0 +1,23 @@ +"""Normalize URLs / env strings (HF secrets and copy-paste often include trailing newlines).""" + + +def sanitize_url(value: str | None) -> str | None: + if value is None: + return None + s = value.strip().replace("\r", "").replace("\n", "").replace("\t", "") + return s or None + + +def sanitize_env(value: str | None) -> str | None: + """Strip whitespace and line breaks from environment-backed strings.""" + return sanitize_url(value) + + +# OpenAI SDK (>=1.x) requires a non-empty api_key at client construction (Docker build / prewarm has no secrets). +_OPENAI_API_KEY_BUILD_PLACEHOLDER = "build-placeholder-openrouter-not-for-production" + + +def openai_compatible_api_key(raw: str | None) -> str: + """Return sanitized API key, or a placeholder so AsyncOpenAI() can be constructed without env at build time.""" + k = sanitize_env(raw) + return k if k else _OPENAI_API_KEY_BUILD_PLACEHOLDER diff --git a/app/websocket_manager.py b/app/websocket_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..56dd602d027ec4fec9566508168217ee4827362e --- /dev/null +++ b/app/websocket_manager.py @@ -0,0 +1,40 @@ +"""WebSocket connection registry and job status notifications (avoid circular imports with main).""" + +from __future__ import annotations + +import logging +from typing import Dict, List + +from fastapi import WebSocket, WebSocketDisconnect + +logger = logging.getLogger(__name__) + +active_connections: Dict[str, List[WebSocket]] = {} + + +async def notify_status(job_id: str, data: dict) -> None: + if job_id not in active_connections: + return + for connection in list(active_connections[job_id]): + try: + await connection.send_json(data) + except Exception as e: + logger.error("WS error sending to %s: %s", job_id, e) + + +def register_websocket_routes(app) -> None: + """Attach websocket endpoint to the FastAPI app.""" + + @app.websocket("/ws/{job_id}") + async def websocket_endpoint(websocket: WebSocket, job_id: str) -> None: + await websocket.accept() + if job_id not in active_connections: + active_connections[job_id] = [] + active_connections[job_id].append(websocket) + try: + while True: + await websocket.receive_text() + except WebSocketDisconnect: + active_connections[job_id].remove(websocket) + if not active_connections[job_id]: + del active_connections[job_id] diff --git a/clean_ports.sh b/clean_ports.sh new file mode 100755 index 0000000000000000000000000000000000000000..b31e620274c7f7e20ff86c6326ae06093ba2ada3 --- /dev/null +++ b/clean_ports.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Script to kill all project-related processes for a clean restart + +echo "🧹 Cleaning up project processes..." + +# Kill things on ports 8000 (Backend) and 3000 (Frontend) +PORTS="8000 3000 11020" +for PORT in $PORTS; do + PIDS=$(lsof -ti :$PORT) + if [ ! -z "$PIDS" ]; then + echo "Killing processes on port $PORT: $PIDS" + kill -9 $PIDS 2>/dev/null + fi +done + +# Kill by process name +echo "Killing any remaining Celery, Uvicorn, or Manim processes..." +pkill -9 -f "celery" 2>/dev/null +pkill -9 -f "uvicorn" 2>/dev/null +pkill -9 -f "manim" 2>/dev/null + +echo "✅ Done. You can now restart your Backend, Worker, and Frontend." diff --git a/dump.rdb b/dump.rdb new file mode 100644 index 0000000000000000000000000000000000000000..e00e303a4ee6a585ba063b400a5b957ac00fb3e3 Binary files /dev/null and b/dump.rdb differ diff --git a/geometry_render/__init__.py b/geometry_render/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e8a738f3506634d1227ac4b358ffe10def98c7 --- /dev/null +++ b/geometry_render/__init__.py @@ -0,0 +1,5 @@ +"""Manim geometry script generation and rendering (worker-safe, no LLM agents).""" + +from .renderer import RendererAgent + +__all__ = ["RendererAgent"] diff --git a/geometry_render/renderer.py b/geometry_render/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..03a03dfab0f2f0870d11787553e027aa28c73459 --- /dev/null +++ b/geometry_render/renderer.py @@ -0,0 +1,265 @@ +import os +import subprocess +import glob +import string +import logging +from typing import Dict, Any, List + +logger = logging.getLogger(__name__) + + +class RendererAgent: + """ + Renderer — generates Manim scripts from geometry data. + + Drawing happens in phases: + Phase 1: Main polygon (base shape with correct vertex order) + Phase 2: Auxiliary points and segments (midpoints, derived segments) + Phase 3: Labels for all points + """ + + def generate_manim_script(self, data: Dict[str, Any]) -> str: + coords: Dict[str, List[float]] = data.get("coordinates", {}) + polygon_order: List[str] = data.get("polygon_order", []) + circles_meta: List[Dict] = data.get("circles", []) + lines_meta: List[List[str]] = data.get("lines", []) + rays_meta: List[List[str]] = data.get("rays", []) + drawing_phases: List[Dict] = data.get("drawing_phases", []) + semantic: Dict[str, Any] = data.get("semantic", {}) + shape_type = semantic.get("type", "").lower() + + # ── Detect 3D Context ──────────────────────────────────────────────── + is_3d = False + for pos in coords.values(): + if len(pos) >= 3 and abs(pos[2]) > 0.001: + is_3d = True + break + if shape_type in ["pyramid", "prism", "sphere"]: + is_3d = True + + # ── Fallback: infer polygon_order from coords keys (alphabetical uppercase) ── + if not polygon_order: + base = sorted( + [pid for pid in coords if pid in string.ascii_uppercase], + key=lambda p: string.ascii_uppercase.index(p) + ) + polygon_order = base + + # Separate base points from derived (multi-char or lowercase) + base_ids = [pid for pid in polygon_order if pid in coords] + derived_ids = [pid for pid in coords if pid not in polygon_order] + + scene_base = "ThreeDScene" if is_3d else "MovingCameraScene" + lines = [ + "from manim import *", + "", + f"class GeometryScene({scene_base}):", + " def construct(self):", + ] + + if is_3d: + lines.append(" # 3D Setup") + lines.append(" self.set_camera_orientation(phi=75*DEGREES, theta=-45*DEGREES)") + lines.append(" axes = ThreeDAxes(axis_config={'stroke_width': 1})") + lines.append(" axes.set_opacity(0.3)") + lines.append(" self.add(axes)") + lines.append(" self.begin_ambient_camera_rotation(rate=0.1)") + lines.append("") + + # ── Declare all dots and labels ─────────────────────────────────────── + for pid, pos in coords.items(): + x, y, z = 0, 0, 0 + if len(pos) >= 1: x = round(pos[0], 4) + if len(pos) >= 2: y = round(pos[1], 4) + if len(pos) >= 3: z = round(pos[2], 4) + + dot_class = "Dot3D" if is_3d else "Dot" + lines.append(f" p_{pid} = {dot_class}(point=[{x}, {y}, {z}], color=WHITE, radius=0.08)") + + if is_3d: + lines.append( + f" l_{pid} = Text('{pid}', font_size=20, color=WHITE)" + f".move_to(p_{pid}.get_center() + [0.2, 0.2, 0.2])" + ) + # Ensure labels follow camera in 3D (fixed orientation) + lines.append(f" self.add_fixed_orientation_mobjects(l_{pid})") + else: + lines.append( + f" l_{pid} = Text('{pid}', font_size=22, color=WHITE)" + f".next_to(p_{pid}, UR, buff=0.15)" + ) + + # ── 3D Shape Special: Pyramid/Prism Faces ──────────────────────────── + if is_3d and shape_type == "pyramid" and len(base_ids) >= 3: + # Find apex (usually 'S') + apex_id = "S" if "S" in coords else derived_ids[0] if derived_ids else None + if apex_id: + # Draw base face + base_pts = ", ".join([f"p_{pid}.get_center()" for pid in base_ids]) + lines.append(f" base_face = Polygon({base_pts}, color=BLUE, fill_opacity=0.1)") + lines.append(" self.play(Create(base_face), run_time=1.0)") + + # Draw side faces + for i in range(len(base_ids)): + p1 = base_ids[i] + p2 = base_ids[(i + 1) % len(base_ids)] + face_pts = f"p_{apex_id}.get_center(), p_{p1}.get_center(), p_{p2}.get_center()" + lines.append( + f" side_{i} = Polygon({face_pts}, color=BLUE, stroke_width=1, fill_opacity=0.05)" + ) + lines.append(f" self.play(Create(side_{i}), run_time=0.5)") + + # ── Circles ────────────────────────────────────────────────────────── + for i, c in enumerate(circles_meta): + center = c["center"] + r = c["radius"] + if center in coords: + cx, cy, cz = 0, 0, 0 + pos = coords[center] + if len(pos) >= 1: cx = round(pos[0], 4) + if len(pos) >= 2: cy = round(pos[1], 4) + if len(pos) >= 3: cz = round(pos[2], 4) + lines.append( + f" circle_{i} = Circle(radius={r}, color=BLUE)" + f".move_to([{cx}, {cy}, {cz}])" + ) + + # ── Infinite Lines & Rays ──────────────────────────────────────────── + # (Standard Line works for 3D coordinates in Manim) + for i, (p1, p2) in enumerate(lines_meta): + if p1 in coords and p2 in coords: + lines.append( + f" line_ext_{i} = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=GRAY_D, stroke_width=2)" + f".scale(20)" + ) + + for i, (p1, p2) in enumerate(rays_meta): + if p1 in coords and p2 in coords: + lines.append( + f" ray_{i} = Line(p_{p1}.get_center(), p_{p1}.get_center() + 15 * (p_{p2}.get_center() - p_{p1}.get_center())," + f" color=GRAY_C, stroke_width=2)" + ) + + # ── Camera auto-fit group (Only for 2D) ────────────────────────────── + if not is_3d: + all_dot_names = [f"p_{pid}" for pid in coords] + all_names_str = ", ".join(all_dot_names) + lines.append(f" _all = VGroup({all_names_str})") + lines.append(" self.camera.frame.set_width(max(_all.width * 2.0, 8))") + lines.append(" self.camera.frame.move_to(_all)") + lines.append("") + + # ── Phase 1: Base polygon ───────────────────────────────────────────── + if len(base_ids) >= 3: + pts_str = ", ".join([f"p_{pid}.get_center()" for pid in base_ids]) + lines.append(f" poly = Polygon({pts_str}, color=BLUE, fill_color=BLUE, fill_opacity=0.15)") + lines.append(" self.play(Create(poly), run_time=1.5)") + elif len(base_ids) == 2: + p1, p2 = base_ids + lines.append(f" base_line = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=BLUE)") + lines.append(" self.play(Create(base_line), run_time=1.0)") + + # Draw base points + if base_ids: + base_dots_str = ", ".join([f"p_{pid}" for pid in base_ids]) + lines.append(f" self.play(FadeIn(VGroup({base_dots_str})), run_time=0.5)") + lines.append(" self.wait(0.5)") + + # ── Phase 2: Auxiliary points and segments ──────────────────────────── + if derived_ids: + derived_dots_str = ", ".join([f"p_{pid}" for pid in derived_ids]) + lines.append(f" self.play(FadeIn(VGroup({derived_dots_str})), run_time=0.8)") + + # Segments from drawing_phases + segment_lines = [] + for phase in drawing_phases: + if phase.get("phase") == 2: + for seg in phase.get("segments", []): + if len(seg) == 2 and seg[0] in coords and seg[1] in coords: + p1, p2 = seg[0], seg[1] + seg_var = f"seg_{p1}_{p2}" + lines.append( + f" {seg_var} = Line(p_{p1}.get_center(), p_{p2}.get_center()," + f" color=YELLOW)" + ) + segment_lines.append(seg_var) + + if segment_lines: + segs_str = ", ".join([f"Create({sv})" for sv in segment_lines]) + lines.append(f" self.play({segs_str}, run_time=1.2)") + + if derived_ids or segment_lines: + lines.append(" self.wait(0.5)") + + # ── Phase 3: All labels ─────────────────────────────────────────────── + all_labels_str = ", ".join([f"l_{pid}" for pid in coords]) + lines.append(f" self.play(FadeIn(VGroup({all_labels_str})), run_time=0.8)") + + # ── Circles phase ───────────────────────────────────────────────────── + for i in range(len(circles_meta)): + lines.append(f" self.play(Create(circle_{i}), run_time=1.5)") + + # ── Lines & Rays phase ──────────────────────────────────────────────── + if lines_meta or rays_meta: + lr_anims = [] + for i in range(len(lines_meta)): + lr_anims.append(f"Create(line_ext_{i})") + for i in range(len(rays_meta)): + lr_anims.append(f"Create(ray_{i})") + lines.append(f" self.play({', '.join(lr_anims)}, run_time=1.5)") + + lines.append(" self.wait(2)") + + return "\n".join(lines) + + def run_manim(self, script_content: str, job_id: str) -> str: + script_file = f"{job_id}.py" + with open(script_file, "w") as f: + f.write(script_content) + + try: + if os.getenv("MOCK_VIDEO") == "true": + logger.info(f"MOCK_VIDEO is true. Skipping Manim for job {job_id}") + # Create a dummy file if needed, or just return a path that exists + dummy_path = f"videos/{job_id}.mp4" + os.makedirs("videos", exist_ok=True) + with open(dummy_path, "wb") as f: + f.write(b"dummy video content") + return dummy_path + + # Determine manim executable path + manim_exe = "manim" + venv_manim = os.path.join(os.getcwd(), "venv", "bin", "manim") + if os.path.exists(venv_manim): + manim_exe = venv_manim + + # Prepare environment with homebrew paths + custom_env = os.environ.copy() + brew_path = "/opt/homebrew/bin:/usr/local/bin" + custom_env["PATH"] = f"{brew_path}:{custom_env.get('PATH', '')}" + + logger.info(f"Running {manim_exe} for job {job_id}...") + result = subprocess.run( + [manim_exe, "-ql", "--media_dir", ".", "-o", f"{job_id}.mp4", script_file, "GeometryScene"], + capture_output=True, + text=True, + env=custom_env, + ) + logger.info(f"Manim STDOUT: {result.stdout}") + if result.returncode != 0: + logger.error(f"Manim STDERR: {result.stderr}") + + for pattern in [f"**/videos/**/{job_id}.mp4", f"**/{job_id}*.mp4"]: + found = glob.glob(pattern, recursive=True) + if found: + logger.info(f"Manim Success: Found {found[0]}") + return found[0] + + logger.error(f"Manim file not found for job {job_id}. Return code: {result.returncode}") + return "" + except Exception as e: + logger.exception(f"Manim Execution Error: {e}") + return "" + finally: + if os.path.exists(script_file): + os.remove(script_file) diff --git a/migrations/add_image_bucket_storage.sql b/migrations/add_image_bucket_storage.sql new file mode 100644 index 0000000000000000000000000000000000000000..2b51fdd44cc1bcabb9d3ac3f5db1b4f89ec97d5d --- /dev/null +++ b/migrations/add_image_bucket_storage.sql @@ -0,0 +1,35 @@ +-- ============================================================ +-- MathSolver: Supabase Storage bucket `image` (chat / OCR attachments) +-- Run after session_assets and storage.video policies exist. +-- ============================================================ + +INSERT INTO storage.buckets (id, name, public) +VALUES ('image', 'image', true) +ON CONFLICT (id) DO UPDATE SET public = true; + +-- Service role: upload/delete/list for API + workers +DROP POLICY IF EXISTS "Service Role manage images" ON storage.objects; +CREATE POLICY "Service Role manage images" ON storage.objects + FOR ALL + TO service_role + USING (bucket_id = 'image') + WITH CHECK (bucket_id = 'image'); + +-- Authenticated: read only objects under sessions they own (path sessions/{session_id}/...) +DROP POLICY IF EXISTS "Users view session images" ON storage.objects; +CREATE POLICY "Users view session images" ON storage.objects + FOR SELECT + TO authenticated + USING ( + bucket_id = 'image' + AND (storage.foldername(name))[2] IN ( + SELECT id::text FROM public.sessions WHERE user_id = auth.uid() + ) + ); + +-- Public read for get_public_url / FE img tags (same model as video bucket) +DROP POLICY IF EXISTS "Public read images" ON storage.objects; +CREATE POLICY "Public read images" ON storage.objects + FOR SELECT + TO public + USING (bucket_id = 'image'); diff --git a/migrations/fix_rls_assets.sql b/migrations/fix_rls_assets.sql new file mode 100644 index 0000000000000000000000000000000000000000..3cb09b6427c84d7b28f705bfac05b1815db26328 --- /dev/null +++ b/migrations/fix_rls_assets.sql @@ -0,0 +1,96 @@ +-- ============================================================ +-- FIX RLS & SESSION ASSETS (MathSolver v5.1 Worker Fix) +-- ============================================================ + +-- 1. Ensure session_assets table exists +CREATE TABLE IF NOT EXISTS public.session_assets ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE, + job_id UUID NOT NULL, + asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')), + storage_path TEXT NOT NULL, + public_url TEXT NOT NULL, + version INTEGER NOT NULL DEFAULT 1, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Index for session_assets +CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id); +CREATE INDEX IF NOT EXISTS idx_session_assets_type ON public.session_assets(session_id, asset_type); + +-- 2. Enable RLS for all tables +ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY; +ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY; +ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY; +ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY; +ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY; + + +-- 3. Fix Table Policies to allow SERVICE ROLE +-- In Supabase, service_role usually bypasses RLS, but we add explicit policies for safety +-- especially for path-based checks or when SECURITY DEFINER functions are used. + +-- [Session Assets] +DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets; +CREATE POLICY "Users view own assets" ON public.session_assets + FOR SELECT USING ( + session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid()) + ); + +DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets; +CREATE POLICY "Service role manages assets" ON public.session_assets + FOR ALL USING (true) + WITH CHECK (true); + + +-- [Messages] - Allow Worker to insert assistant messages +DROP POLICY IF EXISTS "Users manage own messages" ON public.messages; +CREATE POLICY "Users manage own messages" ON public.messages + FOR ALL USING ( + session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid()) + OR + (auth.jwt() ->> 'role' = 'service_role') + ); + + +-- [Jobs] - Allow Worker to update job status +DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs; +CREATE POLICY "Users manage own jobs" ON public.jobs + FOR ALL USING ( + auth.uid() = user_id + OR user_id IS NULL + OR (auth.jwt() ->> 'role' = 'service_role') + ); + + +-- 4. Storage Policies (Bucket: video) +-- Ensure 'video' bucket exists +INSERT INTO storage.buckets (id, name, public) +VALUES ('video', 'video', true) +ON CONFLICT (id) DO UPDATE SET public = true; + +-- [Storage: Worker / Service Role] - Allow all in video bucket +DROP POLICY IF EXISTS "Service Role manage videos" ON storage.objects; +CREATE POLICY "Service Role manage videos" ON storage.objects + FOR ALL + TO service_role + USING (bucket_id = 'video'); + +-- [Storage: Users] - Allow users to view their session videos +DROP POLICY IF EXISTS "Users view session videos" ON storage.objects; +CREATE POLICY "Users view session videos" ON storage.objects + FOR SELECT + TO authenticated + USING ( + bucket_id = 'video' + AND (storage.foldername(name))[2] IN ( + SELECT id::text FROM public.sessions WHERE user_id = auth.uid() + ) + ); + +-- [Storage: Public] - Allow public read access to videos +DROP POLICY IF EXISTS "Public read videos" ON storage.objects; +CREATE POLICY "Public read videos" ON storage.objects + FOR SELECT + TO public + USING (bucket_id = 'video'); diff --git a/migrations/v4_migration.sql b/migrations/v4_migration.sql new file mode 100644 index 0000000000000000000000000000000000000000..e3c4c7ab8689de5952fd4dba04210873e2659042 --- /dev/null +++ b/migrations/v4_migration.sql @@ -0,0 +1,131 @@ +-- ============================================================ +-- MATHSOLVER v4.0 - Migration Script (Multi-Session & History) +-- ============================================================ + +-- 1. Profiles Table (Extends Supabase Auth) +CREATE TABLE IF NOT EXISTS public.profiles ( + id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE, + display_name TEXT, + avatar_url TEXT, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Function to handle new user signup and auto-create profile +CREATE OR REPLACE FUNCTION public.handle_new_user() +RETURNS TRIGGER AS $$ +BEGIN + INSERT INTO public.profiles (id, display_name, avatar_url) + VALUES ( + NEW.id, + COALESCE(NEW.raw_user_meta_data->>'full_name', NEW.email), + NEW.raw_user_meta_data->>'avatar_url' + ); + RETURN NEW; +END; +$$ LANGUAGE plpgsql SECURITY DEFINER; + +-- Trigger for profile creation +DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users; +CREATE TRIGGER on_auth_user_created + AFTER INSERT ON auth.users + FOR EACH ROW EXECUTE FUNCTION public.handle_new_user(); + +-- 2. Sessions Table +CREATE TABLE IF NOT EXISTS public.sessions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE, + title TEXT DEFAULT 'Bài toán mới', + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Index for sessions +CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON public.sessions(user_id); +CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON public.sessions(updated_at DESC); + +-- 3. Messages Table +CREATE TABLE IF NOT EXISTS public.messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE, + role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')), + type TEXT NOT NULL DEFAULT 'text', + content TEXT NOT NULL, + metadata JSONB DEFAULT '{}'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Index for messages +CREATE INDEX IF NOT EXISTS idx_messages_session_id ON public.messages(session_id); +CREATE INDEX IF NOT EXISTS idx_messages_created_at ON public.messages(session_id, created_at); + +-- 4. Session Assets Table (v5.1 Versioning) +CREATE TABLE IF NOT EXISTS public.session_assets ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE, + job_id UUID NOT NULL, + asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')), + storage_path TEXT NOT NULL, + public_url TEXT NOT NULL, + version INTEGER NOT NULL DEFAULT 1, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Index for session_assets +CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id); + +-- 5. Update Jobs Table +ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS user_id UUID REFERENCES auth.users(id); +ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS session_id UUID REFERENCES public.sessions(id); + +-- 6. Row Level Security (RLS) +ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY; +ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY; +ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY; +ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY; +ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY; + +-- Polices for public.profiles +DROP POLICY IF EXISTS "Users view own profile" ON public.profiles; +CREATE POLICY "Users view own profile" ON public.profiles FOR SELECT USING (auth.uid() = id); +DROP POLICY IF EXISTS "Users update own profile" ON public.profiles; +CREATE POLICY "Users update own profile" ON public.profiles FOR UPDATE USING (auth.uid() = id); + +-- Policies for public.sessions +DROP POLICY IF EXISTS "Users manage own sessions" ON public.sessions; +CREATE POLICY "Users manage own sessions" ON public.sessions FOR ALL USING (auth.uid() = user_id); + +-- Policies for public.messages +DROP POLICY IF EXISTS "Users manage own messages" ON public.messages; +CREATE POLICY "Users manage own messages" ON public.messages FOR ALL USING ( + session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid()) + OR (auth.jwt() ->> 'role' = 'service_role') +); + +-- Policies for public.session_assets +DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets; +CREATE POLICY "Users view own assets" ON public.session_assets FOR SELECT USING ( + session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid()) +); +DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets; +CREATE POLICY "Service role manages assets" ON public.session_assets FOR ALL USING (true); + +-- Policies for public.jobs +DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs; +CREATE POLICY "Users manage own jobs" ON public.jobs FOR ALL USING ( + auth.uid() = user_id OR user_id IS NULL OR (auth.jwt() ->> 'role' = 'service_role') +); + +-- 7. Storage Policies (Bucket: video) +-- (Run this in Supabase Dashboard if not allowed in migration) +-- INSERT INTO storage.buckets (id, name, public) VALUES ('video', 'video', true) ON CONFLICT (id) DO NOTHING; +-- CREATE POLICY "Service Role manage videos" ON storage.objects FOR ALL TO service_role USING (bucket_id = 'video'); +-- CREATE POLICY "Public read videos" ON storage.objects FOR SELECT TO public USING (bucket_id = 'video'); + +-- Grant permissions to public/authenticated +GRANT ALL ON public.profiles TO authenticated; +GRANT ALL ON public.sessions TO authenticated; +GRANT ALL ON public.messages TO authenticated; +GRANT ALL ON public.jobs TO authenticated; +GRANT ALL ON public.session_assets TO authenticated; +GRANT ALL ON public.session_assets TO service_role; diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..cff9e594f74c96e369757cad9ee634bcb50a6ea9 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,18 @@ +[pytest] +asyncio_mode = auto +testpaths = tests +pythonpath = . +filterwarnings = + ignore::DeprecationWarning + +markers = + real_api: HTTP tests need running backend and TEST_USER_ID / TEST_SESSION_ID. + real_worker_ocr: OCR Celery task or full OCR stack (heavy). + real_worker_manim: Real Manim render and Supabase video upload. + real_agents: Live LLM / orchestrator agent calls. + slow: Large suite or long polling timeouts. + smoke: Fast API health + one solve job. + orchestrator_local: In-process Orchestrator without HTTP server. + +# Default: skip integration tests that need services, keys, or long runs. +addopts = -m "not real_api and not real_worker_ocr and not real_worker_manim and not real_agents and not slow and not orchestrator_local" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6a519f87ed293cefe850664bf486ad47ade8bdfd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,38 @@ +# Target: Python 3.11 (see Dockerfile). Used by: FastAPI API, Celery worker, Manim render, OCR/vision stack. +# Install: pip install -r requirements.txt + +# --- Dev / test --- +pytest>=8.0 +pytest-asyncio>=0.24 + +# --- HTTP API --- +cachetools>=5.3 +fastapi>=0.115,<1 +uvicorn[standard]>=0.30 +python-multipart>=0.0.9 +python-dotenv>=1.0 +pydantic[email]>=2.4 +email-validator>=2 + +# --- Auth / data / queue --- +openai>=1.40 +supabase>=2.0 +celery>=5.3 +redis>=5 +httpx>=0.27 +websockets>=12 + +# --- Math & symbolic solver --- +sympy>=1.12 +numpy>=1.26,<2 +scipy>=1.11 +opencv-python-headless>=4.8,<4.10 + +# --- Video (GeometryScene via CLI) --- +manim>=0.18,<0.20 + +# --- OCR & vision (orchestrator / legacy /ocr) --- +pix2tex>=0.1.4 +paddleocr==2.7.3 +paddlepaddle==2.6.2 +ultralytics==8.2.2 diff --git a/requirements.worker-ocr.txt b/requirements.worker-ocr.txt new file mode 100644 index 0000000000000000000000000000000000000000..baba45613164185fffb7c93e35e2bcbe4eafe064 --- /dev/null +++ b/requirements.worker-ocr.txt @@ -0,0 +1,23 @@ +# OCR-only Celery worker: YOLO + PaddleOCR + Pix2Tex (no OpenRouter / no Manim). +# Install: pip install -r requirements.worker-ocr.txt + +cachetools>=5.3 +fastapi>=0.115,<1 +uvicorn[standard]>=0.30 +python-multipart>=0.0.9 +python-dotenv>=1.0 +pydantic[email]>=2.4 +email-validator>=2 + +celery>=5.3 +redis>=5 +httpx>=0.27 +websockets>=12 + +numpy>=1.26,<2 +opencv-python-headless>=4.8,<4.10 + +pix2tex>=0.1.4 +paddleocr==2.7.3 +paddlepaddle==2.6.2 +ultralytics==8.2.2 diff --git a/requirements.worker-render.txt b/requirements.worker-render.txt new file mode 100644 index 0000000000000000000000000000000000000000..2c3be39bc404e3185dc6a79995eb8836298c2d90 --- /dev/null +++ b/requirements.worker-render.txt @@ -0,0 +1,21 @@ +# Celery render worker: Manim + Supabase (no OpenAI / SymPy / OCR vision stack). +# Install: pip install -r requirements.worker-render.txt +# Includes FastAPI/uvicorn for worker_health.py (HF Spaces). + +cachetools>=5.3 +fastapi>=0.115,<1 +uvicorn[standard]>=0.30 +python-multipart>=0.0.9 +python-dotenv>=1.0 +pydantic[email]>=2.4 +email-validator>=2 + +celery>=5.3 +redis>=5 +httpx>=0.27 +websockets>=12 + +supabase>=2.0 + +numpy>=1.26,<2 +manim>=0.18,<0.20 diff --git a/run_api_test.sh b/run_api_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..2d83a73fde7f106e2610ed6c556b2fb916274868 --- /dev/null +++ b/run_api_test.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +LOG_FILE="api_test_results.log" +echo "=== Starting API E2E Test Suite ($(date)) ===" > $LOG_FILE + +# 1. Start BE Server in background +echo "[INFO] Starting Backend Server..." | tee -a $LOG_FILE +export ALLOW_TEST_BYPASS=true +export LOG_LEVEL=info +export CELERY_TASK_ALWAYS_EAGER=true +export CELERY_RESULT_BACKEND=rpc:// +export MOCK_VIDEO=true +PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 & +SERVER_PID=$! + +# 2. Wait for server to be ready +echo "[INFO] Waiting for server (PID: $SERVER_PID) on port 8000..." | tee -a $LOG_FILE +MAX_RETRIES=15 +READY=0 +for i in $(seq 1 $MAX_RETRIES); do + if curl -s http://localhost:8000/ > /dev/null; then + READY=1 + break + fi + sleep 2 +done + +if [ $READY -eq 0 ]; then + echo "[ERROR] Server failed to start in time. Check server_debug.log" | tee -a $LOG_FILE + kill $SERVER_PID + exit 1 +fi +echo "[INFO] Server is READY." | tee -a $LOG_FILE + +# 3. Prepare Test Data +echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE +PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py) +echo "$PREP_OUTPUT" >> $LOG_FILE + +export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2) +export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2) + +if [ -z "$TEST_USER_ID" ] || [ -z "$TEST_SESSION_ID" ]; then + echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE + kill $SERVER_PID + exit 1 +fi + +echo "[INFO] Test Data: User=$TEST_USER_ID, Session=$TEST_SESSION_ID" | tee -a $LOG_FILE + +# 4. Run Pytest +echo "[INFO] Running API E2E Tests..." | tee -a $LOG_FILE +PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py -m "smoke and real_api" -s \ + --junitxml=pytest_smoke.xml >> $LOG_FILE 2>&1 +TEST_EXIT_CODE=$? + +# 5. Cleanup +echo "[INFO] Shutting down Server..." | tee -a $LOG_FILE +kill $SERVER_PID + +echo "==========================================" | tee -a $LOG_FILE +if [ $TEST_EXIT_CODE -eq 0 ]; then + echo "FINAL RESULT: ✅ ALL API TESTS PASSED" | tee -a $LOG_FILE +else + echo "FINAL RESULT: ❌ SOME API TESTS FAILED (Code: $TEST_EXIT_CODE)" | tee -a $LOG_FILE +fi +echo "==========================================" | tee -a $LOG_FILE + +exit $TEST_EXIT_CODE diff --git a/run_full_api_test.sh b/run_full_api_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b0b14659f3880aef45215d749cdfd8ca097f743b --- /dev/null +++ b/run_full_api_test.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Full API integration (CI-style): eager Celery + mock video + full HTTP suite. +LOG_FILE="full_api_suite.log" +REPORT_FILE="full_api_test_report.md" +JSON_RESULTS="temp_suite_results.json" +JUNIT="pytest_api_suite.xml" + +echo "=== Starting Full API Suite Test ($(date)) ===" >"$LOG_FILE" + +trap 'echo "[INFO] Cleaning up processes..."; kill $SERVER_PID 2>/dev/null; sleep 1' EXIT + +echo "[INFO] Starting Backend Server (EAGER + MOCK_VIDEO)..." | tee -a "$LOG_FILE" +export ALLOW_TEST_BYPASS=true +export LOG_LEVEL=info +export CELERY_TASK_ALWAYS_EAGER=true +export CELERY_RESULT_BACKEND=rpc:// +export MOCK_VIDEO=true +PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 >server_debug.log 2>&1 & +SERVER_PID=$! + +echo "[INFO] Waiting for server (PID: $SERVER_PID)..." | tee -a "$LOG_FILE" +for i in {1..20}; do + if curl -s http://localhost:8000/ >/dev/null; then + echo "[INFO] Server is READY." | tee -a "$LOG_FILE" + break + fi + sleep 2 +done + +echo "[INFO] Preparing fresh test data..." | tee -a "$LOG_FILE" +PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py) +export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2) +export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2) + +if [ -z "$TEST_USER_ID" ]; then + echo "[ERROR] Failed to prepare test data." | tee -a "$LOG_FILE" + exit 1 +fi + +echo "[INFO] Executing API tests (smoke + full suite)..." | tee -a "$LOG_FILE" +PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py \ + -m "real_api" -s --tb=short --junitxml="$JUNIT" >>"$LOG_FILE" 2>&1 +TEST_EXIT_CODE=$? + +echo "[INFO] Generating Markdown Report..." | tee -a "$LOG_FILE" +if [ -f "$JSON_RESULTS" ]; then + PYTHONPATH=. venv/bin/python scripts/generate_report.py "$JSON_RESULTS" "$REPORT_FILE" "$JUNIT" +else + echo "[WARN] $JSON_RESULTS not found" | tee -a "$LOG_FILE" +fi + +echo "==========================================" | tee -a "$LOG_FILE" +echo "DONE. Check $REPORT_FILE for results." | tee -a "$LOG_FILE" +echo "==========================================" | tee -a "$LOG_FILE" + +exit $TEST_EXIT_CODE diff --git a/scripts/benchmark_openrouter.py b/scripts/benchmark_openrouter.py new file mode 100644 index 0000000000000000000000000000000000000000..7660d48b9e2997bd5bbfffd49b25f0404b4aa809 --- /dev/null +++ b/scripts/benchmark_openrouter.py @@ -0,0 +1,77 @@ +"""Benchmark several OpenRouter models (manual tool; not part of pytest).""" + +from __future__ import annotations + +import json +import os +import time + +import httpx +from dotenv import load_dotenv + +_BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +load_dotenv(os.path.join(_BACKEND_ROOT, ".env")) + +MODELS = [ + "nvidia/nemotron-3-super-120b-a12b:free", + "meta-llama/llama-3.3-70b-instruct:free", + "openai/gpt-oss-120b:free", + "z-ai/glm-4.5-air:free", + "minimax/minimax-m2.5:free", + "google/gemma-4-26b-a4b-it:free", + "google/gemma-4-31b-it:free", +] + +PROMPT = ( + "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10. Gọi E là điểm nằm trong đoạn CD sao cho CE = 2ED. " + "Vẽ đoạn thẳng AE. Vẽ thêm P là điểm nằm trên đường thẳng BC sao cho BP = 2PC, tính chu vi tam giác PEA" +) + + +def main() -> None: + api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY") + base_url = "https://openrouter.ai/api/v1/chat/completions" + + if not api_key: + print("Missing OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY in .env") + return + + print("Benchmark OpenRouter models\nPrompt:", PROMPT, "\n") + results = [] + + for model in MODELS: + print(f"Calling {model}...", end="", flush=True) + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "HTTP-Referer": "https://mathsolver.io", + "X-Title": "MathSolver Benchmark Tool", + } + payload = {"model": model, "messages": [{"role": "user", "content": PROMPT}]} + start = time.time() + try: + with httpx.Client(timeout=120.0) as client: + r = client.post(base_url, headers=headers, json=payload) + r.raise_for_status() + data = r.json() + answer = data["choices"][0]["message"]["content"] + duration = time.time() - start + results.append( + {"model": model, "duration": duration, "answer": answer, "status": "success"} + ) + print(f" OK ({duration:.2f}s)") + except Exception as e: + duration = time.time() - start + results.append( + {"model": model, "duration": duration, "error": str(e), "status": "error"} + ) + print(f" FAIL ({duration:.2f}s) {e}") + + print("\n" + "=" * 80) + for res in results: + print(json.dumps(res, ensure_ascii=False, indent=2)[:2000]) + print("-" * 40) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_report.py b/scripts/generate_report.py new file mode 100644 index 0000000000000000000000000000000000000000..979f2ae12b71dea413a9ded9c182ddeee21b0d1f --- /dev/null +++ b/scripts/generate_report.py @@ -0,0 +1,115 @@ +import json +import os +import sys +import xml.etree.ElementTree as ET +from datetime import datetime + + +def _parse_junit_xml(path: str) -> dict: + """Summarize pytest junitxml (JUnit) file.""" + out = {"tests": 0, "failures": 0, "errors": 0, "skipped": 0, "time": 0.0} + try: + tree = ET.parse(path) + root = tree.getroot() + nodes = [root] if root.tag == "testsuite" else list(root.iter("testsuite")) + for ts in nodes: + if ts.tag != "testsuite": + continue + out["tests"] += int(ts.attrib.get("tests", 0) or 0) + out["failures"] += int(ts.attrib.get("failures", 0) or 0) + out["errors"] += int(ts.attrib.get("errors", 0) or 0) + out["skipped"] += int(ts.attrib.get("skipped", 0) or 0) + out["time"] += float(ts.attrib.get("time", 0) or 0) + except Exception as e: + out["parse_error"] = str(e) + return out + + +def generate_report(json_path: str, report_path: str, junit_path: str | None = None) -> None: + try: + with open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + + junit_summary = None + if junit_path and os.path.isfile(junit_path): + junit_summary = _parse_junit_xml(junit_path) + + with open(report_path, "w", encoding="utf-8") as f: + f.write("# Báo cáo Kiểm thử tích hợp Backend (Integration Report)\n\n") + f.write(f"**Thời gian chạy:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + suite_ok = all(r.get("success", False) for r in data) if isinstance(data, list) else False + f.write(f"**API suite (JSON):** {'PASS' if suite_ok else 'FAIL'}\n") + + if junit_summary and "parse_error" not in junit_summary: + j_ok = junit_summary["failures"] == 0 and junit_summary["errors"] == 0 + f.write( + f"**Pytest (JUnit):** {'PASS' if j_ok else 'FAIL'} — " + f"tests={junit_summary['tests']}, failures={junit_summary['failures']}, " + f"errors={junit_summary['errors']}, skipped={junit_summary['skipped']}, " + f"time_s={junit_summary['time']:.2f}\n" + ) + elif junit_summary and "parse_error" in junit_summary: + f.write(f"**Pytest (JUnit):** (could not parse: {junit_summary['parse_error']})\n") + + f.write("\n") + + f.write("| ID | Câu hỏi (Query) | Trạng thái | Thời gian (s) | Kết quả / Lỗi |\n") + f.write("| :--- | :--- | :--- | :--- | :--- |\n") + for r in data: + status = "PASS" if r.get("success") else "FAIL" + elapsed = f"{float(r.get('elapsed', 0) or 0):.2f}" + query = r.get("query", "-") + + res = r.get("result", {}) + if not isinstance(res, dict): + res = {} + + analysis = res.get("semantic_analysis", "-") + if not r.get("success"): + analysis = f"**Lỗi:** {r.get('error', '-')}" + + short_analysis = (analysis[:100] + "...") if len(str(analysis)) > 100 else analysis + + f.write(f"| {r['id']} | {query} | {status} | {elapsed} | {short_analysis} |\n") + + f.write("\n---\n**Chi tiết Output (DSL & Analysis):**\n") + for r in data: + if not r.get("success"): + continue + res = r.get("result", {}) + if not isinstance(res, dict): + continue + + f.write(f"\n### Case {r['id']}: {r.get('query')}\n") + f.write(f"**Semantic Analysis:**\n{res.get('semantic_analysis', '-')}\n\n") + f.write(f"**Geometry DSL:**\n```\n{res.get('geometry_dsl', '-')}\n```\n") + + sol = res.get("solution") + if sol and isinstance(sol, dict): + f.write("**Solution (v5.1):**\n") + f.write(f"- **Answer:** {sol.get('answer', 'N/A')}\n") + f.write("- **Steps:**\n") + steps = sol.get("steps", []) + if steps: + for step in steps: + f.write(f" - {step}\n") + else: + f.write(" - (Không có bước giải cụ thể)\n") + + if sol.get("symbolic_expression"): + f.write(f"- **Symbolic:** `{sol.get('symbolic_expression')}`\n") + f.write("\n") + + print(f"Report generated: {report_path}") + except Exception as e: + print(f"Error generating report: {e}") + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print( + "Usage: python generate_report.py [junit_xml_optional]" + ) + sys.exit(1) + junit = sys.argv[3] if len(sys.argv) > 3 else None + generate_report(sys.argv[1], sys.argv[2], junit) diff --git a/scripts/prepare_api_test.py b/scripts/prepare_api_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a4f91758fd68440e9fe7ec10204abf81e09357 --- /dev/null +++ b/scripts/prepare_api_test.py @@ -0,0 +1,38 @@ +import os +import sys +import uuid + +from dotenv import load_dotenv + +# Add parent dir to path to import app modules +_BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(_BACKEND_ROOT) +load_dotenv(os.path.join(_BACKEND_ROOT, ".env")) + +from app.supabase_client import get_supabase + +# Default UUID matches historical dev DB; override with TEST_SUPABASE_USER_ID in .env +_DEFAULT_TEST_USER = "8cd3adb0-7964-4575-949c-d0cadcd8b679" + + +def prepare(): + supabase = get_supabase() + user_id = os.environ.get("TEST_SUPABASE_USER_ID", _DEFAULT_TEST_USER).strip() + session_id = str(uuid.uuid4()) + + print(f"Using test user (TEST_SUPABASE_USER_ID or default): {user_id}") + + print(f"Creating fresh test session: {session_id}") + # Insert session + supabase.table("sessions").insert({ + "id": session_id, + "user_id": user_id, + "title": f"Fresh API Test {session_id[:8]}" + }).execute() + + # Return IDs for the test script + print(f"RESULT:USER_ID={user_id}") + print(f"RESULT:SESSION_ID={session_id}") + +if __name__ == "__main__": + prepare() diff --git a/scripts/prewarm_models.py b/scripts/prewarm_models.py new file mode 100644 index 0000000000000000000000000000000000000000..4e54cad3c3f768ebe1bd4b9467cfef2fb3b67d59 --- /dev/null +++ b/scripts/prewarm_models.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +""" +Download and load all heavy models during Docker build (YOLO, PaddleOCR, Pix2Tex, agents). +Fails the image build if initialization fails. +""" + +from __future__ import annotations + +import logging +import os +import sys + +# Ensure imports work when run as `python scripts/prewarm_models.py` from WORKDIR +_APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _APP_ROOT not in sys.path: + sys.path.insert(0, _APP_ROOT) + +os.chdir(_APP_ROOT) + +from dotenv import load_dotenv + +load_dotenv() + +from app.runtime_env import apply_runtime_env_defaults + +apply_runtime_env_defaults() + +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s") + +logger = logging.getLogger("prewarm") + + +def main() -> None: + from agents.orchestrator import Orchestrator + + logger.info("Constructing Orchestrator (full agent + model load)...") + Orchestrator() + logger.info("Prewarm finished successfully.") + + +if __name__ == "__main__": + main() diff --git a/scripts/prewarm_ocr_worker.py b/scripts/prewarm_ocr_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..ca1001344e8009b95dc92525ef0ded5794e24c59 --- /dev/null +++ b/scripts/prewarm_ocr_worker.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +"""Docker build: load OCR models only (no Orchestrator / no LLM). Used by Dockerfile.worker.ocr.""" + +from __future__ import annotations + +import logging +import os +import sys + +_APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _APP_ROOT not in sys.path: + sys.path.insert(0, _APP_ROOT) + +os.chdir(_APP_ROOT) + +from dotenv import load_dotenv + +load_dotenv() + +from app.runtime_env import apply_runtime_env_defaults + +apply_runtime_env_defaults() + +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s") +logger = logging.getLogger("prewarm_ocr_worker") + + +def main() -> None: + from vision_ocr.pipeline import OcrVisionPipeline + + logger.info("Loading OcrVisionPipeline...") + OcrVisionPipeline() + logger.info("OCR worker prewarm finished successfully.") + + +if __name__ == "__main__": + main() diff --git a/scripts/prewarm_render_worker.py b/scripts/prewarm_render_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..17a85e032d084a406e1c11bbffe3986141f95070 --- /dev/null +++ b/scripts/prewarm_render_worker.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +"""Docker build: load geometry_render only (no Orchestrator / no LLM / no OCR).""" + +from __future__ import annotations + +import logging +import os +import subprocess +import sys + +_APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if _APP_ROOT not in sys.path: + sys.path.insert(0, _APP_ROOT) + +os.chdir(_APP_ROOT) + +from dotenv import load_dotenv + +load_dotenv() + +from app.runtime_env import apply_runtime_env_defaults + +apply_runtime_env_defaults() + +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s") +logger = logging.getLogger("prewarm_render_worker") + + +def main() -> None: + from geometry_render.renderer import RendererAgent + + logger.info("Loading RendererAgent (geometry_render only)...") + RendererAgent() + try: + r = subprocess.run( + ["manim", "--version"], + capture_output=True, + text=True, + timeout=60, + ) + if r.returncode == 0: + logger.info("manim --version: %s", (r.stdout or r.stderr or "").strip()[:200]) + else: + logger.warning("manim --version failed rc=%s", r.returncode) + except FileNotFoundError: + logger.warning("manim CLI not found on PATH (skipping version check).") + except subprocess.TimeoutExpired: + logger.warning("manim --version timed out.") + + logger.info("Render worker prewarm finished successfully.") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_real_integration.sh b/scripts/run_real_integration.sh new file mode 100755 index 0000000000000000000000000000000000000000..4046e9d18981a3cc2e535bd8399257420fa4356c --- /dev/null +++ b/scripts/run_real_integration.sh @@ -0,0 +1,134 @@ +#!/usr/bin/env bash +# Run backend integration tests. Usage: +# ./scripts/run_real_integration.sh # profile ci (default) +# ./scripts/run_real_integration.sh ci +# ./scripts/run_real_integration.sh real # heavy: workers, manim, OCR, full API suite +set -euo pipefail + +PROFILE="${1:-ci}" +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" +PY="${ROOT}/venv/bin/python" +if [[ ! -x "$PY" ]]; then + PY="python3" +fi + +export PYTHONPATH="$ROOT" +LOG_FILE="${LOG_FILE:-integration_run.log}" +JUNIT="${JUNIT:-pytest_integration.xml}" +REPORT_MD="${REPORT_MD:-integration_report.md}" +JSON_RESULTS="${JSON_RESULTS:-temp_suite_results.json}" + +log() { echo "[$(date '+%H:%M:%S')] $*" | tee -a "$LOG_FILE"; } + +log "Profile=$PROFILE working_dir=$ROOT" + +if [[ "$PROFILE" == "ci" ]]; then + export ALLOW_TEST_BYPASS="${ALLOW_TEST_BYPASS:-true}" + export LOG_LEVEL="${LOG_LEVEL:-info}" + export CELERY_TASK_ALWAYS_EAGER="${CELERY_TASK_ALWAYS_EAGER:-true}" + export CELERY_RESULT_BACKEND="${CELERY_RESULT_BACKEND:-rpc://}" + export MOCK_VIDEO="${MOCK_VIDEO:-true}" + + set +e + log "Phase A: default pytest (unit / mocked; excludes real_* markers per pytest.ini)" + "$PY" -m pytest tests/ -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" + P1=${PIPESTATUS[0]} + set -e + + log "Starting uvicorn for API phase..." + "$PY" -m uvicorn app.main:app --port 8000 >>uvicorn_integration.log 2>&1 & + SERVER_PID=$! + trap 'kill "$SERVER_PID" 2>/dev/null || true' EXIT + + for i in $(seq 1 25); do + if curl -sf "http://localhost:8000/" >/dev/null; then + log "API ready" + break + fi + sleep 2 + if [[ "$i" -eq 25 ]]; then + log "ERROR: API did not start" + exit 1 + fi + done + + PREP="$("$PY" scripts/prepare_api_test.py)" + echo "$PREP" | tee -a "$LOG_FILE" + export TEST_USER_ID="$(echo "$PREP" | grep "RESULT:USER_ID=" | cut -d'=' -f2)" + export TEST_SESSION_ID="$(echo "$PREP" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)" + if [[ -z "${TEST_USER_ID:-}" || -z "${TEST_SESSION_ID:-}" ]]; then + log "ERROR: prepare_api_test did not emit USER_ID / SESSION_ID" + exit 1 + fi + + set +e + log "Phase B: API smoke + full suite (real_api)" + "$PY" -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py \ + -m "real_api" -s --tb=short --junitxml="$JUNIT" -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" + P2=${PIPESTATUS[0]} + set -e + + if [[ -f "$JSON_RESULTS" ]]; then + log "Generating Markdown report" + "$PY" scripts/generate_report.py "$JSON_RESULTS" "$REPORT_MD" "$JUNIT" + else + log "WARN: $JSON_RESULTS missing (suite may have failed before write)" + fi + + if [[ "$P1" -ne 0 || "$P2" -ne 0 ]]; then + log "FAIL: phase A exit=$P1 phase B exit=$P2" + exit 1 + fi + + log "Done CI profile. See $REPORT_MD and $LOG_FILE" + exit 0 +fi + +if [[ "$PROFILE" == "real" ]]; then + unset CELERY_TASK_ALWAYS_EAGER || true + export CELERY_TASK_ALWAYS_EAGER="${CELERY_TASK_ALWAYS_EAGER:-false}" + export MOCK_VIDEO="${MOCK_VIDEO:-false}" + export RUN_REAL_WORKER_OCR="${RUN_REAL_WORKER_OCR:-0}" + export RUN_REAL_WORKER_MANIM="${RUN_REAL_WORKER_MANIM:-0}" + + log "Phase A: default pytest (fast)" + "$PY" -m pytest tests/ -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" + + log "Phase B: real agents + orchestrator smoke (requires OpenRouter keys)" + "$PY" -m pytest tests/integration/test_agents_real.py tests/integration/test_orchestrator_smoke.py \ + -m "real_agents" -q --tb=short --junitxml="$JUNIT" -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true + + if [[ "${RUN_REAL_WORKER_OCR:-0}" == "1" ]] || [[ "${RUN_REAL_WORKER_OCR:-0}" =~ ^(true|yes)$ ]]; then + log "Phase C: OCR worker task (RUN_REAL_WORKER_OCR enabled)" + "$PY" -m pytest tests/integration/test_worker_ocr_real.py \ + -m "real_worker_ocr" -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true + else + log "Skipping OCR worker (set RUN_REAL_WORKER_OCR=1 to enable)" + fi + + if [[ "${RUN_REAL_WORKER_MANIM:-0}" == "1" ]]; then + log "Phase D: Manim + storage (RUN_REAL_WORKER_MANIM=1, MOCK_VIDEO=false)" + "$PY" -m pytest tests/integration/test_worker_manim_real.py -m "real_worker_manim" -s --tb=short \ + -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true + else + log "Skipping Manim integration (set RUN_REAL_WORKER_MANIM=1 to enable)" + fi + + log "Phase E: API real (expects TEST_BASE_URL or localhost:8000 with server already up)" + if curl -sf "http://localhost:8000/" >/dev/null 2>&1; then + PREP="$("$PY" scripts/prepare_api_test.py)" + export TEST_USER_ID="$(echo "$PREP" | grep "RESULT:USER_ID=" | cut -d'=' -f2)" + export TEST_SESSION_ID="$(echo "$PREP" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)" + "$PY" -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py tests/test_api_metadata_real.py \ + -m "real_api" -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true + else + log "WARN: No server on :8000 — skip API real phase (start backend first)" + fi + + log "Done REAL profile. Review $LOG_FILE" + exit 0 +fi + +echo "Unknown profile: $PROFILE (use ci or real)" +exit 1 diff --git a/scripts/test_LLM.py b/scripts/test_LLM.py new file mode 100644 index 0000000000000000000000000000000000000000..a43ecb2655082b9e64b53e61002917d356f5e99b --- /dev/null +++ b/scripts/test_LLM.py @@ -0,0 +1,142 @@ +import sys +import os +import time +import asyncio +import logging +from typing import List, Dict, Any +from dotenv import load_dotenv + +# Add the parent directory to sys.path to allow importing from 'app' +# This assumes the script is inside 'backend/scripts' and we want to import from 'backend/app' +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from app.url_utils import openai_compatible_api_key +from openai import AsyncOpenAI + +# Set up logger +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) + +# List of models to benchmark +MODELS_TO_TEST = [ + "nvidia/nemotron-3-super-120b-a12b:free", + "meta-llama/llama-3.3-70b-instruct:free", + "openai/gpt-oss-120b:free", + "z-ai/glm-4.5-air:free", + "minimax/minimax-m2.5:free", + "google/gemma-4-26b-a4b-it:free", + "google/gemma-4-31b-it:free", + "arcee-ai/trinity-large-preview:free", + "openai/gpt-oss-20b:free", + "nvidia/nemotron-3-nano-30b-a3b:free", + "nvidia/nemotron-nano-9b-v2:free", +] + +DEFAULT_QUERY = "Giải hệ phương trình sau: x + y = 10, 2x - y = 2. Trả về kết quả cuối cùng x và y." + +async def test_model(client: AsyncOpenAI, model: str, query: str) -> Dict[str, Any]: + """Test a single model and return performance metrics.""" + start_time = time.time() + result = { + "model": model, + "status": "success", + "duration": 0, + "content": "", + "error": None + } + + try: + response = await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": query}], + timeout=60.0 + ) + result["duration"] = time.time() - start_time + result["content"] = response.choices[0].message.content.strip() + except Exception as e: + result["status"] = "failed" + result["duration"] = time.time() - start_time + result["error"] = str(e) + + return result + +async def main(): + # Load configuration from .env file inside backend directory + # If starting from root, backend/.env might be needed. If starting from backend/, .env is enough. + load_dotenv() + + # Try multiple common env keys for api key + api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY") + + if not api_key: + logger.error("❌ Error: NO OPENROUTER_API_KEY found in environment variables.") + logger.info("Check your .env file in the backend directory.") + return + + # Using the project's url_utils to maintain consistency with the main app + sanitized_key = openai_compatible_api_key(api_key) + + client = AsyncOpenAI( + api_key=sanitized_key, + base_url="https://openrouter.ai/api/v1", + default_headers={ + "HTTP-Referer": "https://mathsolver.ai", + "X-Title": "MathSolver LLM Benchmarker", + } + ) + + query = DEFAULT_QUERY + logger.info("=" * 80) + logger.info(f"🚀 LLM PERFORMANCE BENCHMARK") + logger.info(f"Query: {query}") + logger.info("=" * 80) + logger.info(f"Testing {len(MODELS_TO_TEST)} models sequentially with 30s delay...\n") + + results = [] + for i, model in enumerate(MODELS_TO_TEST): + if i > 0: + logger.info(f"⏳ Waiting 30s before testing next model...") + await asyncio.sleep(30) + + logger.info(f"[{i+1}/{len(MODELS_TO_TEST)}] Testing: {model}...") + res = await test_model(client, model, query) + results.append(res) + + # Immediate feedback + status_str = "✅ SUCCESS" if res["status"] == "success" else "❌ FAILED" + logger.info(f" Status: {status_str} | Time: {res['duration']:.2f}s") + + # Report Summary Table + logger.info("\n" + "=" * 80) + logger.info("📊 FINAL BENCHMARK SUMMARY") + logger.info("=" * 80) + header = f"{'MODEL':<45} | {'STATUS':<10} | {'TIME (s)':<10}" + logger.info(header) + logger.info("-" * len(header)) + + for res in results: + status_str = "✅ SUCCESS" if res["status"] == "success" else "❌ FAILED" + duration_str = f"{res['duration']:.2f}s" + logger.info(f"{res['model']:<45} | {status_str:<10} | {duration_str:<10}") + + logger.info("-" * len(header)) + + # Detailed report for successful ones + logger.info("\n📝 FULL RESPONSES:") + for res in results: + logger.info(f"\n{'='*20} [{res['model']}] {'='*20}") + if res["status"] == "success": + logger.info(res["content"]) + else: + logger.info(f"❌ Error: {res['error']}") + + logger.info("\n" + "=" * 80) + logger.info(f"Benchmark finished.") + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("\nBenchmark cancelled by user.") + except Exception as e: + logger.error(f"Unexpected error: {e}") diff --git a/scripts/test_engine_direct.py b/scripts/test_engine_direct.py new file mode 100644 index 0000000000000000000000000000000000000000..a26b8c05d1cb152886524278b66d0aeff8d57381 --- /dev/null +++ b/scripts/test_engine_direct.py @@ -0,0 +1,36 @@ +import asyncio +import os +import json +import logging +import sys + +# Add root directory to path to import app and agents +sys.path.append("/Volumes/WorkSpace/Project/MathSolver/backend") + +# Configure logging to stdout +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +from agents.orchestrator import Orchestrator + +async def main(): + orch = Orchestrator() + text = "Vẽ tam giác đều cạnh 5." + job_id = "test_direct_equilateral" + + print(f"\n--- Testing Orchestrator Direct: {text} ---") + + async def status_cb(status): + print(f" [STATUS] {status}") + + try: + result = await orch.run(text, job_id=job_id, status_callback=status_cb, request_video=False) + print("\n--- Final Result ---") + print(json.dumps(result, indent=2)) + except Exception as e: + print(f"\n--- ERROR ---") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000000000000000000000000000000000000..c102430ffab0816b6f2c08295e6d887af1f2380e --- /dev/null +++ b/setup.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# MathSolver v3.1 Setup Script for macOS + +echo "🚀 Starting Environment Setup..." + +# 1. System Dependencies (Homebrew) +if command -v brew >/dev/null 2>&1; then + echo "📦 Installing system dependencies via Homebrew..." + brew install pango pkg-config glib librsvg +else + echo "⚠️ Homebrew not found. Please install it first: https://brew.sh/" + exit 1 +fi + +# 2. Python SSL Certificates +PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') +CERT_FILE="/Applications/Python ${PYTHON_VERSION}/Install Certificates.command" + +if [ -f "$CERT_FILE" ]; then + echo "🔐 Installing Python SSL certificates..." + sh "$CERT_FILE" +else + echo "ℹ️ SSL certificate installer not found at $CERT_FILE. Skipping..." +fi + +# 3. Virtual Environment +echo "🐍 Setting up Python Virtual Environment..." +cd backend +python3 -m venv venv +source venv/bin/activate + +# 4. Pip packages +echo "📦 Installing Python packages..." +pip install --upgrade pip +pip install -r requirements.txt + +# 5. Fix ManimPango (Crucial for macOS arm64) +echo "🛠️ Rebuilding ManimPango from source to ensure library linking..." +pip install --no-cache-dir --force-reinstall --no-binary manimpango manimpango + +echo "✅ Setup Complete!" +echo "To start the backend, run: source venv/bin/activate && uvicorn app.main:app --reload" diff --git a/solver/dsl_parser.py b/solver/dsl_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9e3322b2788b8480bcd251ac27cdb30a144ffb --- /dev/null +++ b/solver/dsl_parser.py @@ -0,0 +1,217 @@ +import re +import logging +from typing import List, Tuple, Dict, Any +from .models import Point, Constraint + +logger = logging.getLogger(__name__) + + +class DSLParser: + def parse(self, text: str) -> Tuple[List[Point], List[Constraint], bool]: + """Parse DSL text into points and constraints. Stateless per call.""" + points: Dict[str, Point] = {} + explicit_point_ids: List[str] = [] + constraints: List[Constraint] = [] + polygon_order: List[str] = [] + circles: List[Dict[str, Any]] = [] + segments: List[List[str]] = [] + lines_ext: List[List[str]] = [] + rays: List[List[str]] = [] + is_3d = False + + logger.info("==[DSLParser] Parsing DSL input==") + logger.debug(f"[DSLParser] Raw DSL:\n{text}") + + lines = text.strip().split('\n') + for line in lines: + line = line.strip() + if not line or line.startswith('//') or line.startswith('#'): + continue + + # POINT(A) or POINT(A, 0, 0, 5) + m = re.match(r'POINT\((\w+)(?:,\s*([\d\.-]+),\s*([\d\.-]+)(?:,\s*([\d\.-]+))?)?\)', line) + if m: + name = m.group(1) + x = float(m.group(2)) if m.group(2) else None + y = float(m.group(3)) if m.group(3) else None + z = float(m.group(4)) if m.group(4) else None + # z=0 with x,y is still the xy-plane; only treat as 3D when z is meaningfully non-zero. + # Otherwise POINT(A,0,0,0) incorrectly forced is_3d and broke 2D engine paths. + if z is not None and abs(z) > 1e-9: + is_3d = True + points[name] = Point(id=name, x=x, y=y, z=z) + if name not in explicit_point_ids: + explicit_point_ids.append(name) + logger.debug(f"[DSLParser] + POINT: {name} ({x}, {y}, {z})") + continue + + # LENGTH(AB, 5) + m = re.match(r'LENGTH\((\w+),\s*([\d\.]+)\)', line) + if m: + target, value = m.group(1), float(m.group(2)) + pts = [target[i:i+1] for i in range(len(target))] + constraints.append(Constraint(type='length', targets=pts, value=value)) + logger.debug(f"[DSLParser] + LENGTH: {pts} = {value}") + continue + + # ANGLE(A, 90) or ANGLE(A, 90deg) + m = re.match(r'ANGLE\((\w+),\s*([\d\.]+)(?:deg)?\)', line) + if m: + target, value = m.group(1), float(m.group(2)) + constraints.append(Constraint(type='angle', targets=[target], value=value)) + logger.debug(f"[DSLParser] + ANGLE: vertex={target}, degrees={value}") + continue + + # PARALLEL(AB, CD) + m = re.match(r'PARALLEL\((\w+),\s*(\w+)\)', line) + if m: + seg1, seg2 = m.group(1), m.group(2) + constraints.append(Constraint(type='parallel', targets=list(seg1) + list(seg2), value=0)) + logger.debug(f"[DSLParser] + PARALLEL: {seg1} || {seg2}") + continue + + # PERPENDICULAR(AB, CD) + m = re.match(r'PERPENDICULAR\((\w+),\s*(\w+)\)', line) + if m: + seg1, seg2 = m.group(1), m.group(2) + constraints.append(Constraint(type='perpendicular', targets=list(seg1) + list(seg2), value=0)) + logger.debug(f"[DSLParser] + PERPENDICULAR: {seg1} _|_ {seg2}") + continue + + # MIDPOINT(M, AB) — M is midpoint of AB + m = re.match(r'MIDPOINT\((\w+),\s*(\w+)\)', line) + if m: + mid, seg = m.group(1), m.group(2) + if mid not in points: + points[mid] = Point(id=mid) + pts = [mid] + [seg[i:i+1] for i in range(len(seg))] + constraints.append(Constraint(type='midpoint', targets=pts, value=0)) + logger.debug(f"[DSLParser] + MIDPOINT: {mid} = mid({seg})") + continue + + # SECTION(E, A, C, 0.66) — E lies on AC s.t. AE = 0.66 * AC + m = re.match(r'SECTION\((\w+),\s*(\w+),\s*(\w+),\s*([\d\.-]+)\)', line) + if m: + target, p1, p2, k = m.group(1), m.group(2), m.group(3), float(m.group(4)) + if target not in points: + points[target] = Point(id=target) + constraints.append(Constraint(type='section', targets=[target, p1, p2], value=k)) + logger.debug(f"[DSLParser] + SECTION: {target} = {p1} + {k}({p2}-{p1})") + continue + + # CIRCLE(O, r) + m = re.match(r'CIRCLE\((\w+),\s*([\d\.]+)\)', line) + if m: + center, radius = m.group(1), float(m.group(2)) + if center not in points: + points[center] = Point(id=center) + constraints.append(Constraint(type='circle', targets=[center], value=radius)) + circles.append({"center": center, "radius": radius}) + logger.debug(f"[DSLParser] + CIRCLE: center={center}, r={radius}") + continue + + # POLYGON_ORDER(A, B, C, D) — thứ tự nối điểm để vẽ đa giác + m = re.match(r'POLYGON_ORDER\(([^)]+)\)', line) + if m: + polygon_order = [p.strip() for p in m.group(1).split(',')] + logger.debug(f"[DSLParser] + POLYGON_ORDER: {polygon_order}") + continue + + # SEGMENT(M, N) — đoạn thẳng phụ cần vẽ + m = re.match(r'SEGMENT\((\w+),\s*(\w+)\)', line) + if m: + p1, p2 = m.group(1), m.group(2) + segments.append([p1, p2]) + constraints.append(Constraint(type='segment', targets=[p1, p2], value=0)) + logger.debug(f"[DSLParser] + SEGMENT: {p1}—{p2}") + continue + + # LINE(A, B) — infinite line + m = re.match(r'LINE\((\w+),\s*(\w+)\)', line) + if m: + p1, p2 = m.group(1), m.group(2) + lines_ext.append([p1, p2]) + constraints.append(Constraint(type='line', targets=[p1, p2], value=0)) + logger.debug(f"[DSLParser] + LINE: {p1}-{p2}") + continue + + # RAY(A, B) — ray AB starting at A + m = re.match(r'RAY\((\w+),\s*(\w+)\)', line) + if m: + p1, p2 = m.group(1), m.group(2) + rays.append([p1, p2]) + constraints.append(Constraint(type='ray', targets=[p1, p2], value=0)) + logger.debug(f"[DSLParser] + RAY: {p1}->{p2}") + continue + + # TRIANGLE(ABC) / PYRAMID(S_ABCD) / PRISM(ABC_DEF) + m = re.match(r'(TRIANGLE|PYRAMID|PRISM)\(([^)]+)\)', line) + if m: + pt_type = m.group(1) + targets = m.group(2) + if pt_type in ["PYRAMID", "PRISM"]: + is_3d = True + if pt_type == "TRIANGLE": + if not polygon_order: polygon_order = list(targets) + elif pt_type == "PYRAMID": + # S_ABCD -> S is apex, ABCD is base + if "_" in targets: + apex, base = targets.split("_") + # Add segments from apex to all base points + for p in base: + segments.append([apex, p]) + constraints.append(Constraint(type='segment', targets=[apex, p], value=0)) + if not polygon_order: polygon_order = list(base) + elif pt_type == "PRISM": + # ABC_DEF -> two bases + if "_" in targets: + b1, b2 = targets.split("_") + for p1, p2 in zip(b1, b2): + segments.append([p1, p2]) + constraints.append(Constraint(type='segment', targets=[p1, p2], value=0)) + logger.debug(f"[DSLParser] + {pt_type}: {targets}") + continue + + # SPHERE(O, r) + m = re.match(r'SPHERE\((\w+),\s*([\d\.]+)\)', line) + if m: + is_3d = True + center, radius = m.group(1), float(m.group(2)) + if center not in points: + points[center] = Point(id=center) + constraints.append(Constraint(type='sphere', targets=[center], value=radius)) + logger.debug(f"[DSLParser] + SPHERE: center={center}, r={radius}") + continue + + logger.warning(f"[DSLParser] ? Unrecognized DSL line: '{line}'") + + logger.info( + "[DSLParser] Parsed %d points, %d constraints, is_3d=%s.", + len(points), + len(constraints), + is_3d, + ) + + # Safety sweep: Ensure all points referenced in constraints actually exist in the points dictionary + for c in constraints: + for pid in c.targets: + # Some targets might be values or comma-separated strings (handled elsewhere), + # but most are single-character point IDs. + if isinstance(pid, str) and len(pid) == 1 and pid not in points: + points[pid] = Point(id=pid) + logger.debug(f"[DSLParser] ! Auto-declared missing point from constraint: {pid}") + + # Attach metadata to a synthetic constraint for downstream use + if polygon_order: + constraints.append(Constraint(type='polygon_order', targets=polygon_order, value=0)) + elif explicit_point_ids: + # Re-use polygon_order as a carrier for explicit points IF no real order was specified + constraints.append(Constraint(type='explicit_points', targets=explicit_point_ids, value=0)) + + # Add auxiliary metadata for lines and rays + if lines_ext: + constraints.append(Constraint(type='lines_metadata', targets=[",".join(l) for l in lines_ext], value=0)) + if rays: + constraints.append(Constraint(type='rays_metadata', targets=[",".join(l) for l in rays], value=0)) + + return list(points.values()), constraints, is_3d diff --git a/solver/engine.py b/solver/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..69bdbf1f9080d3e5d74ebbf0dd1a37fdd6e88438 --- /dev/null +++ b/solver/engine.py @@ -0,0 +1,426 @@ +import sympy as sp +import numpy as np +import logging +import string +from typing import List, Dict, Any +from .models import Point, Constraint + +logger = logging.getLogger(__name__) + + +class GeometryEngine: + def solve(self, points: List[Point], constraints: List[Constraint], is_3d: bool = False) -> Dict[str, Any] | None: + if not points: + logger.error("[GeometryEngine] No points to solve.") + return None + + logger.info(f"==[GeometryEngine] Starting solve with {len(points)} points, {len(constraints)} constraints (is_3d={is_3d})==") + + # ── Separate metadata constraints from real ones ────────────────────── + polygon_order: List[str] = [] + circles_meta: List[Dict] = [] + segments_meta: List[List[str]] = [] + real_constraints: List[Constraint] = [] + + for c in constraints: + if c.type == 'polygon_order': + polygon_order = list(c.targets) + elif c.type == 'explicit_points' and not polygon_order: + polygon_order = list(c.targets) + elif c.type == 'circle': + circles_meta.append({"center": c.targets[0], "radius": float(c.value)}) + real_constraints.append(c) + elif c.type == 'segment': + segments_meta.append(list(c.targets)) + # don't add to equations — pure drawing annotation + elif c.type == 'lines_metadata': + lines_meta_list = [t.split(',') for t in c.targets] + real_constraints.append(c) # for passing to builder? or just keep here + elif c.type == 'rays_metadata': + rays_meta_list = [t.split(',') for t in c.targets] + real_constraints.append(c) + else: + real_constraints.append(c) + + # ── Setup symbols ───────────────────────────────────────────────────── + point_vars: Dict[str, tuple] = {} + equations = [] + + # Convert to list for stable indexing and to handle both Dict and List inputs + pt_list = list(points.values()) if isinstance(points, dict) else points + + for p in pt_list: + x = sp.Symbol(f"{p.id}_x") + y = sp.Symbol(f"{p.id}_y") + z = sp.Symbol(f"{p.id}_z") + point_vars[p.id] = (x, y, z) + logger.debug(f"[GeometryEngine] Symbol: ({p.id}_x, {p.id}_y, {p.id}_z)") + + # If 2D problem, pin all z to 0 immediately + if not is_3d: + equations.append(z) + + # ── Anchor logic to fix translation + rotation DOF ──────────────────── + # Skip anchoring if points already have explicit coordinates that fix DOFs + + if len(pt_list) > 0: + p1 = pt_list[0] + # Translation: fix p1 at (0,0) or (0,0,0) + if p1.x is None: equations.append(point_vars[p1.id][0]); logger.debug(f"Anchor {p1.id}_x=0") + if p1.y is None: equations.append(point_vars[p1.id][1]); logger.debug(f"Anchor {p1.id}_y=0") + if is_3d and p1.z is None: + equations.append(point_vars[p1.id][2]); logger.debug(f"Anchor {p1.id}_z=0") + + if len(pt_list) > 1: + p2 = pt_list[1] + # Rotation: fix p2 on X-axis (y=0) + if p2.y is None: equations.append(point_vars[p2.id][1]); logger.debug(f"Anchor {p2.id}_y=0") + if is_3d and p2.z is None: + equations.append(point_vars[p2.id][2]); logger.debug(f"Anchor {p2.id}_z=0") + + if is_3d and len(pt_list) > 2: + p3 = pt_list[2] + # Planar rotation: fix p3 on XY-plane (z=0) + if p3.z is None: equations.append(point_vars[p3.id][2]); logger.debug(f"Anchor {p3.id}_z=0") + + # ── Build equations from explicit point coordinates ────────────────── + for p in pt_list: + if p.x is not None: + equations.append(point_vars[p.id][0] - p.x) + if p.y is not None: + equations.append(point_vars[p.id][1] - p.y) + if p.z is not None: + equations.append(point_vars[p.id][2] - p.z) + + # ── Build equations from constraints ────────────────────────────────── + for c in real_constraints: + logger.debug(f"[GeometryEngine] Processing constraint: type={c.type}, targets={c.targets}, value={c.value}") + + if c.type == 'length' and len(c.targets) == 2: + p1, p2 = c.targets + if p1 not in point_vars or p2 not in point_vars: + logger.warning(f"[GeometryEngine] Skip length: {c.targets} not in symbols.") + continue + v1, v2 = point_vars[p1], point_vars[p2] + # 3D distance + eq = (v2[0]-v1[0])**2 + (v2[1]-v1[1])**2 + (v2[2]-v1[2])**2 - float(c.value)**2 + equations.append(eq) + logger.debug(f"[GeometryEngine] -> Length eq (3D): |{p1}{p2}|² = {c.value}²") + + elif c.type == 'angle' and len(c.targets) >= 1: + # In 3D, 'angle' usually refers to the angle between two vectors (e.g., ∠BAC) + v_name = c.targets[0] + if v_name not in point_vars: + continue + # For simplicity, we assume the next two points in targets or fallback to first 2 others + if len(c.targets) >= 3: + p1_name, p2_name = c.targets[1], c.targets[2] + else: + other_pts = [p.id for p in pt_list if p.id != v_name][:2] + if len(other_pts) < 2: continue + p1_name, p2_name = other_pts + + pV = point_vars[v_name] + p1_vars = point_vars[p1_name] + p2_vars = point_vars[p2_name] + + # Vectors V1 and V2 + v1 = [p1_vars[i] - pV[i] for i in range(3)] + v2 = [p2_vars[i] - pV[i] for i in range(3)] + + # Dot product relation: v1.v2 = |v1||v2| cos(theta) + # But we use the tangent relation or square it to avoid sqrt if possible + # If 90 deg: dot product = 0 + if abs(float(c.value) - 90.0) < 1e-9: + eq = sum(v1[i]*v2[i] for i in range(3)) + logger.debug(f"[GeometryEngine] -> Angle eq at {v_name} (90° dot=0)") + else: + # Generic angle using law of cosines (squared) + cos_val = np.cos(np.deg2rad(float(c.value))) + d1_sq = sum(v1[i]**2 for i in range(3)) + d2_sq = sum(v2[i]**2 for i in range(3)) + dot = sum(v1[i]*v2[i] for i in range(3)) + eq = dot**2 - (cos_val**2) * d1_sq * d2_sq + # Note: this allows theta and 180-theta. + # Better: dot - cos(theta) * sqrt(d1_sq * d2_sq) = 0, but that has sqrt. + logger.debug(f"[GeometryEngine] -> Angle eq at {v_name} ({c.value}° cos² relation)") + equations.append(eq) + + elif c.type == 'parallel' and len(c.targets) == 4: + pA, pB, pC, pD = c.targets + if any(t not in point_vars for t in [pA, pB, pC, pD]): continue + va, vb, vc, vd = point_vars[pA], point_vars[pB], point_vars[pC], point_vars[pD] + # AB || CD means vector(AB) = lambda * vector(CD) + # In 3D, cross product = 0. (b-a) x (d-c) = 0 + v1 = [vb[i]-va[i] for i in range(3)] + v2 = [vd[i]-vc[i] for i in range(3)] + # Cross product components: + equations.append(v1[1]*v2[2] - v1[2]*v2[1]) + equations.append(v1[2]*v2[0] - v1[0]*v2[2]) + equations.append(v1[0]*v2[1] - v1[1]*v2[0]) + logger.debug(f"[GeometryEngine] -> Parallel eq (3D cross=0): {pA}{pB} || {pC}{pD}") + + elif c.type == 'perpendicular' and len(c.targets) == 4: + pA, pB, pC, pD = c.targets + if any(t not in point_vars for t in [pA, pB, pC, pD]): continue + va, vb, vc, vd = point_vars[pA], point_vars[pB], point_vars[pC], point_vars[pD] + # Dot product = 0 + dot = sum((vb[i]-va[i])*(vd[i]-vc[i]) for i in range(3)) + equations.append(dot) + logger.debug(f"[GeometryEngine] -> Perpendicular eq (3D dot=0): {pA}{pB} ⊥ {pC}{pD}") + + elif c.type == 'midpoint' and len(c.targets) == 3: + pM, pA, pB = c.targets + if any(t not in point_vars for t in [pM, pA, pB]): continue + vM, vA, vB = point_vars[pM], point_vars[pA], point_vars[pB] + for i in range(3): + equations.append(2*vM[i] - vA[i] - vB[i]) + logger.debug(f"[GeometryEngine] -> Midpoint eq (3D): {pM} = mid({pA},{pB})") + + elif c.type == 'section' and len(c.targets) == 3: + pE, pA, pC = c.targets + if any(t not in point_vars for t in [pE, pA, pC]): continue + vE, vA, vC = point_vars[pE], point_vars[pA], point_vars[pC] + k = float(c.value) + for i in range(3): + equations.append(vE[i] - (vA[i] + k * (vC[i] - vA[i]))) + logger.debug(f"[GeometryEngine] -> Section eq (3D): {pE} = {pA} + {k}({pC}-{pA})") + + elif c.type == 'circle': + # Circle doesn't add position constraints for center (already a point) + logger.debug(f"[GeometryEngine] -> Circle: center={c.targets[0]}, r={c.value} (meta only)") + + all_vars = [] + for v in point_vars.values(): + all_vars.extend(v) + + n_eqs = len(equations) + n_vars = len(all_vars) + logger.info(f"[GeometryEngine] Built {n_eqs} equations for {n_vars} unknowns.") + + # ── Strategy 1: SymPy symbolic ─────────────────────────────────────── + coords = self._try_symbolic(equations, all_vars, point_vars) + + # Extract lines/rays from constraints for builder + lines_ext = [] + rays_ext = [] + for c in constraints: + if c.type == 'lines_metadata': + lines_ext = [t.split(',') for t in c.targets] + if c.type == 'rays_metadata': + rays_ext = [t.split(',') for t in c.targets] + + if coords: + return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list) + + # ── Strategy 2: Numerical nsolve ───────────────────────────────────── + if n_eqs == n_vars: + coords = self._try_nsolve(equations, all_vars, point_vars, n_vars) + if coords: + return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list) + + # ── Strategy 3: Scipy least-squares ───────────────────────────────── + coords = self._try_lsq(equations, all_vars, point_vars, n_vars) + if coords: + return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list) + + # ── Strategy 4: Differential evolution ────────────────────────────── + coords = self._try_global(equations, all_vars, point_vars, n_vars) + if coords: + return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list) + + logger.error("[GeometryEngine] All strategies exhausted.") + return None + + # ─── Solving strategies ────────────────────────────────────────────────── + + def _try_symbolic(self, equations, all_vars, point_vars): + # Optimization: SymPy's symbolic solver becomes extremely slow for many variables. + # For 3D problems (usually 12-18+ variables), we prefer using numerical methods directly. + if len(all_vars) > 10: + logger.info(f"[GeometryEngine] Strategy 1: Skipping symbolic solve due to high variable count ({len(all_vars)}).") + return None + + try: + solution = sp.solve(equations, all_vars, dict=True) + if solution: + res = solution[0] + logger.info("[GeometryEngine] Strategy 1 (SymPy symbolic): SUCCESS.") + logger.debug(f"[GeometryEngine] Symbolic solution: {res}") + return {pid: [float(res.get(vx, 0.0)), float(res.get(vy, 0.0)), float(res.get(vz, 0.0))] + for pid, (vx, vy, vz) in point_vars.items()} + else: + logger.warning("[GeometryEngine] Strategy 1 returned no solution. Trying numerical...") + except Exception as e: + logger.warning(f"[GeometryEngine] Strategy 1 threw exception: {e}. Trying numerical...") + return None + + def _try_nsolve(self, equations, all_vars, point_vars, n_vars): + MAX_NSOLVE_ATTEMPTS = 15 + logger.info(f"[GeometryEngine] Strategy 2 (nsolve): square system ({n_vars}x{n_vars}). Trying {MAX_NSOLVE_ATTEMPTS} random starts...") + import random + for attempt in range(MAX_NSOLVE_ATTEMPTS): + try: + # Use varying scales for the random guesses to handle different problem sizes + scale = 10 if attempt < 5 else (100 if attempt < 10 else 1) + guesses = [random.uniform(-scale, scale) for _ in all_vars] + sol_vals = sp.nsolve(equations, all_vars, guesses, tol=1e-6, maxsteps=1000) + res = {var: float(val) for var, val in zip(all_vars, sol_vals)} + logger.info(f"[GeometryEngine] Strategy 2 (nsolve): SUCCESS on attempt {attempt + 1}.") + return {pid: [float(res.get(vx, 0.0)), float(res.get(vy, 0.0)), float(res.get(vz, 0.0))] + for pid, (vx, vy, vz) in point_vars.items()} + except Exception as e: + logger.debug(f"[GeometryEngine] nsolve attempt {attempt + 1} failed: {e}") + return None + + def _try_lsq(self, equations, all_vars, point_vars, n_vars): + logger.info("[GeometryEngine] Strategy 3 (scipy least-squares): minimizing residuals...") + try: + from scipy.optimize import minimize + eq_funcs = [sp.lambdify(all_vars, eq, 'numpy') for eq in equations] + + def objective(x): + return sum(float(f(*x))**2 for f in eq_funcs) + + best_res, best_val = None, float('inf') + # Increase restarts for better coverage of local minima + for i in range(12): + if i == 0: + x0 = [1.0]*n_vars + elif i < 4: + x0 = [np.random.uniform(-10, 10) for _ in range(n_vars)] + else: + x0 = [np.random.uniform(-100, 100) for _ in range(n_vars)] + + res = minimize(objective, x0, method='L-BFGS-B') + if res.fun < best_val: + best_val, best_res = res.fun, res + if best_val < 1e-6: + break + + TOLERANCE = 1e-4 + logger.info(f"[GeometryEngine] Strategy 3: best residual = {best_val:.2e} (tol={TOLERANCE})") + if best_val < TOLERANCE: + res = {var: float(val) for var, val in zip(all_vars, best_res.x)} + logger.info("[GeometryEngine] Strategy 3 (least-squares): SUCCESS.") + return {pid: [float(res.get(vx, 0)), float(res.get(vy, 0)), float(res.get(vz, 0))] + for pid, (vx, vy, vz) in point_vars.items()} + else: + logger.warning(f"[GeometryEngine] Strategy 3 failed: residual {best_val:.2e} > {TOLERANCE}") + except Exception as e: + logger.error(f"[GeometryEngine] Strategy 3 threw exception: {e}") + return None + + def _try_global(self, equations, all_vars, point_vars, n_vars): + logger.info("[GeometryEngine] Strategy 4 (Differential Evolution): global search...") + try: + from scipy.optimize import differential_evolution + bounds = [(-20, 20)] * n_vars + eq_funcs = [sp.lambdify(all_vars, eq, 'numpy') for eq in equations] + + def obj(x): + s = 0.0 + for f in eq_funcs: + try: + s += float(f(*x))**2 + except: + s += 1e6 + return s + + result = differential_evolution(obj, bounds, maxiter=500, popsize=15, mutation=(0.5, 1), recombination=0.7) + TOLERANCE = 1e-3 + logger.info(f"[GeometryEngine] Strategy 4: best residual = {result.fun:.2e} (tol={TOLERANCE})") + if result.fun < TOLERANCE: + res = {var: float(val) for var, val in zip(all_vars, result.x)} + logger.info("[GeometryEngine] Strategy 4 (global opt): SUCCESS.") + return {pid: [float(res.get(vx, 0)), float(res.get(vy, 0)), float(res.get(vz, 0))] + for pid, (vx, vy, vz) in point_vars.items()} + except Exception as e: + logger.error(f"[GeometryEngine] Strategy 4 threw exception: {e}") + return None + + # ─── Result builder ────────────────────────────────────────────────────── + + def _build_result( + self, + coords: Dict[str, List[float]], + polygon_order: List[str], + circles_meta: List[Dict], + segments_meta: List[List[str]], + lines_meta: List[List[str]], + rays_meta: List[List[str]], + pt_list: List[Point], + ) -> Dict[str, Any]: + """ + Build structured result including drawing phases for the renderer. + + drawing_phases: + Phase 1 — Base shape (main polygon) + Phase 2 — Auxiliary/derived points and segments + """ + all_ids = [p.id for p in pt_list] + + # 1. Infer/clean polygon_order + if not polygon_order: + # Fallback: use all declared point IDs sorted by conventional uppercase order. + # This is far safer than only looking for A/B/C/D. + base_pts = sorted( + all_ids, + key=lambda p: (string.ascii_uppercase.index(p) if p in string.ascii_uppercase else 100, p) + ) + polygon_order = base_pts + + base_ids = [pid for pid in polygon_order if pid in all_ids] + derived_ids = [pid for pid in all_ids if pid not in polygon_order] + + # 2. Collect unique segments to avoid redundancy (AB == BA) + drawn_segments = set() + + def add_segment(p1, p2, target_list): + if p1 == p2: + return + s = frozenset([p1, p2]) + if s not in drawn_segments: + drawn_segments.add(s) + target_list.append([p1, p2]) + + # Phase 1: Main polygon boundary + phase1_segments = [] + if len(base_ids) >= 2: + # Connect in sequence: A-B, B-C, etc. + for i in range(len(base_ids) - 1): + add_segment(base_ids[i], base_ids[i+1], phase1_segments) + + # ONLY close the loop if we have 3 or more points (a real polygon) + if len(base_ids) > 2: + add_segment(base_ids[-1], base_ids[0], phase1_segments) + + # Phase 2: Auxiliary segments from DSL + phase2_segments = [] + for p1, p2 in segments_meta: + add_segment(p1, p2, phase2_segments) + + drawing_phases = [ + { + "phase": 1, + "label": "Hình cơ bản", + "points": base_ids, + "segments": phase1_segments, + } + ] + if derived_ids or phase2_segments: + drawing_phases.append({ + "phase": 2, + "label": "Điểm và đoạn phụ", + "points": derived_ids, + "segments": phase2_segments, + }) + + return { + "coordinates": coords, + "polygon_order": polygon_order, + "circles": circles_meta, + "lines": lines_meta, + "rays": rays_meta, + "drawing_phases": drawing_phases, + } diff --git a/solver/models.py b/solver/models.py new file mode 100644 index 0000000000000000000000000000000000000000..42c26bfa3f7c313938b99a2d9026c1a2c5b217fb --- /dev/null +++ b/solver/models.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel +from typing import List, Dict, Union, Optional + +class Point(BaseModel): + id: str + x: Optional[float] = None + y: Optional[float] = None + z: Optional[float] = None + +class Constraint(BaseModel): + type: str # 'length', 'angle', 'parallel', etc. + targets: List[str] + value: Union[float, str] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d5690edb05a1c8c1a9a8d9d62947b83f249347 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Backend test package (enables ``tests.cases`` imports for pytest).""" diff --git a/tests/cases/__init__.py b/tests/cases/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..355cb06e638eaa24f3f4fd8086bed5f02c78906b --- /dev/null +++ b/tests/cases/__init__.py @@ -0,0 +1 @@ +"""Shared test case definitions for API and orchestrator suites.""" diff --git a/tests/cases/pipeline_cases.py b/tests/cases/pipeline_cases.py new file mode 100644 index 0000000000000000000000000000000000000000..d32551fbe1f919d2f53dc6b1116f7ab5156615e6 --- /dev/null +++ b/tests/cases/pipeline_cases.py @@ -0,0 +1,127 @@ +"""Shared geometry pipeline test cases (single-session, multi-turn).""" + +from __future__ import annotations + +from typing import Any + +QUERIES: list[dict[str, Any]] = [ + { + "id": "Q1", + "text": "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10", + "expect_pts": ["A", "B", "C", "D"], + "expect_phases": 1, + }, + { + "id": "Q2", + "text": "Tam giác ABC có AB=6, BC=8, AC=10", + "expect_pts": ["A", "B", "C"], + "expect_phases": 1, + }, + { + "id": "Q3", + "text": "Cho hình chữ nhật ABCD có AB=10 và AD=20. Gọi M là trung điểm của cạnh AB.", + "expect_pts": ["A", "B", "C", "D", "M"], + "expect_phases": 2, + }, + { + "id": "Q4", + "text": "Cho hình thang ABCD vuông tại A và D. AB=4, CD=8, AD=5.", + "expect_pts": ["A", "B", "C", "D"], + "expect_phases": 1, + }, + { + "id": "Q5", + "text": "Cho hình vuông ABCD có cạnh bằng 6.", + "expect_pts": ["A", "B", "C", "D"], + "expect_phases": 1, + }, + { + "id": "Q6", + "text": "Cho tam giác ABC vuông tại A. AB=3, AC=4. Vẽ đường cao AH.", + "expect_pts": ["A", "B", "C", "H"], + "expect_phases": 2, + }, + { + "id": "Q7", + "text": "Cho hình thoi ABCD có cạnh bằng 5 và góc A bằng 60 độ.", + "expect_pts": ["A", "B", "C", "D"], + "expect_phases": 1, + }, + { + "id": "Q8", + "text": "Cho đường tròn tâm O bán kính bằng 7.", + "expect_pts": ["O"], + "expect_phases": 1, + }, + { + "id": "Q9", + "text": "Cho hình bình hành ABCD có AB=8, AD=6. Gọi E là trung điểm của CD. Vẽ đoạn thẳng AE.", + "expect_pts": ["A", "B", "C", "D", "E"], + "expect_phases": 2, + }, + { + "id": "Q10-Step1", + "text": "Cho hình chữ nhật ABCD có AB=10, AD=5.", + "expect_pts": ["A", "B", "C", "D"], + "expect_phases": 1, + }, + { + "id": "Q11-Video", + "text": "Cho tam giác ABC đều cạnh 5. Vẽ đường tròn ngoại tiếp tam giác.", + "expect_pts": ["A", "B", "C"], + "expect_phases": 2, + "request_video": True, + }, + { + "id": "Q12-3D", + "text": "Cho hình chóp S.ABCD có đáy ABCD là hình vuông cạnh 10, đường cao SO=15 với O là tâm đáy.", + "expect_pts": ["S", "A", "B", "C", "D", "O"], + "expect_phases": 2, + }, +] + +Q10_FOLLOW_UP: dict[str, Any] = { + "id": "Q10-Step2", + "text": "Vẽ thêm đường chéo AC.", + "expect_pts": ["A", "B", "C", "D"], + "expect_phases": 2, +} + +# Second multi-turn flow: follow-up depends on prior triangle definition in the same session. +Q13_HISTORY_STEPS: list[dict[str, Any]] = [ + { + "id": "Q13-Step1", + "text": "Cho tam giác ABC với AB=5, BC=6, AC=7.", + "expect_pts": ["A", "B", "C"], + "expect_phases": 1, + }, + { + "id": "Q13-Step2", + "text": "Tính diện tích tam giác ABC (dùng các cạnh đã nêu ở trên).", + "expect_pts": ["A", "B", "C"], + "expect_phases": 1, + }, +] + + +def validate_q10_step2_dsl(dsl: str) -> bool: + """Multi-turn rectangle + diagonal: merged DSL should still describe polygon and diagonal.""" + if not dsl: + return False + return "POLYGON_ORDER" in dsl and "SEGMENT" in dsl + + +def validate_query_result(q: dict[str, Any], result_data: dict[str, Any]) -> list[str]: + """Return list of validation error strings (empty if pass).""" + errors: list[str] = [] + coords = result_data.get("coordinates", {}) or {} + for pt in q.get("expect_pts", []): + if pt not in coords: + errors.append(f"Missing point {pt}") + if coords and len(coords) > 1 and all(v == [0, 0, 0] for v in coords.values()): + errors.append("All points are at [0,0,0]") + phases = result_data.get("drawing_phases", []) or [] + min_phases = int(q.get("expect_phases", 1)) + if len(phases) < min_phases: + errors.append(f"Expected {min_phases} phases, got {len(phases)}") + return errors diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..855b0e6f604cbb8986b1287cf8955acce852c2d1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,17 @@ +"""Load backend/.env for all pytest runs so integration tests see credentials.""" + +from __future__ import annotations + +import os + +import pytest + + +def pytest_configure(config: pytest.Config) -> None: + try: + from dotenv import load_dotenv + + root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + load_dotenv(os.path.join(root, ".env"), override=False) + except Exception: + pass diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..414aedcdb3b00f107e895072cbe108ae98c11272 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests (real services, workers, LLM).""" diff --git a/tests/integration/test_agents_real.py b/tests/integration/test_agents_real.py new file mode 100644 index 0000000000000000000000000000000000000000..b7469ca3e0702aeb61c00c37f9a47ff13d867cfb --- /dev/null +++ b/tests/integration/test_agents_real.py @@ -0,0 +1,90 @@ +"""Smoke tests for individual agents against real LLM / rules (opt-in via markers).""" + +from __future__ import annotations + +import os + +import pytest + +from agents.geometry_agent import GeometryAgent +from agents.knowledge_agent import KnowledgeAgent +from agents.parser_agent import ParserAgent +from agents.solver_agent import SolverAgent +from solver.dsl_parser import DSLParser + + +def _openrouter_configured() -> bool: + return bool(os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")) + + +@pytest.mark.real_agents +@pytest.mark.asyncio +async def test_parser_agent_real(): + if not _openrouter_configured(): + pytest.skip("OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY not set") + agent = ParserAgent() + out = await agent.process("Cho hình vuông ABCD có cạnh bằng 4.") + assert isinstance(out, dict) + assert out.get("type") in (None, "square", "rectangle", "general") + assert "entities" in out + + +@pytest.mark.real_agents +@pytest.mark.asyncio +async def test_geometry_agent_real(): + if not _openrouter_configured(): + pytest.skip("OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY not set") + agent = GeometryAgent() + semantic = { + "type": "square", + "values": {"side": 4}, + "entities": ["A", "B", "C", "D"], + "analysis": "Hình vuông ABCD cạnh 4", + "input_text": "Cho hình vuông ABCD cạnh 4", + "target_question": None, + } + dsl = await agent.generate_dsl(semantic, previous_dsl=None) + assert isinstance(dsl, str) and len(dsl) > 10 + parser = DSLParser() + try: + points, _constraints, _is_3d = parser.parse(dsl) + except Exception as e: + pytest.fail(f"GeometryAgent output is not parseable DSL: {e}\n---\n{dsl[:800]}") + assert len(points) >= 1, "Expected at least one point from Geometry DSL" + + +@pytest.mark.real_agents +@pytest.mark.asyncio +async def test_solver_agent_real(): + if not _openrouter_configured(): + pytest.skip("OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY not set") + agent = SolverAgent() + semantic = { + "target_question": "Tính diện tích hình vuông ABCD.", + "input_text": "Hình vuông cạnh 4", + } + engine_result = { + "coordinates": { + "A": [0.0, 0.0, 0.0], + "B": [4.0, 0.0, 0.0], + "C": [4.0, 4.0, 0.0], + "D": [0.0, 4.0, 0.0], + } + } + sol = await agent.solve(semantic, engine_result) + assert isinstance(sol, dict) + assert "steps" in sol + assert sol.get("answer") is not None or len(sol.get("steps") or []) > 0 + + +def test_knowledge_agent_augment_semantic_data(): + """Rule-based augmentation; no API key required.""" + agent = KnowledgeAgent() + data = { + "type": "general", + "values": {"AB": 5}, + "input_text": "Cho hình vuông ABCD có cạnh bằng 5.", + } + out = agent.augment_semantic_data(dict(data)) + assert out.get("type") == "square" + assert out.get("values", {}).get("AB") == 5 diff --git a/tests/integration/test_orchestrator_smoke.py b/tests/integration/test_orchestrator_smoke.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e2407edb78043a66f2e3bb4af9a4798acf3e93 --- /dev/null +++ b/tests/integration/test_orchestrator_smoke.py @@ -0,0 +1,35 @@ +"""In-process orchestrator smoke (2 queries) — same stack as API without HTTP.""" + +from __future__ import annotations + +import os +import uuid + +import pytest + +from tests.cases.pipeline_cases import QUERIES + + +def _openrouter_configured() -> bool: + return bool(os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")) + + +@pytest.mark.orchestrator_local +@pytest.mark.real_agents +@pytest.mark.asyncio +async def test_orchestrator_two_queries_smoke(): + if not _openrouter_configured(): + pytest.skip("OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY not set") + + from agents.orchestrator import Orchestrator + + orch = Orchestrator() + # Avoid Q1-style rectangles first: LLM sometimes returns prose instead of DSL. + stable_ids = ("Q5", "Q2") + by_id = {q["id"]: q for q in QUERIES} + for qid in stable_ids: + q = by_id[qid] + jid = str(uuid.uuid4()) + result = await orch.run(text=q["text"], job_id=jid) + assert "error" not in result, f"{qid}: {result.get('error')}" + assert result.get("coordinates"), f"No coordinates for {qid}" diff --git a/tests/integration/test_worker_manim_real.py b/tests/integration/test_worker_manim_real.py new file mode 100644 index 0000000000000000000000000000000000000000..545d5c93fab4481d646dce74942047b6b730202f --- /dev/null +++ b/tests/integration/test_worker_manim_real.py @@ -0,0 +1,100 @@ +"""Real Manim render + Supabase upload via render_geometry_video (opt-in).""" + +from __future__ import annotations + +import os +import uuid + +import pytest + + +@pytest.mark.real_worker_manim +@pytest.mark.slow +def test_render_geometry_video_uploads_and_updates_job(): + if os.getenv("MOCK_VIDEO", "").lower() == "true": + pytest.skip("MOCK_VIDEO must be unset/false for real Manim") + + if os.getenv("RUN_REAL_WORKER_MANIM", "").lower() not in ("1", "true", "yes"): + pytest.skip("Set RUN_REAL_WORKER_MANIM=1 to run Manim + storage integration") + + if not os.getenv("SUPABASE_SERVICE_ROLE_KEY") and not os.getenv("SUPABASE_KEY"): + pytest.skip("Supabase credentials required") + + from app.supabase_client import get_supabase + from worker.tasks import render_geometry_video + + supabase = get_supabase() + # Same default as scripts/prepare_api_test.py when TEST_SUPABASE_USER_ID is unset. + default_test_user = "8cd3adb0-7964-4575-949c-d0cadcd8b679" + user_id = ( + os.getenv("TEST_SUPABASE_USER_ID") or os.getenv("TEST_USER_ID") or default_test_user + ).strip() + if not user_id: + pytest.skip("TEST_SUPABASE_USER_ID or TEST_USER_ID required for session FK") + + session_id = str(uuid.uuid4()) + job_id = str(uuid.uuid4()) + + supabase.table("sessions").insert( + { + "id": session_id, + "user_id": user_id, + "title": "pytest manim integration", + } + ).execute() + + supabase.table("jobs").insert( + { + "id": job_id, + "user_id": user_id, + "session_id": session_id, + "status": "processing", + "input_text": "pytest render_geometry_video", + } + ).execute() + + data = { + "session_id": session_id, + "coordinates": { + "A": [0, 0], + "B": [5, 0], + "C": [5, 5], + "D": [0, 5], + }, + "polygon_order": ["A", "B", "C", "D"], + "drawing_phases": [ + { + "phase": 1, + "label": "Base", + "points": ["A", "B", "C", "D"], + "segments": [["A", "B"], ["B", "C"], ["C", "D"], ["D", "A"]], + } + ], + "semantic_analysis": "pytest square video", + "geometry_dsl": "POINT(A)\nPOINT(B)\nLENGTH(AB,5)", + } + + try: + video_url = render_geometry_video.run(job_id, data) + assert video_url and isinstance(video_url, str), "Expected public video URL" + + job_res = supabase.table("jobs").select("status, result").eq("id", job_id).execute() + assert job_res.data and job_res.data[0].get("status") == "success" + assert job_res.data[0].get("result", {}).get("video_url") + finally: + try: + supabase.table("session_assets").delete().eq("job_id", job_id).execute() + except Exception: + pass + try: + supabase.table("messages").delete().eq("session_id", session_id).execute() + except Exception: + pass + try: + supabase.table("jobs").delete().eq("id", job_id).execute() + except Exception: + pass + try: + supabase.table("sessions").delete().eq("id", session_id).execute() + except Exception: + pass diff --git a/tests/integration/test_worker_ocr_real.py b/tests/integration/test_worker_ocr_real.py new file mode 100644 index 0000000000000000000000000000000000000000..caf5ca8fe0205cbb3c87e3c34584c502eba0ad0e --- /dev/null +++ b/tests/integration/test_worker_ocr_real.py @@ -0,0 +1,26 @@ +"""Celery OCR task against a public image URL (opt-in).""" + +from __future__ import annotations + +import os + +import pytest + + +@pytest.mark.real_worker_ocr +def test_run_ocr_from_url_celery_task(): + if os.getenv("RUN_REAL_WORKER_OCR", "").lower() not in ("1", "true", "yes"): + pytest.skip("Set RUN_REAL_WORKER_OCR=1 to run OCR worker integration") + + if not (os.getenv("CELERY_BROKER_URL") or os.getenv("REDIS_URL")): + pytest.skip("CELERY_BROKER_URL or REDIS_URL required for Celery") + + from worker.ocr_tasks import run_ocr_from_url + + url = os.getenv( + "TEST_OCR_IMAGE_URL", + "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/280px-PNG_transparency_demonstration_1.png", + ) + # .run() executes the task body in-process (same code path as the worker). + text = run_ocr_from_url.run(url) + assert isinstance(text, str) diff --git a/tests/test_3d_solver.py b/tests/test_3d_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..d6eb15c208ee258f0624f912615310463ab57493 --- /dev/null +++ b/tests/test_3d_solver.py @@ -0,0 +1,99 @@ +import pytest +from solver.dsl_parser import DSLParser +from solver.engine import GeometryEngine +from solver.models import Point, Constraint + +def test_solve_square_pyramid(): + """ + Test solving for a square pyramid S.ABCD. + Base ABCD is a square with side 10. + Height SO = 15, where O is the center of ABCD. + """ + dsl = """ + POINT(A, 0, 0, 0) + POINT(B, 10, 0, 0) + POINT(C, 10, 10, 0) + POINT(D, 0, 10, 0) + POINT(S) + POINT(O) + MIDPOINT(M1, AB) + MIDPOINT(M2, AC) + SECTION(O, A, C, 0.5) + LENGTH(SO, 15) + PERPENDICULAR(SO, AC) + PERPENDICULAR(SO, AB) + PYRAMID(S_ABCD) + """ + parser = DSLParser() + engine = GeometryEngine() + + points, constraints, is_3d = parser.parse(dsl) + result = engine.solve(points, constraints, is_3d) + + assert result is not None + coords = result["coordinates"] + + # Check base points + assert coords["A"] == [0.0, 0.0, 0.0] + assert coords["B"] == [10.0, 0.0, 0.0] + assert coords["C"] == [10.0, 10.0, 0.0] + assert coords["D"] == [0.0, 10.0, 0.0] + + # Check center O (should be (5, 5, 0)) + assert coords["O"][0] == pytest.approx(5.0) + assert coords["O"][1] == pytest.approx(5.0) + assert coords["O"][2] == pytest.approx(0.0) + + # Check apex S (should be (5, 5, 15) or (5, 5, -15)) + assert coords["S"][0] == pytest.approx(5.0) + assert coords["S"][1] == pytest.approx(5.0) + assert abs(coords["S"][2]) == pytest.approx(15.0) + +def test_solve_prism(): + """ + Triangular prism ABC_DEF. + Base ABC is right triangle at A. AB=3, AC=4. + Height AD=10. + """ + dsl = """ + POINT(A, 0, 0, 0) + POINT(B, 3, 0, 0) + POINT(C, 0, 4, 0) + POINT(D) + POINT(E) + POINT(F) + LENGTH(AD, 10) + PERPENDICULAR(AD, AB) + PERPENDICULAR(AD, AC) + PRISM(ABC_DEF) + """ + parser = DSLParser() + engine = GeometryEngine() + + points, constraints, is_3d = parser.parse(dsl) + result = engine.solve(points, constraints, is_3d) + + assert result is not None + coords = result["coordinates"] + + # D should be (0, 0, 10) + assert coords["D"][0] == pytest.approx(0.0, abs=1e-3) + assert coords["D"][1] == pytest.approx(0.0, abs=1e-3) + assert abs(coords["D"][2]) == pytest.approx(10.0, rel=1e-4, abs=1e-3) + +def test_explicit_z_zero_on_xy_plane_does_not_force_is_3d(): + """POINT with z=0 must not flip is_3d; 2D triangle stays 2D (regression vs POINT(A,x,y,0) bug).""" + dsl = """ + POINT(A, 0, 0, 0) + POINT(B, 3, 0, 0) + POINT(C, 0, 4, 0) + TRIANGLE(ABC) + """ + parser = DSLParser() + points, constraints, is_3d = parser.parse(dsl) + assert is_3d is False + assert len(points) >= 3 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_advanced_geometry.py b/tests/test_advanced_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f424859b95d3fba30b4ca7750ef864c9c2dab71d --- /dev/null +++ b/tests/test_advanced_geometry.py @@ -0,0 +1,102 @@ +import pytest +import asyncio +import logging +from solver.dsl_parser import DSLParser +from solver.engine import GeometryEngine + +logging.basicConfig(level=logging.DEBUG) + +@pytest.mark.asyncio +async def test_section_internal(): + print("\n--- Test: Section Point (Internal AE=2/3 AC) ---") + dsl = """ + POINT(A) + POINT(B) + POINT(C) + LENGTH(AB, 6) + LENGTH(BC, 6) + ANGLE(B, 90) + SECTION(E, A, C, 0.6667) + """ + parser = DSLParser() + engine = GeometryEngine() + + pts, constraints, is_3d = parser.parse(dsl) + result = engine.solve(pts, constraints, is_3d) + + if result: + coords = result['coordinates'] + print(f" A: {coords['A']}") + print(f" C: {coords['C']}") + print(f" E: {coords['E']}") + + # Verify AE = 0.6667 * AC + import math + def dist(p1, p2): return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2) + + d_ac = dist(coords['A'], coords['C']) + d_ae = dist(coords['A'], coords['E']) + ratio = d_ae / d_ac + print(f" Calculated Ratio AE/AC: {ratio:.4f} (Expected: 0.6667)") + assert abs(ratio - 0.6667) < 1e-4 + else: + print(" ❌ Solve failed") + +@pytest.mark.asyncio +async def test_section_external(): + print("\n--- Test: Section Point (External AE=2*AC) ---") + dsl = """ + POINT(A) + POINT(C) + LENGTH(AC, 5) + SECTION(E, A, C, 2.0) + """ + parser = DSLParser() + engine = GeometryEngine() + + pts, constraints, is_3d = parser.parse(dsl) + result = engine.solve(pts, constraints, is_3d) + + if result: + coords = result['coordinates'] + print(f" A: {coords['A']}") + print(f" C: {coords['C']}") + print(f" E: {coords['E']}") + + import math + def dist(p1, p2): return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2) + d_ac = dist(coords['A'], coords['C']) + d_ae = dist(coords['A'], coords['E']) + print(f" AE: {d_ae}, AC: {d_ac}, Ratio: {d_ae/d_ac}") + assert abs(d_ae/d_ac - 2.0) < 1e-4 + else: + print(" ❌ Solve failed") + +@pytest.mark.asyncio +async def test_line_ray_metadata(): + print("\n--- Test: Line and Ray Metadata ---") + dsl = """ + POINT(A) + POINT(B) + LINE(A, B) + RAY(A, B) + """ + parser = DSLParser() + engine = GeometryEngine() + + pts, constraints, is_3d = parser.parse(dsl) + result = engine.solve(pts, constraints, is_3d) + + if result: + print(f" Lines: {result.get('lines')}") + print(f" Rays: {result.get('rays')}") + assert ['A', 'B'] in result.get('lines', []) + assert ['A', 'B'] in result.get('rays', []) + print(" ✅ Metadata present") + else: + print(" ❌ Solve failed") + +if __name__ == "__main__": + asyncio.run(test_section_internal()) + asyncio.run(test_section_external()) + asyncio.run(test_line_ray_metadata()) diff --git a/tests/test_api_full_suite.py b/tests/test_api_full_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..74d25703e6383f296c5a9fe5f2b34aaa5969d78a --- /dev/null +++ b/tests/test_api_full_suite.py @@ -0,0 +1,224 @@ +import asyncio +import copy +import json +import os +import time + +import httpx +import pytest + +from tests.cases.pipeline_cases import ( + Q10_FOLLOW_UP, + Q13_HISTORY_STEPS, + QUERIES, + validate_q10_step2_dsl, + validate_query_result, +) + +BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000") +USER_ID = os.getenv("TEST_USER_ID") +SESSION_ID = os.getenv("TEST_SESSION_ID") + +test_stats: list[dict] = [] + + +_SOLVER_TRANSIENT = "Solver failed after multiple attempts" + + +async def run_single_api_query(client, q, headers, default_session_id: str | None): + print(f"\n🚀 [RUNNING] {q['id']}: {q['text']}") + start_time = time.time() + + payload = { + "text": q["text"], + "request_video": q.get("request_video", False), + } + + max_rounds = 3 + + try: + for round_idx in range(max_rounds): + if q.get("isolate", True): + session_resp = await client.post("/api/v1/sessions", headers=headers) + if session_resp.status_code != 200: + return { + "id": q["id"], + "query": q["text"], + "success": False, + "error": f"Session creation failed: {session_resp.text}", + } + session_id = session_resp.json()["id"] + else: + session_id = q.get("session_id", default_session_id) + + res = await client.post( + f"/api/v1/sessions/{session_id}/solve", + json=payload, + headers=headers, + ) + if res.status_code != 200: + print(f" ❌ FAILED: Status {res.status_code} - {res.text}") + return { + "id": q["id"], + "query": q["text"], + "success": False, + "error": f"HTTP {res.status_code}: {res.text}", + } + + job_id = res.json()["job_id"] + print(f" ✅ Job Created: {job_id}") + + max_attempts = 45 + result_data = None + last_error = None + for i in range(max_attempts): + await asyncio.sleep(4) + res = await client.get(f"/api/v1/solve/{job_id}", headers=headers) + data = res.json() + status = data.get("status") + print(f" - Polling ({i + 1}): {status}") + + if status == "success": + result_data = data["result"] + break + if status == "error": + last_error = data.get("result", {}).get("error") + print(f" ❌ ERROR: {last_error}") + err_s = str(last_error or "") + if _SOLVER_TRANSIENT in err_s and round_idx < max_rounds - 1: + print( + f" ↻ Retry {round_idx + 2}/{max_rounds} (transient solver/LLM flake)" + ) + result_data = None + break + return { + "id": q["id"], + "query": q["text"], + "success": False, + "error": last_error, + } + + if i == max_attempts - 1: + print(" ❌ TIMEOUT") + return {"id": q["id"], "query": q["text"], "success": False, "error": "Timeout"} + + if result_data is None: + continue + + elapsed = time.time() - start_time + errors = validate_query_result(q, result_data) + + if q.get("request_video") and not result_data.get("video_url"): + print(" ⚠️ Video requested but no URL found (Expected in some test envs)") + + if errors: + print(f" ❌ VALIDATION FAILED: {', '.join(errors)}") + return { + "id": q["id"], + "query": q["text"], + "success": False, + "error": "; ".join(errors), + "elapsed": elapsed, + "result": result_data, + } + + print(f" ✅ PASS ({elapsed:.2f}s)") + return { + "id": q["id"], + "query": q["text"], + "success": True, + "elapsed": elapsed, + "job_id": job_id, + "result": result_data, + } + + raise RuntimeError("run_single_api_query: retry loop fell through (bug)") + + except Exception as e: + print(f" ❌ EXCEPTION: {str(e)}") + return {"id": q["id"], "query": q["text"], "success": False, "error": str(e)} + + +@pytest.mark.real_api +@pytest.mark.slow +@pytest.mark.asyncio +async def test_full_api_suite(): + if not USER_ID or not SESSION_ID: + pytest.fail("TEST_USER_ID and TEST_SESSION_ID must be set") + + global test_stats + test_stats = [] + + headers = {"Authorization": f"Test {USER_ID}"} + + async with httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) as client: + for q in QUERIES: + if q["id"] == "Q10-Step1": + continue + qc = copy.deepcopy(q) + res = await run_single_api_query(client, qc, headers, SESSION_ID) + test_stats.append(res) + + print("\n--- Testing Multi-turn API Flow (Q10) ---") + shared_session_resp = await client.post("/api/v1/sessions", headers=headers) + assert shared_session_resp.status_code == 200 + shared_session = shared_session_resp.json()["id"] + + q10_1 = copy.deepcopy(next(q for q in QUERIES if q["id"] == "Q10-Step1")) + q10_1["session_id"] = shared_session + q10_1["isolate"] = False + res10_1 = await run_single_api_query(client, q10_1, headers, SESSION_ID) + test_stats.append(res10_1) + + if res10_1["success"]: + q10_2 = copy.deepcopy(Q10_FOLLOW_UP) + q10_2["session_id"] = shared_session + q10_2["isolate"] = False + res10_2 = await run_single_api_query(client, q10_2, headers, SESSION_ID) + + if res10_2["success"]: + dsl = res10_2.get("result", {}).get("geometry_dsl", "") or "" + if not validate_q10_step2_dsl(dsl): + res10_2["success"] = False + res10_2["error"] = "DSL did not merge history correctly" + + test_stats.append(res10_2) + + print("\n--- Testing Multi-turn API Flow (Q13 history in one session) ---") + q13_session_resp = await client.post("/api/v1/sessions", headers=headers) + assert q13_session_resp.status_code == 200 + q13_session = q13_session_resp.json()["id"] + + prev_ok = True + for step in Q13_HISTORY_STEPS: + if not prev_ok: + test_stats.append( + { + "id": step["id"], + "query": step["text"], + "success": False, + "error": "Skipped: previous step in Q13 failed", + } + ) + continue + sc = copy.deepcopy(step) + sc["session_id"] = q13_session + sc["isolate"] = False + out = await run_single_api_query(client, sc, headers, SESSION_ID) + test_stats.append(out) + prev_ok = bool(out.get("success")) + + with open("temp_suite_results.json", "w", encoding="utf-8") as f: + json.dump(test_stats, f, ensure_ascii=False, indent=2) + + failures = [r for r in test_stats if not r.get("success")] + if failures: + pytest.fail( + "Suite failures: " + + "; ".join(f"{r.get('id')}: {r.get('error')}" for r in failures[:5]) + + (f" (+{len(failures) - 5} more)" if len(failures) > 5 else "") + ) + + +if __name__ == "__main__": + asyncio.run(test_full_api_suite()) diff --git a/tests/test_api_metadata_real.py b/tests/test_api_metadata_real.py new file mode 100644 index 0000000000000000000000000000000000000000..01618fd0ac7941955cf966f5e81bc0302edd2868 --- /dev/null +++ b/tests/test_api_metadata_real.py @@ -0,0 +1,103 @@ +"""Verify assistant message metadata after process_session_job (Supabase + LLM).""" + +from __future__ import annotations + +import os +import uuid + +import pytest +from dotenv import load_dotenv + +load_dotenv() + +from app.models.schemas import SolveRequest +from app.routers.solve import process_session_job +from app.supabase_client import get_supabase + + +@pytest.mark.real_api +@pytest.mark.asyncio +async def test_metadata_persistence_after_solve(): + if not os.getenv("SUPABASE_SERVICE_ROLE_KEY") and not os.getenv("SUPABASE_KEY"): + pytest.skip("Supabase credentials not configured") + + user_id = os.getenv("TEST_SUPABASE_USER_ID") or os.getenv("TEST_USER_ID") + if not user_id: + pytest.skip("TEST_SUPABASE_USER_ID or TEST_USER_ID required") + + supabase = get_supabase() + session_id = str(uuid.uuid4()) + job_id = str(uuid.uuid4()) + + supabase.table("sessions").insert( + { + "id": session_id, + "user_id": user_id, + "title": "pytest metadata session", + } + ).execute() + + request = SolveRequest( + text="Cho hình chữ nhật ABCD có AB=10, AD=20. Vẽ đường thẳng d đi qua A và B.", + request_video=False, + ) + + supabase.table("jobs").insert( + { + "id": job_id, + "user_id": user_id, + "session_id": session_id, + "status": "processing", + "input_text": request.text, + } + ).execute() + supabase.table("messages").insert( + { + "session_id": session_id, + "role": "user", + "type": "text", + "content": request.text, + "metadata": {}, + } + ).execute() + + try: + await process_session_job(job_id, session_id, request, user_id) + + res = ( + supabase.table("messages") + .select("metadata") + .eq("session_id", session_id) + .eq("role", "assistant") + .order("created_at", desc=True) + .limit(1) + .execute() + ) + + assert res.data, "Expected at least one assistant message" + metadata = res.data[0].get("metadata") or {} + required = [ + "job_id", + "coordinates", + "polygon_order", + "drawing_phases", + "circles", + "lines", + "rays", + ] + missing = [f for f in required if f not in metadata] + assert not missing, f"Missing metadata fields: {missing}" + assert metadata.get("job_id") == job_id + finally: + try: + supabase.table("messages").delete().eq("session_id", session_id).execute() + except Exception: + pass + try: + supabase.table("jobs").delete().eq("session_id", session_id).execute() + except Exception: + pass + try: + supabase.table("sessions").delete().eq("id", session_id).execute() + except Exception: + pass diff --git a/tests/test_api_real_e2e.py b/tests/test_api_real_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f4dd9a69f25aa2c2eb03bde5f4bfed7fe1af00 --- /dev/null +++ b/tests/test_api_real_e2e.py @@ -0,0 +1,77 @@ +import os +import httpx +import time +import pytest +import logging + +# Configuration from environment +BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000") +USER_ID = os.getenv("TEST_USER_ID") +SESSION_ID = os.getenv("TEST_SESSION_ID") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@pytest.mark.smoke +@pytest.mark.real_api +@pytest.mark.asyncio +async def test_api_e2e_flow(): + if not USER_ID or not SESSION_ID: + pytest.fail("TEST_USER_ID and TEST_SESSION_ID must be set") + + auth_headers = {"Authorization": f"Test {USER_ID}"} + + async with httpx.AsyncClient(base_url=BASE_URL, timeout=30.0) as client: + # 1. Health check + print("\n[1/3] Checking API Health...") + res = await client.get("/") + assert res.status_code == 200 + assert "running" in res.json()["message"].lower() + print(" ✅ Health check passed") + + # 2. Submit Solve Request + print(f"\n[2/3] Submitting solve request for session {SESSION_ID}...") + payload = { + "text": "Cho hình chữ nhật ABCD có AB=5, AD=10. Tính diện tích.", + "request_video": False + } + res = await client.post(f"/api/v1/sessions/{SESSION_ID}/solve", json=payload, headers=auth_headers) + + if res.status_code != 200: + print(f" ❌ FAILED: {res.text}") + assert res.status_code == 200 + + data = res.json() + job_id = data["job_id"] + assert job_id is not None + print(f" ✅ Request accepted. Job ID: {job_id}") + + # 3. Polling Job Status + print("\n[3/3] Polling job status...") + max_attempts = 15 + for i in range(max_attempts): + time.sleep(2) # Simple sleep between polls + res = await client.get(f"/api/v1/solve/{job_id}", headers=auth_headers) + assert res.status_code == 200 + job_data = res.json() + status = job_data["status"] + print(f" Attempt {i+1}: Status = {status}") + + if status == "success": + print(" ✅ SUCCESS: API pipeline completed successfully.") + result = job_data.get("result", {}) + assert "coordinates" in result + assert "geometry_dsl" in result + return + + if status == "error": + error_msg = job_data.get("result", {}).get("error", "Unknown error") + pytest.fail(f"Job failed with error: {error_msg}") + + if i == max_attempts - 1: + pytest.fail("Timeout waiting for job completion") + +if __name__ == "__main__": + # This allows running the script directly if needed + import asyncio + asyncio.run(test_api_e2e_flow()) diff --git a/tests/test_chat_image_validate.py b/tests/test_chat_image_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..57535f8f8473b9c5e664a183ffe55b2a1130feb2 --- /dev/null +++ b/tests/test_chat_image_validate.py @@ -0,0 +1,26 @@ +"""Unit tests for chat image validation (no Supabase / FastAPI app import).""" + +import pytest +from fastapi import HTTPException + +from app.chat_image_upload import validate_chat_image_bytes + +_VALID_PNG = b"\x89PNG\r\n\x1a\n" + b"\x00" * 32 + + +def test_validate_png_ok(): + ext, mime = validate_chat_image_bytes("a.png", _VALID_PNG, "image/png") + assert ext == ".png" + assert mime == "image/png" + + +def test_validate_rejects_bad_magic(): + with pytest.raises(HTTPException) as exc: + validate_chat_image_bytes("a.png", b"xxxxxxxxxxxx", "image/png") + assert exc.value.status_code == 400 + + +def test_validate_rejects_empty(): + with pytest.raises(HTTPException) as exc: + validate_chat_image_bytes("a.png", b"", "image/png") + assert exc.value.status_code == 400 diff --git a/tests/test_job_poll.py b/tests/test_job_poll.py new file mode 100644 index 0000000000000000000000000000000000000000..1660b2665a3b010249610d6263e276a78179a90b --- /dev/null +++ b/tests/test_job_poll.py @@ -0,0 +1,32 @@ +"""Job poll normalization for FE contract.""" + +import uuid + +from app.job_poll import normalize_job_row_for_client + + +def test_normalize_adds_job_id_and_parses_result_json_string(): + jid = str(uuid.uuid4()) + row = { + "id": jid, + "status": "success", + "user_id": uuid.uuid4(), + "session_id": uuid.uuid4(), + "result": '{"coordinates": {"A": [0, 1]}}', + "input_text": "x", + } + out = normalize_job_row_for_client(row) + assert out["job_id"] == jid + assert out["id"] == jid + assert out["status"] == "success" + assert isinstance(out["result"], dict) + assert out["result"]["coordinates"]["A"] == [0, 1] + assert isinstance(out["user_id"], str) + assert isinstance(out["session_id"], str) + + +def test_normalize_keeps_dict_result(): + row = {"id": "j1", "status": "processing", "result": None} + out = normalize_job_row_for_client(row) + assert out["job_id"] == "j1" + assert out["result"] is None diff --git a/tests/test_ocr_preview.py b/tests/test_ocr_preview.py new file mode 100644 index 0000000000000000000000000000000000000000..11e0bbe0bcf3e5f346fa553e71ea77e02ef68ae4 --- /dev/null +++ b/tests/test_ocr_preview.py @@ -0,0 +1,100 @@ +"""Tests for POST /api/v1/sessions/{session_id}/ocr_preview (auth + owner + merge).""" + +from __future__ import annotations + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +os.environ.setdefault("ALLOW_TEST_BYPASS", "true") + +from app.main import app # noqa: E402 + +_VALID_SESSION_ID = "00000000-0000-0000-0000-000000000099" + + +@pytest.fixture +def auth_headers(): + return {"Authorization": "Test test-user-ocr-preview"} + + +@pytest.mark.asyncio +async def test_ocr_preview_requires_auth(): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview", + files={"file": ("t.png", b"\x89PNG\r\n\x1a\n", "image/png")}, + data={"user_message": "hello"}, + ) + assert res.status_code == 401 + + +@pytest.mark.asyncio +async def test_ocr_preview_forbidden_when_not_owner(auth_headers): + with patch("app.routers.solve.session_owned_by_user", return_value=False): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview", + headers=auth_headers, + files={"file": ("t.png", b"\x89PNG\r\n\x1a\n", "image/png")}, + data={"user_message": "note"}, + ) + assert res.status_code == 403 + + +@pytest.mark.asyncio +async def test_ocr_preview_success_merges_draft(auth_headers): + mock_orch = MagicMock() + mock_orch.ocr_agent.process_image = AsyncMock(return_value="OCR_LINE") + + with ( + patch("app.routers.solve.session_owned_by_user", return_value=True), + patch("app.routers.solve.get_orchestrator", return_value=mock_orch), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview", + headers=auth_headers, + files={"file": ("t.png", b"\x89PNG\r\n\x1a\n", "image/png")}, + data={"user_message": " my note "}, + ) + assert res.status_code == 200, res.text + data = res.json() + assert data["ocr_text"] == "OCR_LINE" + assert data["user_message"] == "my note" + assert data["combined_draft"] == "my note\n\nOCR_LINE" + mock_orch.ocr_agent.process_image.assert_called_once() + + +@pytest.mark.asyncio +async def test_ocr_preview_rejects_oversized_file(auth_headers): + mock_orch = MagicMock() + mock_orch.ocr_agent.process_image = AsyncMock(return_value="") + + big = b"x" * (11 * 1024 * 1024) + with ( + patch("app.routers.solve.session_owned_by_user", return_value=True), + patch("app.routers.solve.get_orchestrator", return_value=mock_orch), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview", + headers=auth_headers, + files={"file": ("huge.png", big, "image/png")}, + ) + assert res.status_code == 413 + mock_orch.ocr_agent.process_image.assert_not_called() + + +@pytest.mark.asyncio +async def test_ocr_preview_rejects_empty_file(auth_headers): + with patch("app.routers.solve.session_owned_by_user", return_value=True): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/ocr_preview", + headers=auth_headers, + files={"file": ("empty.png", b"", "image/png")}, + ) + assert res.status_code == 400 diff --git a/tests/test_real_llm.py b/tests/test_real_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..3658696404965051700d5e1de84b30bacb1bc4dd --- /dev/null +++ b/tests/test_real_llm.py @@ -0,0 +1,44 @@ +import asyncio +import logging +import os + +import pytest +from dotenv import load_dotenv + +from app.llm_client import get_llm_client + +logging.basicConfig(level=logging.INFO) +load_dotenv() + + +def _openrouter_configured() -> bool: + return bool(os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")) + + +@pytest.mark.real_agents +@pytest.mark.asyncio +async def test_real_llm(): + if not _openrouter_configured(): + pytest.skip("OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY not set") + + client = get_llm_client() + if getattr(client, "client", None) is None: + pytest.skip("LLM client not configured") + + content = await client.chat_completions_create( + messages=[ + { + "role": "system", + "content": ( + "You are a Geometry Expert. Give a short step-by-step reasoning for the distance " + "between midpoints M of AB and N of AD in rectangle ABCD with AB=10 and AD=20." + ), + }, + {"role": "user", "content": "Solve briefly."}, + ] + ) + assert isinstance(content, str) and len(content.strip()) > 20 + + +if __name__ == "__main__": + asyncio.run(test_real_llm()) diff --git a/tests/test_schema_helpers.py b/tests/test_schema_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..d22731058cda0ab4f9d9054bf729645ac5d91d59 --- /dev/null +++ b/tests/test_schema_helpers.py @@ -0,0 +1,10 @@ +"""Unit tests that avoid importing the full FastAPI app (no Supabase).""" + +from app.ocr_text_merge import build_combined_ocr_preview_draft + + +def test_build_combined_ocr_preview_draft(): + assert build_combined_ocr_preview_draft(None, "only ocr") == "only ocr" + assert build_combined_ocr_preview_draft("", "only ocr") == "only ocr" + assert build_combined_ocr_preview_draft(" caption ", "") == "caption" + assert build_combined_ocr_preview_draft("a", "b") == "a\n\nb" diff --git a/tests/test_solve_multipart.py b/tests/test_solve_multipart.py new file mode 100644 index 0000000000000000000000000000000000000000..73e2d2b227eb7f8935b6240ad6d60abdd30ad49c --- /dev/null +++ b/tests/test_solve_multipart.py @@ -0,0 +1,119 @@ +"""Tests for POST /api/v1/sessions/{session_id}/solve_multipart.""" + +from __future__ import annotations + +import os +from unittest.mock import MagicMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +os.environ.setdefault("ALLOW_TEST_BYPASS", "true") + +from app.main import app # noqa: E402 +from app.models.schemas import SolveResponse # noqa: E402 + +_VALID_SESSION_ID = "00000000-0000-0000-0000-000000000088" + +# PNG signature + padding (>= 12 bytes) for magic check in validate_chat_image_bytes +_VALID_PNG_BODY = b"\x89PNG\r\n\x1a\n" + b"\x00" * 32 + + +@pytest.fixture +def auth_headers(): + return {"Authorization": "Test test-user-solve-mp"} + + +@pytest.mark.asyncio +async def test_solve_multipart_requires_auth(): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart", + files={"file": ("t.png", _VALID_PNG_BODY, "image/png")}, + data={"text": "hi"}, + ) + assert res.status_code == 401 + + +@pytest.mark.asyncio +async def test_solve_multipart_forbidden(auth_headers): + with patch("app.routers.solve.session_owned_by_user", return_value=False): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart", + headers=auth_headers, + files={"file": ("t.png", _VALID_PNG_BODY, "image/png")}, + data={"text": "hi"}, + ) + assert res.status_code == 403 + + +@pytest.mark.asyncio +async def test_solve_multipart_empty_text(auth_headers): + with patch("app.routers.solve.session_owned_by_user", return_value=True): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart", + headers=auth_headers, + files={"file": ("t.png", _VALID_PNG_BODY, "image/png")}, + data={"text": " "}, + ) + assert res.status_code == 400 + + +@pytest.mark.asyncio +async def test_solve_multipart_bad_magic(auth_headers): + with patch("app.routers.solve.session_owned_by_user", return_value=True): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart", + headers=auth_headers, + files={"file": ("t.png", b"not-a-real-png!!", "image/png")}, + data={"text": "problem text"}, + ) + assert res.status_code == 400 + + +@pytest.mark.asyncio +async def test_solve_multipart_upload_then_enqueue(auth_headers): + up = { + "public_url": "https://example.test/bucket/sessions/s1/image_v1_j.png", + "storage_path": f"sessions/{_VALID_SESSION_ID}/image_v1_job.png", + "version": 1, + "session_asset_id": "00000000-0000-0000-0000-000000000099", + } + captured = {} + + def fake_enqueue(supabase, background_tasks, session_id, user_id, uid, request, message_metadata, job_id): + captured["metadata"] = message_metadata + captured["job_id"] = job_id + captured["request"] = request + return SolveResponse(job_id=job_id, status="processing") + + with ( + patch("app.routers.solve.session_owned_by_user", return_value=True), + patch("app.routers.solve.upload_session_chat_image", return_value=up) as up_mock, + patch("app.routers.solve._enqueue_solve_common", side_effect=fake_enqueue), + ): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + res = await client.post( + f"/api/v1/sessions/{_VALID_SESSION_ID}/solve_multipart", + headers=auth_headers, + files={"file": ("t.png", _VALID_PNG_BODY, "image/png")}, + data={"text": " my problem "}, + ) + assert res.status_code == 200, res.text + data = res.json() + assert data["status"] == "processing" + jid = data["job_id"] + assert jid + up_mock.assert_called_once() + call_args = up_mock.call_args[0] + assert call_args[0] == _VALID_SESSION_ID + assert call_args[1] == jid + assert len(call_args[2]) == len(_VALID_PNG_BODY) + att = captured["metadata"].get("attachment", {}) + assert att.get("size_bytes") == len(_VALID_PNG_BODY) + assert att.get("public_url") == up["public_url"] + assert captured["request"].text == "my problem" + assert captured["request"].image_url == up["public_url"] diff --git a/tests/test_solver.py b/tests/test_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..555d1dbc10a5e9dc79c1b94e8201d81b9f10e463 --- /dev/null +++ b/tests/test_solver.py @@ -0,0 +1,44 @@ +import sys +import os +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +from solver.engine import GeometryEngine +from solver.models import Point, Constraint + +def test_triangle_abc(): + engine = GeometryEngine() + + # Triangle ABC: AB=5, AC=7, angle A=60 + points = [ + Point(id="A"), + Point(id="B"), + Point(id="C") + ] + + constraints = [ + Constraint(type="length", targets=["A", "B"], value=5.0), + Constraint(type="length", targets=["A", "C"], value=7.0), + Constraint(type="angle", targets=["A"], value=60.0) # Angle at A + ] + + print("Solving for Triangle ABC (AB=5, AC=7, angle A=60)...") + results = engine.solve(points, constraints) + + if results: + coords = results["coordinates"] + print("Success! Coordinates:") + for pid, c in coords.items(): + print(f"Point {pid}: {c}") + + # Verify distance AB + dist_ab = ((coords["B"][0] - coords["A"][0])**2 + (coords["B"][1] - coords["A"][1])**2)**0.5 + print(f"Verified AB distance: {dist_ab:.2f}") + + # Verify distance AC + dist_ac = ((coords["C"][0] - coords["A"][0])**2 + (coords["C"][1] - coords["A"][1])**2)**0.5 + print(f"Verified AC distance: {dist_ac:.2f}") + else: + print("Solver failed.") + +if __name__ == "__main__": + test_triangle_abc() diff --git a/tests/verify_db_metadata.py b/tests/verify_db_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..04a0be951da10dfd4b4dbbe0e8da6cb9e90e1b62 --- /dev/null +++ b/tests/verify_db_metadata.py @@ -0,0 +1,39 @@ +import os +import json +from app.supabase_client import get_supabase + +def verify_metadata(): + supabase = get_supabase() + + # Get the 5 most recent assistant messages + res = supabase.table("messages") \ + .select("id, role, content, metadata, created_at") \ + .eq("role", "assistant") \ + .order("created_at", desc=True) \ + .limit(5) \ + .execute() + + if not res.data: + print("No assistant messages found.") + return + + for i, msg in enumerate(res.data): + print(f"\n--- Message {i+1} (ID: {msg['id']}, Created: {msg['created_at']}) ---") + metadata = msg.get("metadata", {}) + + required_fields = ["job_id", "coordinates", "polygon_order", "drawing_phases", "circles"] + missing = [f for f in required_fields if f not in metadata] + + if not missing: + print("✅ All mandatory fields present in metadata.") + # Print a snippet of the data + print(f" - job_id: {metadata.get('job_id')}") + print(f" - polygon_order: {metadata.get('polygon_order')}") + print(f" - drawing_phases count: {len(metadata.get('drawing_phases', []))}") + print(f" - circles count: {len(metadata.get('circles', []))}") + else: + print(f"❌ Missing fields in metadata: {missing}") + print(f" Metadata keys: {list(metadata.keys())}") + +if __name__ == "__main__": + verify_metadata() diff --git a/vision_ocr/__init__.py b/vision_ocr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..495a12157a66b17d7c1b79475b6576b361323593 --- /dev/null +++ b/vision_ocr/__init__.py @@ -0,0 +1,6 @@ +"""Vision-only OCR (YOLO layout load / PaddleOCR / Pix2Tex). No LLM — safe for dedicated OCR workers.""" + +from .compat import allow_ultralytics_weights +from .pipeline import OcrVisionPipeline + +__all__ = ["OcrVisionPipeline", "allow_ultralytics_weights"] diff --git a/vision_ocr/compat.py b/vision_ocr/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..cccb635a3b7298de7ebc06ca9ba8f09ec996ed45 --- /dev/null +++ b/vision_ocr/compat.py @@ -0,0 +1,33 @@ +"""PyTorch 2.6+ defaults weights_only=True; Ultralytics YOLO .pt checkpoints unpickle full nn graphs (trusted official weights).""" + +from __future__ import annotations + +import functools + +_torch_load_patched = False + + +def allow_ultralytics_weights() -> None: + """ + Official yolov8n.pt is a trusted checkpoint. PyTorch 2.6+ safe unpickling would require + allowlisting many torch.nn globals; loading with weights_only=False matches Ultralytics + upstream behavior for local .pt files. + """ + global _torch_load_patched + if _torch_load_patched: + return + try: + import torch + + _orig = torch.load + + @functools.wraps(_orig) + def _load(*args, **kwargs): + if "weights_only" not in kwargs: + kwargs["weights_only"] = False + return _orig(*args, **kwargs) + + torch.load = _load + _torch_load_patched = True + except Exception: + pass diff --git a/vision_ocr/pipeline.py b/vision_ocr/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..33b741cbdcfb036130b935ec0e74f37ed7f4a77e --- /dev/null +++ b/vision_ocr/pipeline.py @@ -0,0 +1,245 @@ +"""OCR vision stack only (no LLM). Used by Celery OCR worker and composed by ``agents.ocr_agent``.""" + +from __future__ import annotations + +import logging +import os +import uuid +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np + +from .compat import allow_ultralytics_weights + +logger = logging.getLogger(__name__) + +_OCR_MAX_EDGE = 2000 +_CROP_PAD = 4 + + +class OcrVisionPipeline: + """ + Hybrid pipeline: + 1. YOLO for layout analysis (weights preloaded; layout path reserved). + 2. PaddleOCR for Vietnamese text extraction. + 3. Pix2Tex for LaTeX formula extraction. + """ + + def __init__(self) -> None: + logger.info("[OcrVisionPipeline] Initializing engines...") + + try: + from ultralytics import YOLO + + allow_ultralytics_weights() + logger.info("[OcrVisionPipeline] Loading YOLO...") + self.layout_model = YOLO("yolov8n.pt") + logger.info("[OcrVisionPipeline] YOLO initialized.") + except Exception as e: + logger.error("[OcrVisionPipeline] YOLO init failed: %s", e) + self.layout_model = None + + try: + from paddleocr import PaddleOCR + + logger.info("[OcrVisionPipeline] Loading PaddleOCR...") + self.text_model = PaddleOCR(use_angle_cls=True, lang="vi") + logger.info("[OcrVisionPipeline] PaddleOCR (vi) initialized.") + except Exception as e: + logger.error("[OcrVisionPipeline] PaddleOCR init failed: %s", e) + self.text_model = None + + try: + from pix2tex.cli import LatexOCR + + logger.info("[OcrVisionPipeline] Loading Pix2Tex...") + self.math_model = LatexOCR() + logger.info("[OcrVisionPipeline] Pix2Tex initialized.") + except Exception as e: + logger.error("[OcrVisionPipeline] Pix2Tex init failed: %s", e) + self.math_model = None + + def _preprocess_image_for_ocr(self, src_path: str) -> Tuple[str, bool]: + """Resize large images, CLAHE on luminance; returns path (may be new temp file).""" + img = cv2.imread(src_path, cv2.IMREAD_COLOR) + if img is None: + g = cv2.imread(src_path, cv2.IMREAD_GRAYSCALE) + if g is None: + logger.warning("[OcrVisionPipeline] OpenCV could not read %s; using original.", src_path) + return src_path, False + img = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR) + h, w = img.shape[:2] + max_dim = max(h, w) + if max_dim > _OCR_MAX_EDGE: + scale = _OCR_MAX_EDGE / max_dim + img = cv2.resize( + img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA + ) + logger.info("[OcrVisionPipeline] Resized for OCR to max edge %s", _OCR_MAX_EDGE) + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray) + den = cv2.fastNlMeansDenoising(gray, None, 8, 7, 21) + out = f"temp_ocr_prep_{uuid.uuid4().hex}.png" + cv2.imwrite(out, den) + return out, True + + def _load_bgr_for_crops(self, path: str) -> Optional[np.ndarray]: + im = cv2.imread(path, cv2.IMREAD_COLOR) + if im is None: + g = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + if g is None: + return None + im = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR) + return im + + def _crop_from_quad(self, img_bgr: np.ndarray, bbox) -> Optional[np.ndarray]: + try: + pts = np.array(bbox, dtype=np.float32) + xs = pts[:, 0] + ys = pts[:, 1] + H, W = img_bgr.shape[:2] + x1 = max(0, int(xs.min()) - _CROP_PAD) + y1 = max(0, int(ys.min()) - _CROP_PAD) + x2 = min(W, int(xs.max()) + _CROP_PAD) + y2 = min(H, int(ys.max()) + _CROP_PAD) + if x2 <= x1 or y2 <= y1: + return None + return img_bgr[y1:y2, x1:x2].copy() + except Exception as e: + logger.debug("[OcrVisionPipeline] crop failed: %s", e) + return None + + def _latex_from_crop_bgr(self, crop_bgr: np.ndarray) -> Optional[str]: + if self.math_model is None or crop_bgr is None or crop_bgr.size == 0: + return None + ch, cw = crop_bgr.shape[:2] + if ch < 10 or cw < 10: + return None + try: + from PIL import Image + + rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) + pil = Image.fromarray(rgb) + out = self.math_model(pil) + if isinstance(out, str) and out.strip(): + return out.strip() + except Exception as e: + logger.debug("[OcrVisionPipeline] Pix2Tex on crop failed: %s", e) + return None + + def _maybe_math_from_crop(self, img_bgr: Optional[np.ndarray], bbox, text: str) -> str: + if img_bgr is None or not self.math_model: + return text + is_math_hint = any( + c in text for c in ["\\", "^", "_", "{", "}", "=", "+", "-", "*", "/"] + ) + if not is_math_hint: + return text + crop = self._crop_from_quad(img_bgr, bbox) + latex = self._latex_from_crop_bgr(crop) if crop is not None else None + if latex: + logger.info("[OcrVisionPipeline] Pix2Tex replaced line fragment (len=%s)", len(latex)) + return f"${latex}$" + return text + + async def process_image(self, image_path: str) -> str: + """Return assembled raw OCR text (no LLM).""" + logger.info("==[OcrVisionPipeline] Processing: %s==", image_path) + + if not os.path.exists(image_path): + return f"Error: File {image_path} not found." + + prep_path, prep_cleanup = self._preprocess_image_for_ocr(image_path) + paddle_path = prep_path if prep_cleanup else image_path + img_bgr = self._load_bgr_for_crops(prep_path if prep_cleanup else image_path) + + raw_fragments: List[Dict[str, Any]] = [] + + try: + if self.text_model: + logger.info("[OcrVisionPipeline] Running PaddleOCR on %s...", paddle_path) + result = self.text_model.ocr(paddle_path) + logger.info("[OcrVisionPipeline] PaddleOCR raw result: %s", result) + + if not result: + logger.warning("[OcrVisionPipeline] PaddleOCR returned no results.") + return "" + + if isinstance(result[0], dict): + res_dict = result[0] + rec_texts = res_dict.get("rec_texts", []) + rec_scores = res_dict.get("rec_scores", []) + rec_polys = res_dict.get("rec_polys", []) + + for i in range(len(rec_texts)): + text = rec_texts[i] + bbox = rec_polys[i] + score = rec_scores[i] if i < len(rec_scores) else None + if score is not None and float(score) < 0.45: + logger.debug( + "[OcrVisionPipeline] Low-confidence line (score=%s): %s", + score, + text[:80], + ) + + y_top = int(min(p[1] for p in bbox)) if hasattr(bbox, "__iter__") else 0 + content = self._maybe_math_from_crop(img_bgr, bbox, text) + raw_fragments.append({"y": y_top, "content": content, "type": "text"}) + elif isinstance(result[0], list): + for line in result[0]: + bbox = line[0] + text = line[1][0] + score = line[1][1] if len(line[1]) > 1 else None + if score is not None and float(score) < 0.45: + logger.debug( + "[OcrVisionPipeline] Low-confidence line (score=%s): %s", + score, + text[:80], + ) + + y_top = int(bbox[0][1]) + content = self._maybe_math_from_crop(img_bgr, bbox, text) + raw_fragments.append({"y": y_top, "content": content, "type": "text"}) + finally: + if prep_cleanup and os.path.exists(prep_path): + try: + os.remove(prep_path) + except OSError: + pass + + raw_fragments.sort(key=lambda x: x["y"]) + combined_text = "\n".join([f["content"] for f in raw_fragments]) + + logger.info( + "[OcrVisionPipeline] Raw OCR output assembled:\n---\n%s\n---", combined_text + ) + + if not combined_text.strip(): + logger.warning("[OcrVisionPipeline] No text detected.") + return "" + + return combined_text + + async def process_url(self, url: str) -> str: + """Download image and run ``process_image`` (raw only).""" + import httpx + + from app.url_utils import sanitize_url + + url = sanitize_url(url) + if not url: + return "Error: Empty image URL after cleanup." + + async with httpx.AsyncClient() as client: + resp = await client.get(url) + if resp.status_code == 200: + temp_path = "temp_url_image.png" + with open(temp_path, "wb") as f: + f.write(resp.content) + try: + return await self.process_image(temp_path) + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + return f"Error: Failed to fetch image from URL {url}" diff --git a/worker/asset_manager.py b/worker/asset_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..373068effee2eac53bc61b475d0e235366265a88 --- /dev/null +++ b/worker/asset_manager.py @@ -0,0 +1,72 @@ +import os +import uuid +import logging +from typing import Tuple +from app.supabase_client import get_supabase + +logger = logging.getLogger(__name__) + +def get_next_version(session_id: str, asset_type: str = "video") -> int: + """ + Query session_assets to find the latest version for this session/type. + """ + supabase = get_supabase() + try: + res = ( + supabase.table("session_assets") + .select("version") + .eq("session_id", session_id) + .eq("asset_type", asset_type) + .order("version", desc=True) + .limit(1) + .execute() + ) + if res.data: + return res.data[0]["version"] + 1 + return 1 + except Exception as e: + logger.error(f"Error fetching version: {e}") + return 1 + +def upload_session_asset( + session_id: str, + job_id: str, + file_bytes: bytes, + asset_type: str, + ext: str +) -> Tuple[str, str]: + """ + Upload file to Supabase Storage with versioned path and record in session_assets. + Returns (storage_path, public_url). + """ + supabase = get_supabase() + bucket_name = os.getenv("SUPABASE_BUCKET", "video") + + version = get_next_version(session_id, asset_type) + + # Structure: sessions/{session_id}/{asset_type}_v{version}_{job_id}.{ext} + file_name = f"{asset_type}_v{version}_{job_id}.{ext}" + storage_path = f"sessions/{session_id}/{file_name}" + + # 1. Upload to Storage + content_type = "video/mp4" if ext == "mp4" else "image/png" + supabase.storage.from_(bucket_name).upload( + path=storage_path, + file=file_bytes, + file_options={"content-type": content_type} + ) + + # 2. Get Public URL + public_url = supabase.storage.from_(bucket_name).get_public_url(storage_path) + + # 3. Record in DB + supabase.table("session_assets").insert({ + "session_id": session_id, + "job_id": job_id, + "asset_type": asset_type, + "storage_path": storage_path, + "public_url": public_url, + "version": version + }).execute() + + return storage_path, public_url diff --git a/worker/celery_app.py b/worker/celery_app.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d452682ea17b645dc6a146026d08fb96f0f9d3 --- /dev/null +++ b/worker/celery_app.py @@ -0,0 +1,75 @@ +import os +from celery import Celery +from dotenv import load_dotenv + +from app.url_utils import sanitize_env + +# Load environment variables early +load_dotenv() + +from app.runtime_env import apply_runtime_env_defaults + +apply_runtime_env_defaults() + +from app.logging_setup import setup_application_logging + +setup_application_logging() + + +def _celery_include_modules() -> list[str]: + """ + Load only task modules for queues this process consumes (see CELERY_WORKER_QUEUES). + OCR-only Spaces must not import worker.tasks (Manim / Supabase render path). + """ + raw = (os.getenv("CELERY_WORKER_QUEUES") or "").strip().lower() + if not raw: + return ["worker.tasks", "worker.ocr_tasks"] + parts = [p.strip() for p in raw.split(",") if p.strip()] + seen: set[str] = set() + out: list[str] = [] + for p in parts: + mod = None + if p == "render": + mod = "worker.tasks" + elif p == "ocr": + mod = "worker.ocr_tasks" + if mod and mod not in seen: + seen.add(mod) + out.append(mod) + return out if out else ["worker.tasks", "worker.ocr_tasks"] + + +_broker_raw = os.getenv("CELERY_BROKER_URL") or os.getenv("REDIS_URL") or "redis://localhost:6379/0" +_backend_raw = os.getenv("CELERY_RESULT_BACKEND") or os.getenv("REDIS_URL") or "redis://localhost:6379/1" + +BROKER_URL = sanitize_env(_broker_raw) or _broker_raw.strip() +RESULT_BACKEND = sanitize_env(_backend_raw) or _backend_raw.strip() + +celery_app = Celery( + "math_solver", + broker=BROKER_URL, + backend=RESULT_BACKEND, + include=_celery_include_modules(), +) + +# Fix for SSL if using rediss:// +if BROKER_URL.startswith("rediss://"): + celery_app.conf.broker_use_ssl = { + "ssl_cert_reqs": "none", + } +if RESULT_BACKEND.startswith("rediss://"): + celery_app.conf.redis_backend_use_ssl = { + "ssl_cert_reqs": "none", + } + +celery_app.conf.update( + task_serializer="json", + accept_content=["json"], + result_serializer="json", + timezone="UTC", + enable_utc=True, + task_routes={ + "worker.tasks.render_geometry_video": {"queue": "render"}, + "worker.ocr_tasks.run_ocr_from_url": {"queue": "ocr"}, + }, +) diff --git a/worker/ocr_tasks.py b/worker/ocr_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..a02fd01c76757ca39f79d09901fe90dac9ca5b2f --- /dev/null +++ b/worker/ocr_tasks.py @@ -0,0 +1,25 @@ +"""Celery tasks for OCR-only worker queue (`ocr`).""" + +from __future__ import annotations + +import asyncio +import logging + +from worker.celery_app import celery_app + +logger = logging.getLogger(__name__) + + +@celery_app.task(name="worker.ocr_tasks.run_ocr_from_url") +def run_ocr_from_url(image_url: str) -> str: + """ + Download image from public URL and run OCR models only (YOLO / PaddleOCR / Pix2Tex). + LLM post-processing runs on the API via ``OCRAgent.refine_with_llm`` after the result is returned. + """ + from vision_ocr.pipeline import OcrVisionPipeline + + pipeline = OcrVisionPipeline() + logger.info("[run_ocr_from_url] starting OCR for url host=%s", image_url.split("/")[2] if "/" in image_url else "?") + text = asyncio.run(pipeline.process_url(image_url)) + logger.info("[run_ocr_from_url] done, text_len=%s", len(text or "")) + return text if text is not None else "" diff --git a/worker/tasks.py b/worker/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..adfa5de889346f49331b1d96fe448bcaac1fcfa6 --- /dev/null +++ b/worker/tasks.py @@ -0,0 +1,76 @@ +import os +from .celery_app import celery_app +from geometry_render.renderer import RendererAgent +from app.supabase_client import get_supabase +from .asset_manager import upload_session_asset + +@celery_app.task(name="worker.tasks.render_geometry_video") +def render_geometry_video(job_id: str, data: dict): + renderer = RendererAgent() + supabase = get_supabase() + session_id = data.get("session_id") + + # 1. Generate Manim script + script = renderer.generate_manim_script(data) + + # 2. Run Manim and get local video file path + video_local_path = renderer.run_manim(script, job_id) + + if not video_local_path or not os.path.exists(video_local_path): + raise Exception(f"Manim rendering failed for job {job_id}") + + try: + with open(video_local_path, "rb") as f: + video_content = f.read() + + # 3. Upload to Supabase using Asset Manager (Versioning) + # If no session_id (unlikely in this flow), fallback to simple upload + if session_id: + storage_path, video_url = upload_session_asset( + session_id=session_id, + job_id=job_id, + file_bytes=video_content, + asset_type="video", + ext="mp4" + ) + else: + # Legacy fallback + bucket_name = os.getenv("SUPABASE_BUCKET", "video") + file_name = f"{job_id}.mp4" + supabase.storage.from_(bucket_name).upload(path=file_name, file=video_content) + video_url = supabase.storage.from_(bucket_name).get_public_url(file_name) + + # 4. Update Job status and Final Result in Supabase Database + final_result = data.copy() + final_result["video_url"] = video_url + + supabase.table("jobs").update({ + "status": "success", + "result": final_result + }).eq("id", job_id).execute() + + # 5. Save message history (Assistant answer) + if session_id: + supabase.table("messages").insert({ + "session_id": session_id, + "role": "assistant", + "type": "analysis", + "content": data.get("semantic_analysis", "🎬 Video minh họa đã sẵn sàng."), + "metadata": { + "job_id": job_id, + "video_url": video_url, + "coordinates": data.get("coordinates"), + "geometry_dsl": data.get("geometry_dsl"), + "polygon_order": data.get("polygon_order", []), + "drawing_phases": data.get("drawing_phases", []), + "circles": data.get("circles", []), + "lines": data.get("lines", []), + "rays": data.get("rays", []), + } + }).execute() + + return video_url + finally: + # 6. Cleanup local file + if os.path.exists(video_local_path): + os.remove(video_local_path) diff --git a/worker_health.py b/worker_health.py new file mode 100644 index 0000000000000000000000000000000000000000..35ae72fe359a282355f7eba79a3ac12aca99390d --- /dev/null +++ b/worker_health.py @@ -0,0 +1,52 @@ +import os +import subprocess +from contextlib import asynccontextmanager + +from dotenv import load_dotenv +from fastapi import FastAPI +import uvicorn + +load_dotenv() +from app.runtime_env import apply_runtime_env_defaults + +apply_runtime_env_defaults() + +from app.logging_setup import setup_application_logging + +setup_application_logging() + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Start Celery worker in the background + print("🚀 Starting Celery worker in background...") + # Using subprocess.Popen to avoid blocking the main thread + # Must match task_routes in worker/celery_app.py (queues: render, ocr). No Celery "solve" queue — solve runs on the API. + queues = os.environ.get("CELERY_WORKER_QUEUES", "render").strip() or "render" + print(f"🧵 Celery consuming queues: {queues}", flush=True) + process = subprocess.Popen( + [ + "celery", + "-A", + "worker.celery_app", + "worker", + "--loglevel=info", + "--concurrency=1", # minimize RAM spikes on HF Spaces + "-Q", + queues, + ] + ) + yield + # Cleanup + print("🛑 Shutting down Celery worker...") + process.terminate() + +app = FastAPI(lifespan=lifespan) + +@app.get("/") +def health_check(): + return {"status": "ok", "worker": "running"} + +if __name__ == "__main__": + port = int(os.environ.get("PORT", 7860)) + print(f"📡 Starting Health Check API on port {port}...", flush=True) + uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")