Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import shutil | |
| import uuid | |
| import time | |
| import asyncio | |
| from pathlib import Path | |
| from contextlib import asynccontextmanager | |
| from typing import Annotated, Optional | |
| import torch | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, File, HTTPException, UploadFile, status | |
| from pydantic import BaseModel, Field | |
| from src.g3_batch_prediction import G3BatchPredictor | |
| from src.utils import load_images_as_base64 | |
| ENV = os.getenv("ENV") | |
| cred_json = os.getenv("GOOGLE_CREDENTIALS_JSON") | |
| if ENV == "hf": | |
| if cred_json: | |
| try: | |
| # Parse để đảm bảo JSON hợp lệ | |
| json.loads(cred_json) | |
| file_path = "google-credentials.json" | |
| with open(file_path, "w") as f: | |
| f.write(cred_json) | |
| # Set lại env để google auth tự nhận | |
| os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = file_path | |
| print("[INFO] Google credentials saved to", file_path) | |
| except json.JSONDecodeError: | |
| print("[ERROR] GOOGLE_CREDENTIALS_JSON is not valid JSON") | |
| else: | |
| print("[ERROR] GOOGLE_CREDENTIALS_JSON is missing") | |
| else: | |
| # DEV mode (local) | |
| print("[INFO] ENV != hf → skip Google credentials setup") | |
| class EvidenceResponse(BaseModel): | |
| analysis: Annotated[ | |
| str, | |
| Field(description="A supporting analysis for the prediction."), | |
| ] | |
| references: Annotated[ | |
| list[str], | |
| Field(description="Links or base64-encoded JPEG supporting the analysis."), | |
| ] = [] | |
| class LocationPredictionResponse(BaseModel): | |
| latitude: Annotated[ | |
| float, | |
| Field(description="Latitude of the predicted location, in degree."), | |
| ] | |
| longitude: Annotated[ | |
| float, | |
| Field(description="Longitude of the predicted location, in degree."), | |
| ] | |
| location: Annotated[ | |
| str, | |
| Field(description="Textual description of the predicted location."), | |
| ] | |
| evidence: Annotated[ | |
| list[EvidenceResponse], | |
| Field(description="List of supporting analyses for the prediction."), | |
| ] | |
| class PredictionResponse(BaseModel): | |
| prediction: Annotated[ | |
| LocationPredictionResponse, | |
| Field(description="The location prediction and accompanying analysis."), | |
| ] | |
| transcript: Annotated[ | |
| str | None, | |
| Field(description="The extracted and concatenated transcripts, if any."), | |
| ] = None | |
| media: Optional[list[str]] = Field( | |
| default=None, | |
| description="List of media files processed during prediction." | |
| ) | |
| class JobStatus(BaseModel): | |
| job_id: str | |
| status: str | |
| message: str | None = None | |
| result: PredictionResponse | None = None | |
| created_at: float | |
| updated_at: float | |
| predictor: G3BatchPredictor | |
| MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT", "10")) | |
| jobs: dict[str, dict] = {} | |
| jobs_lock = asyncio.Lock() | |
| worker_sem = asyncio.Semaphore(MAX_CONCURRENT) | |
| async def lifespan(app: FastAPI): | |
| load_dotenv() | |
| with open("openapi.json", "wt") as api_file: | |
| json.dump(app.openapi(), api_file, indent=4) | |
| global predictor | |
| predictor = G3BatchPredictor(device="cuda" if torch.cuda.is_available() else "cpu") | |
| yield | |
| del predictor | |
| app = FastAPI( | |
| lifespan=lifespan, | |
| title="G3", | |
| description="An endpoint to predict GPS coordinate from static image," | |
| " using G3 Framework.", | |
| ) | |
| async def _update_job(job_id: str, **fields) -> dict: | |
| async with jobs_lock: | |
| job = jobs[job_id] | |
| job.update(fields) | |
| job["updated_at"] = time.time() | |
| return job.copy() | |
| async def _get_job(job_id: str) -> dict | None: | |
| async with jobs_lock: | |
| job = jobs.get(job_id) | |
| return None if job is None else job.copy() | |
| async def _run_job(job_id: str, job_dir: Path) -> None: | |
| await _update_job(job_id, status="running", message=None) | |
| async with worker_sem: | |
| try: | |
| predictor.clear_directories() | |
| os.makedirs(predictor.input_dir, exist_ok=True) | |
| for file_path in job_dir.iterdir(): | |
| if file_path.is_file(): | |
| dest = predictor.input_dir / file_path.name | |
| shutil.copy(file_path, dest) | |
| response = await predictor.predict(model_name="gemini-2.5-pro") | |
| prediction = LocationPredictionResponse( | |
| latitude=response.latitude, | |
| longitude=response.longitude, | |
| location=response.location, | |
| evidence=[ | |
| EvidenceResponse(analysis=ev.analysis, references=ev.references) | |
| for ev in response.evidence | |
| ], | |
| ) | |
| transcript = predictor.get_transcript() | |
| images_b64 = load_images_as_base64() | |
| result = PredictionResponse( | |
| prediction=prediction, | |
| transcript=transcript, | |
| media=images_b64, | |
| ) | |
| await _update_job(job_id, status="succeeded", result=result) | |
| except Exception as e: | |
| await _update_job(job_id, status="failed", message=str(e)) | |
| finally: | |
| shutil.rmtree(job_dir, ignore_errors=True) | |
| async def predict_endpoint( | |
| files: Annotated[ | |
| list[UploadFile], | |
| File(description="Input images, videos and metadata json."), | |
| ], | |
| ) -> JobStatus: | |
| if not files: | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No files provided") | |
| job_id = uuid.uuid4().hex | |
| job_dir = Path("uploads") / job_id | |
| os.makedirs(job_dir, exist_ok=True) | |
| try: | |
| for file in files: | |
| filename = file.filename if file.filename is not None else uuid.uuid4().hex | |
| filepath = job_dir / filename | |
| with open(filepath, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| now = time.time() | |
| async with jobs_lock: | |
| jobs[job_id] = { | |
| "job_id": job_id, | |
| "status": "queued", | |
| "message": None, | |
| "result": None, | |
| "created_at": now, | |
| "updated_at": now, | |
| } | |
| asyncio.create_task(_run_job(job_id, job_dir)) | |
| job = await _get_job(job_id) | |
| return job # type: ignore[return-value] | |
| except Exception as e: | |
| shutil.rmtree(job_dir, ignore_errors=True) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Failed to enqueue job: {e}", | |
| ) | |
| async def get_job_status(job_id: str) -> JobStatus: | |
| job = await _get_job(job_id) | |
| if job is None: | |
| raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found") | |
| return job # type: ignore[return-value] | |
| async def openapi(): | |
| return app.openapi() | |