Spaces:
Sleeping
Sleeping
File size: 4,465 Bytes
395651c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | from __future__ import annotations
import logging
import os
import time
import uuid
import warnings
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from starlette.requests import Request
load_dotenv()
from app.runtime_env import apply_runtime_env_defaults
apply_runtime_env_defaults()
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")
from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging
setup_application_logging()
# Routers (after logging)
from app.dependencies import get_current_user_id
from app.ocr_local_file import ocr_from_local_image_path
from app.routers import auth, sessions, solve
from agents.ocr_agent import OCRAgent
from app.routers.solve import get_orchestrator
from app.job_poll import normalize_job_row_for_client
from app.supabase_client import get_supabase
from app.websocket_manager import register_websocket_routes
logger = logging.getLogger("app.main")
_access = logging.getLogger(ACCESS_LOGGER_NAME)
app = FastAPI(title="Visual Math Solver API v5.1")
@app.middleware("http")
async def access_log_middleware(request: Request, call_next):
"""LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error."""
start = time.perf_counter()
response = await call_next(request)
ms = (time.perf_counter() - start) * 1000
mode = get_log_level()
method = request.method
path = request.url.path
status = response.status_code
if mode in ("debug", "info"):
_access.info("%s %s -> %s (%.0fms)", method, path, status, ms)
elif mode == "warning":
if status >= 500:
_access.error("%s %s -> %s (%.0fms)", method, path, status, ms)
elif status >= 400:
_access.warning("%s %s -> %s (%.0fms)", method, path, status, ms)
elif mode == "error":
if status >= 400:
_access.error("%s %s -> %s", method, path, status)
return response
from worker.celery_app import BROKER_URL
_broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL
if get_log_level() in ("debug", "info"):
logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail)
else:
logger.warning(
"App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail
)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://127.0.0.1:3000",
"http://localhost:3005",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth.router)
app.include_router(sessions.router)
app.include_router(solve.router)
register_websocket_routes(app)
def get_ocr_agent() -> OCRAgent:
"""Same OCR instance as the solve pipeline (no duplicate model load)."""
return get_orchestrator().ocr_agent
supabase_client = get_supabase()
@app.get("/")
def read_root():
return {"message": "Visual Math Solver API v5.1 is running", "version": "5.1"}
@app.post("/api/v1/ocr")
async def upload_ocr(
file: UploadFile = File(...),
_user_id=Depends(get_current_user_id),
):
"""OCR upload: requires authenticated user."""
temp_path = f"temp_{uuid.uuid4()}.png"
with open(temp_path, "wb") as buffer:
buffer.write(await file.read())
try:
text = await ocr_from_local_image_path(temp_path, file.filename, get_ocr_agent())
return {"text": text}
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
@app.get("/api/v1/solve/{job_id}")
async def get_job_status(
job_id: str,
user_id=Depends(get_current_user_id),
):
"""Retrieve job status (can be used for polling if WS fails). Owner-only."""
response = supabase_client.table("jobs").select("*").eq("id", job_id).execute()
if not response.data:
raise HTTPException(status_code=404, detail="Job not found")
job = response.data[0]
if job.get("user_id") is not None and str(job["user_id"]) != str(user_id):
raise HTTPException(status_code=403, detail="Forbidden: You do not own this job.")
# Stable contract for FE poll (job_id alias, parsed result JSON, string UUIDs)
return normalize_job_row_for_client(job)
|