3v324v23 commited on
Commit
ae0724c
·
1 Parent(s): 9887761

update flow poolling

Browse files
Files changed (1) hide show
  1. app.py +107 -26
app.py CHANGED
@@ -2,6 +2,9 @@ import json
2
  import os
3
  import shutil
4
  import uuid
 
 
 
5
  from contextlib import asynccontextmanager
6
  from typing import Annotated, Optional
7
 
@@ -88,8 +91,22 @@ class PredictionResponse(BaseModel):
88
  )
89
 
90
 
 
 
 
 
 
 
 
 
 
91
  predictor: G3BatchPredictor
92
 
 
 
 
 
 
93
 
94
  @asynccontextmanager
95
  async def lifespan(app: FastAPI):
@@ -114,51 +131,115 @@ app = FastAPI(
114
  )
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  @app.post(
118
  "/g3/predict",
119
- description="Provide location prediction.",
 
120
  )
121
  async def predict_endpoint(
122
  files: Annotated[
123
  list[UploadFile],
124
  File(description="Input images, videos and metadata json."),
125
  ],
126
- ) -> PredictionResponse:
127
- # Write files to disk
 
 
 
 
 
 
128
  try:
129
- predictor.clear_directories()
130
  for file in files:
131
  filename = file.filename if file.filename is not None else uuid.uuid4().hex
132
- filepath = predictor.input_dir / filename
133
- os.makedirs(predictor.input_dir, exist_ok=True)
134
  with open(filepath, "wb") as buffer:
135
  shutil.copyfileobj(file.file, buffer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  except Exception as e:
 
137
  raise HTTPException(
138
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
139
- detail=f"Failed to save file: {e}",
140
  )
141
 
142
- # Get prediction
143
- response = await predictor.predict(model_name="gemini-2.5-pro")
144
- # response = predictor.get_response(response)
145
- prediction = LocationPredictionResponse(
146
- latitude=response.latitude,
147
- longitude=response.longitude,
148
- location=response.location,
149
- evidence=[
150
- EvidenceResponse(analysis=ev.analysis, references=ev.references)
151
- for ev in response.evidence
152
- ],
153
- )
154
- # Get transcript if available
155
- transcript = predictor.get_transcript()
156
-
157
- # Get media files if available
158
- images_b64 = load_images_as_base64()
159
 
160
- # Clear directories
161
- return PredictionResponse(prediction=prediction, transcript=transcript, media=images_b64)
 
 
 
 
 
 
 
 
162
 
163
 
164
  @app.get(
 
2
  import os
3
  import shutil
4
  import uuid
5
+ import time
6
+ import asyncio
7
+ from pathlib import Path
8
  from contextlib import asynccontextmanager
9
  from typing import Annotated, Optional
10
 
 
91
  )
92
 
93
 
94
+ class JobStatus(BaseModel):
95
+ job_id: str
96
+ status: str
97
+ message: str | None = None
98
+ result: PredictionResponse | None = None
99
+ created_at: float
100
+ updated_at: float
101
+
102
+
103
  predictor: G3BatchPredictor
104
 
105
+ MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT", "10"))
106
+ jobs: dict[str, dict] = {}
107
+ jobs_lock = asyncio.Lock()
108
+ worker_sem = asyncio.Semaphore(MAX_CONCURRENT)
109
+
110
 
111
  @asynccontextmanager
112
  async def lifespan(app: FastAPI):
 
131
  )
132
 
133
 
134
+ async def _update_job(job_id: str, **fields) -> dict:
135
+ async with jobs_lock:
136
+ job = jobs[job_id]
137
+ job.update(fields)
138
+ job["updated_at"] = time.time()
139
+ return job.copy()
140
+
141
+
142
+ async def _get_job(job_id: str) -> dict | None:
143
+ async with jobs_lock:
144
+ job = jobs.get(job_id)
145
+ return None if job is None else job.copy()
146
+
147
+
148
+ async def _run_job(job_id: str, job_dir: Path) -> None:
149
+ await _update_job(job_id, status="running", message=None)
150
+ async with worker_sem:
151
+ try:
152
+ predictor.clear_directories()
153
+
154
+ os.makedirs(predictor.input_dir, exist_ok=True)
155
+ for file_path in job_dir.iterdir():
156
+ if file_path.is_file():
157
+ dest = predictor.input_dir / file_path.name
158
+ shutil.copy(file_path, dest)
159
+
160
+ response = await predictor.predict(model_name="gemini-2.5-pro")
161
+ prediction = LocationPredictionResponse(
162
+ latitude=response.latitude,
163
+ longitude=response.longitude,
164
+ location=response.location,
165
+ evidence=[
166
+ EvidenceResponse(analysis=ev.analysis, references=ev.references)
167
+ for ev in response.evidence
168
+ ],
169
+ )
170
+ transcript = predictor.get_transcript()
171
+ images_b64 = load_images_as_base64()
172
+
173
+ result = PredictionResponse(
174
+ prediction=prediction,
175
+ transcript=transcript,
176
+ media=images_b64,
177
+ )
178
+
179
+ await _update_job(job_id, status="succeeded", result=result)
180
+ except Exception as e:
181
+ await _update_job(job_id, status="failed", message=str(e))
182
+ finally:
183
+ shutil.rmtree(job_dir, ignore_errors=True)
184
+
185
+
186
  @app.post(
187
  "/g3/predict",
188
+ description="Provide location prediction (async, polling via job_id).",
189
+ response_model=JobStatus,
190
  )
191
  async def predict_endpoint(
192
  files: Annotated[
193
  list[UploadFile],
194
  File(description="Input images, videos and metadata json."),
195
  ],
196
+ ) -> JobStatus:
197
+ if not files:
198
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No files provided")
199
+
200
+ job_id = uuid.uuid4().hex
201
+ job_dir = Path("uploads") / job_id
202
+ os.makedirs(job_dir, exist_ok=True)
203
+
204
  try:
 
205
  for file in files:
206
  filename = file.filename if file.filename is not None else uuid.uuid4().hex
207
+ filepath = job_dir / filename
 
208
  with open(filepath, "wb") as buffer:
209
  shutil.copyfileobj(file.file, buffer)
210
+
211
+ now = time.time()
212
+ async with jobs_lock:
213
+ jobs[job_id] = {
214
+ "job_id": job_id,
215
+ "status": "queued",
216
+ "message": None,
217
+ "result": None,
218
+ "created_at": now,
219
+ "updated_at": now,
220
+ }
221
+
222
+ asyncio.create_task(_run_job(job_id, job_dir))
223
+ job = await _get_job(job_id)
224
+ return job # type: ignore[return-value]
225
  except Exception as e:
226
+ shutil.rmtree(job_dir, ignore_errors=True)
227
  raise HTTPException(
228
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
229
+ detail=f"Failed to enqueue job: {e}",
230
  )
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ @app.get(
234
+ "/g3/predict/{job_id}",
235
+ description="Get prediction job status/result.",
236
+ response_model=JobStatus,
237
+ )
238
+ async def get_job_status(job_id: str) -> JobStatus:
239
+ job = await _get_job(job_id)
240
+ if job is None:
241
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Job not found")
242
+ return job # type: ignore[return-value]
243
 
244
 
245
  @app.get(