File size: 4,465 Bytes
395651c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import annotations

import logging
import os
import time
import uuid
import warnings

from dotenv import load_dotenv
from fastapi import Depends, FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from starlette.requests import Request

load_dotenv()

from app.runtime_env import apply_runtime_env_defaults

apply_runtime_env_defaults()

os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")

from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging

setup_application_logging()

# Routers (after logging)
from app.dependencies import get_current_user_id
from app.ocr_local_file import ocr_from_local_image_path
from app.routers import auth, sessions, solve
from agents.ocr_agent import OCRAgent
from app.routers.solve import get_orchestrator
from app.job_poll import normalize_job_row_for_client
from app.supabase_client import get_supabase
from app.websocket_manager import register_websocket_routes

logger = logging.getLogger("app.main")
_access = logging.getLogger(ACCESS_LOGGER_NAME)

app = FastAPI(title="Visual Math Solver API v5.1")


@app.middleware("http")
async def access_log_middleware(request: Request, call_next):
    """LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error."""
    start = time.perf_counter()
    response = await call_next(request)
    ms = (time.perf_counter() - start) * 1000
    mode = get_log_level()
    method = request.method
    path = request.url.path
    status = response.status_code

    if mode in ("debug", "info"):
        _access.info("%s %s -> %s (%.0fms)", method, path, status, ms)
    elif mode == "warning":
        if status >= 500:
            _access.error("%s %s -> %s (%.0fms)", method, path, status, ms)
        elif status >= 400:
            _access.warning("%s %s -> %s (%.0fms)", method, path, status, ms)
    elif mode == "error":
        if status >= 400:
            _access.error("%s %s -> %s", method, path, status)

    return response


from worker.celery_app import BROKER_URL

_broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL
if get_log_level() in ("debug", "info"):
    logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail)
else:
    logger.warning(
        "App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail
    )

app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:3000",
        "http://127.0.0.1:3000",
        "http://localhost:3005",
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.include_router(auth.router)
app.include_router(sessions.router)
app.include_router(solve.router)

register_websocket_routes(app)


def get_ocr_agent() -> OCRAgent:
    """Same OCR instance as the solve pipeline (no duplicate model load)."""
    return get_orchestrator().ocr_agent


supabase_client = get_supabase()


@app.get("/")
def read_root():
    return {"message": "Visual Math Solver API v5.1 is running", "version": "5.1"}


@app.post("/api/v1/ocr")
async def upload_ocr(
    file: UploadFile = File(...),
    _user_id=Depends(get_current_user_id),
):
    """OCR upload: requires authenticated user."""
    temp_path = f"temp_{uuid.uuid4()}.png"
    with open(temp_path, "wb") as buffer:
        buffer.write(await file.read())

    try:
        text = await ocr_from_local_image_path(temp_path, file.filename, get_ocr_agent())
        return {"text": text}
    finally:
        if os.path.exists(temp_path):
            os.remove(temp_path)


@app.get("/api/v1/solve/{job_id}")
async def get_job_status(
    job_id: str,
    user_id=Depends(get_current_user_id),
):
    """Retrieve job status (can be used for polling if WS fails). Owner-only."""
    response = supabase_client.table("jobs").select("*").eq("id", job_id).execute()
    if not response.data:
        raise HTTPException(status_code=404, detail="Job not found")
    job = response.data[0]
    if job.get("user_id") is not None and str(job["user_id"]) != str(user_id):
        raise HTTPException(status_code=403, detail="Forbidden: You do not own this job.")
    # Stable contract for FE poll (job_id alias, parsed result JSON, string UUIDs)
    return normalize_job_row_for_client(job)