Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| import os | |
| import uuid | |
| from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile | |
| from agents.orchestrator import Orchestrator | |
| from app.chat_image_upload import upload_session_chat_image, validate_chat_image_bytes | |
| from app.ocr_celery import ocr_celery_enabled | |
| from app.ocr_local_file import ocr_from_local_image_path | |
| from app.dependencies import get_current_user_id | |
| from app.errors import format_error_for_user | |
| from app.logutil import log_pipeline_failure, log_pipeline_success, log_step | |
| from app.models.schemas import ( | |
| OcrPreviewResponse, | |
| RenderVideoRequest, | |
| RenderVideoResponse, | |
| SolveRequest, | |
| SolveResponse, | |
| ) | |
| from app.ocr_text_merge import build_combined_ocr_preview_draft | |
| from app.session_cache import invalidate_for_user, session_owned_by_user | |
| from app.supabase_client import get_supabase | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/api/v1/sessions", tags=["Solve"]) | |
| # Eager init: all agents and models load at import time (also run in Docker build via scripts/prewarm_models.py). | |
| ORCHESTRATOR = Orchestrator() | |
| def get_orchestrator() -> Orchestrator: | |
| return ORCHESTRATOR | |
| _OCR_PREVIEW_MAX_BYTES = 10 * 1024 * 1024 | |
| def _assert_session_owner(supabase, session_id: str, user_id, uid: str, op: str) -> None: | |
| def owns() -> bool: | |
| res = ( | |
| supabase.table("sessions") | |
| .select("id") | |
| .eq("id", session_id) | |
| .eq("user_id", user_id) | |
| .execute() | |
| ) | |
| log_step("db_select", table="sessions", op=op, session_id=session_id) | |
| return bool(res.data) | |
| if not session_owned_by_user(session_id, uid, owns): | |
| log_pipeline_failure("solve_request", error="forbidden", session_id=session_id) | |
| raise HTTPException( | |
| status_code=403, detail="Forbidden: You do not own this session." | |
| ) | |
| def _enqueue_solve_common( | |
| supabase, | |
| background_tasks: BackgroundTasks, | |
| session_id: str, | |
| user_id, | |
| uid: str, | |
| request: SolveRequest, | |
| message_metadata: dict, | |
| job_id: str, | |
| ) -> SolveResponse: | |
| """Insert user message, job row, enqueue pipeline; update title when first message.""" | |
| supabase.table("messages").insert( | |
| { | |
| "session_id": session_id, | |
| "role": "user", | |
| "type": "text", | |
| "content": request.text, | |
| "metadata": message_metadata, | |
| } | |
| ).execute() | |
| log_step("db_insert", table="messages", op="user_message", session_id=session_id) | |
| supabase.table("jobs").insert( | |
| { | |
| "id": job_id, | |
| "user_id": user_id, | |
| "session_id": session_id, | |
| "status": "processing", | |
| "input_text": request.text, | |
| } | |
| ).execute() | |
| log_step("db_insert", table="jobs", job_id=job_id) | |
| background_tasks.add_task(process_session_job, job_id, session_id, request, str(user_id)) | |
| title_check = supabase.table("sessions").select("title").eq("id", session_id).execute() | |
| if title_check.data and title_check.data[0]["title"] == "Bài toán mới": | |
| new_title = request.text[:50] + ("..." if len(request.text) > 50 else "") | |
| supabase.table("sessions").update({"title": new_title}).eq("id", session_id).execute() | |
| log_step("db_update", table="sessions", op="title_from_first_message") | |
| invalidate_for_user(uid) | |
| log_pipeline_success("solve_accepted", job_id=job_id, session_id=session_id) | |
| return SolveResponse(job_id=job_id, status="processing") | |
| async def ocr_preview( | |
| session_id: str, | |
| user_id=Depends(get_current_user_id), | |
| file: UploadFile = File(...), | |
| user_message: str | None = Form(None), | |
| ): | |
| """ | |
| Run OCR on an uploaded image and merge with optional user_message into combined_draft. | |
| Does not insert messages or start a solve job. After user confirms, call POST .../solve | |
| with text=combined_draft (edited) and omit image_url to avoid double OCR. | |
| """ | |
| supabase = get_supabase() | |
| uid = str(user_id) | |
| _assert_session_owner(supabase, session_id, user_id, uid, "owner_check_ocr_preview") | |
| body = await file.read() | |
| if len(body) > _OCR_PREVIEW_MAX_BYTES: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"Image too large (max {_OCR_PREVIEW_MAX_BYTES // (1024 * 1024)} MB).", | |
| ) | |
| if not body: | |
| raise HTTPException(status_code=400, detail="Empty file.") | |
| if ocr_celery_enabled(): | |
| validate_chat_image_bytes(file.filename, body, file.content_type) | |
| suffix = os.path.splitext(file.filename or "")[1].lower() | |
| if suffix not in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ""): | |
| suffix = ".png" | |
| temp_path = f"temp_ocr_preview_{uuid.uuid4()}{suffix or '.png'}" | |
| try: | |
| with open(temp_path, "wb") as f: | |
| f.write(body) | |
| ocr_text = await ocr_from_local_image_path( | |
| temp_path, file.filename, get_orchestrator().ocr_agent | |
| ) | |
| if ocr_text is None: | |
| ocr_text = "" | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| um = (user_message or "").strip() | |
| combined = build_combined_ocr_preview_draft(user_message, ocr_text) | |
| log_step("ocr_preview_done", session_id=session_id, ocr_len=len(ocr_text), user_len=len(um)) | |
| return OcrPreviewResponse( | |
| ocr_text=ocr_text, | |
| user_message=um, | |
| combined_draft=combined, | |
| ) | |
| async def solve_problem( | |
| session_id: str, | |
| request: SolveRequest, | |
| background_tasks: BackgroundTasks, | |
| user_id=Depends(get_current_user_id), | |
| ): | |
| """ | |
| Gửi câu hỏi giải toán trong một session (Submit geometry problem in a session). | |
| Lưu câu hỏi vào history và bắt đầu tiến trình giải (chỉ giải toán và tạo hình tĩnh). | |
| """ | |
| supabase = get_supabase() | |
| uid = str(user_id) | |
| _assert_session_owner(supabase, session_id, user_id, uid, "owner_check") | |
| message_metadata = {"image_url": request.image_url} if request.image_url else {} | |
| job_id = str(uuid.uuid4()) | |
| return _enqueue_solve_common( | |
| supabase, | |
| background_tasks, | |
| session_id, | |
| user_id, | |
| uid, | |
| request, | |
| message_metadata, | |
| job_id, | |
| ) | |
| async def solve_multipart( | |
| session_id: str, | |
| background_tasks: BackgroundTasks, | |
| user_id=Depends(get_current_user_id), | |
| text: str = Form(...), | |
| file: UploadFile = File(...), | |
| ): | |
| """ | |
| Gửi text + file ảnh trong một request multipart: validate, upload bucket `image`, | |
| ghi session_assets, lưu message kèm metadata (URL, size, type), rồi enqueue solve | |
| (image_url trỏ public URL để orchestrator OCR). | |
| """ | |
| supabase = get_supabase() | |
| uid = str(user_id) | |
| _assert_session_owner(supabase, session_id, user_id, uid, "owner_check_solve_multipart") | |
| t = (text or "").strip() | |
| if not t: | |
| raise HTTPException(status_code=400, detail="text must not be empty.") | |
| body = await file.read() | |
| ext, content_type = validate_chat_image_bytes(file.filename, body, file.content_type) | |
| job_id = str(uuid.uuid4()) | |
| up = upload_session_chat_image(session_id, job_id, body, ext, content_type) | |
| public_url = up["public_url"] | |
| message_metadata = { | |
| "image_url": public_url, | |
| "attachment": { | |
| "public_url": public_url, | |
| "storage_path": up["storage_path"], | |
| "size_bytes": len(body), | |
| "content_type": content_type, | |
| "original_filename": file.filename or "", | |
| "session_asset_id": up.get("session_asset_id"), | |
| }, | |
| } | |
| request = SolveRequest(text=t, image_url=public_url) | |
| return _enqueue_solve_common( | |
| supabase, | |
| background_tasks, | |
| session_id, | |
| user_id, | |
| uid, | |
| request, | |
| message_metadata, | |
| job_id, | |
| ) | |
| async def render_video( | |
| session_id: str, | |
| request: RenderVideoRequest, | |
| background_tasks: BackgroundTasks, | |
| user_id=Depends(get_current_user_id), | |
| ): | |
| """ | |
| Yêu cầu tạo video Manim từ trạng thái hình ảnh mới nhất của session. | |
| """ | |
| supabase = get_supabase() | |
| # 1. Kiểm tra quyền sở hữu | |
| res = supabase.table("sessions").select("id").eq("id", session_id).eq("user_id", user_id).execute() | |
| if not res.data: | |
| raise HTTPException(status_code=403, detail="Forbidden: You do not own this session.") | |
| # 2. Tìm tin nhắn assistant có metadata hình học (cụ thể job_id hoặc mới nhất trong 10 tin nhắn gần nhất) | |
| msg_res = ( | |
| supabase.table("messages") | |
| .select("metadata") | |
| .eq("session_id", session_id) | |
| .eq("role", "assistant") | |
| .order("created_at", desc=True) | |
| .limit(10) | |
| .execute() | |
| ) | |
| latest_geometry = None | |
| if msg_res.data: | |
| for msg in msg_res.data: | |
| meta = msg.get("metadata", {}) | |
| # Nếu có yêu cầu job_id cụ thể, phải khớp job_id | |
| if request.job_id and meta.get("job_id") != request.job_id: | |
| continue | |
| # Phải có dữ liệu hình học | |
| if meta.get("geometry_dsl") and meta.get("coordinates"): | |
| latest_geometry = meta | |
| break | |
| if not latest_geometry: | |
| raise HTTPException(status_code=404, detail="Không tìm thấy dữ liệu hình học để render video.") | |
| # 3. Tạo Job rendering | |
| job_id = str(uuid.uuid4()) | |
| supabase.table("jobs").insert({ | |
| "id": job_id, | |
| "user_id": user_id, | |
| "session_id": session_id, | |
| "status": "rendering_queued", | |
| "input_text": f"Render video requested at {job_id}", | |
| }).execute() | |
| # 4. Dispatch background task | |
| background_tasks.add_task(process_render_job, job_id, session_id, latest_geometry) | |
| return RenderVideoResponse(job_id=job_id, status="rendering_queued") | |
| async def process_session_job( | |
| job_id: str, session_id: str, request: SolveRequest, user_id: str | |
| ): | |
| """Tiến trình giải toán ngầm, tạo hình ảnh tĩnh.""" | |
| from app.websocket_manager import notify_status | |
| async def status_update(status: str): | |
| await notify_status(job_id, {"status": status, "job_id": job_id}) | |
| supabase = get_supabase() | |
| try: | |
| history_res = ( | |
| supabase.table("messages") | |
| .select("*") | |
| .eq("session_id", session_id) | |
| .order("created_at", desc=False) | |
| .execute() | |
| ) | |
| history = history_res.data if history_res.data else [] | |
| result = await get_orchestrator().run( | |
| request.text, | |
| request.image_url, | |
| job_id=job_id, | |
| session_id=session_id, | |
| status_callback=status_update, | |
| history=history, | |
| ) | |
| status = result.get("status", "error") if "error" not in result else "error" | |
| supabase.table("jobs").update({"status": status, "result": result}).eq( | |
| "id", job_id | |
| ).execute() | |
| supabase.table("messages").insert( | |
| { | |
| "session_id": session_id, | |
| "role": "assistant", | |
| "type": "analysis" if "error" not in result else "error", | |
| "content": ( | |
| result.get("semantic_analysis", "Đã có lỗi xảy ra.") | |
| if "error" not in result | |
| else result["error"] | |
| ), | |
| "metadata": { | |
| "job_id": job_id, | |
| "coordinates": result.get("coordinates"), | |
| "geometry_dsl": result.get("geometry_dsl"), | |
| "polygon_order": result.get("polygon_order", []), | |
| "drawing_phases": result.get("drawing_phases", []), | |
| "circles": result.get("circles", []), | |
| "lines": result.get("lines", []), | |
| "rays": result.get("rays", []), | |
| "solution": result.get("solution"), | |
| "is_3d": result.get("is_3d", False), | |
| }, | |
| } | |
| ).execute() | |
| await notify_status(job_id, {"status": status, "job_id": job_id, "result": result}) | |
| except Exception as e: | |
| logger.exception("Error processing session job %s", job_id) | |
| error_msg = format_error_for_user(e) | |
| supabase = get_supabase() | |
| supabase.table("jobs").update( | |
| {"status": "error", "result": {"error": str(e)}} | |
| ).eq("id", job_id).execute() | |
| supabase.table("messages").insert( | |
| { | |
| "session_id": session_id, | |
| "role": "assistant", | |
| "type": "error", | |
| "content": error_msg, | |
| "metadata": {"job_id": job_id}, | |
| } | |
| ).execute() | |
| await notify_status(job_id, {"status": "error", "job_id": job_id, "error": error_msg}) | |
| async def process_render_job(job_id: str, session_id: str, geometry_data: dict): | |
| """Tiến trình render video từ metadata có sẵn.""" | |
| from app.websocket_manager import notify_status | |
| from worker.tasks import render_geometry_video | |
| await notify_status(job_id, {"status": "rendering_queued", "job_id": job_id}) | |
| # Prepare payload for Celery (similar to what orchestrator used to do) | |
| result_payload = { | |
| "geometry_dsl": geometry_data.get("geometry_dsl"), | |
| "coordinates": geometry_data.get("coordinates"), | |
| "polygon_order": geometry_data.get("polygon_order", []), | |
| "drawing_phases": geometry_data.get("drawing_phases", []), | |
| "circles": geometry_data.get("circles", []), | |
| "lines": geometry_data.get("lines", []), | |
| "rays": geometry_data.get("rays", []), | |
| "semantic": geometry_data.get("semantic", {}), | |
| "semantic_analysis": geometry_data.get("semantic_analysis", "🎬 Video minh họa dựng từ trạng thái gần nhất."), | |
| "session_id": session_id, | |
| } | |
| try: | |
| logger.info(f"[RenderJob] Attempting to dispatch Celery task for job {job_id}...") | |
| render_geometry_video.delay(job_id, result_payload) | |
| logger.info(f"[RenderJob] SUCCESS: Dispatched Celery task for job {job_id}") | |
| except Exception as e: | |
| logger.exception(f"[RenderJob] FAILED to dispatch Celery task: {e}") | |
| supabase = get_supabase() | |
| supabase.table("jobs").update({"status": "error", "result": {"error": f"Task dispatch failed: {str(e)}"}}).eq("id", job_id).execute() | |
| await notify_status(job_id, {"status": "error", "job_id": job_id, "error": str(e)}) | |