Fred808 commited on
Commit
1c54af5
·
verified ·
1 Parent(s): 0ed5299

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +855 -0
app.py ADDED
@@ -0,0 +1,855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import asyncio
5
+ import aiohttp
6
+ from typing import Dict, List, Set, Optional
7
+ from urllib.parse import quote, urljoin
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from datasets import Dataset, DatasetDict
11
+ import huggingface_hub
12
+
13
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, status
14
+ from fastapi.responses import JSONResponse
15
+ from pydantic import BaseModel, Field
16
+ import uvicorn
17
+ import aiohttp
18
+
19
+ # Path for storing caption data
20
+ CAPTIONS_DIR = Path("captions_data")
21
+ CAPTIONS_DIR.mkdir(exist_ok=True)
22
+
23
+ # Hugging Face configuration
24
+ HF_TOKEN = os.getenv("HF_TOKEN")
25
+ HF_DATASET_ID = os.getenv("HF_DATASET_ID", "fred808/helium")
26
+
27
+ if not HF_TOKEN:
28
+ raise ValueError("HF_TOKEN environment variable is required")
29
+
30
+ def get_caption_file_path(course: str) -> Path:
31
+ """Get the path to the JSON file for storing course captions"""
32
+ safe_name = quote(course, safe='')
33
+ return CAPTIONS_DIR / f"{safe_name}_captions.json"
34
+
35
+ def save_captions_to_file(course: str, captions: List[Dict]) -> None:
36
+ """Save captions to a JSON file"""
37
+ try:
38
+ file_path = get_caption_file_path(course)
39
+ with open(file_path, 'w', encoding='utf-8') as f:
40
+ json.dump(captions, f, indent=2, ensure_ascii=False)
41
+ print(f"✓ Saved {len(captions)} captions for {course}")
42
+ except Exception as e:
43
+ print(f"Error saving captions for {course}: {e}")
44
+
45
+ def load_captions_from_file(course: str) -> List[Dict]:
46
+ """Load existing captions from JSON file"""
47
+ try:
48
+ file_path = get_caption_file_path(course)
49
+ if file_path.exists():
50
+ with open(file_path, 'r', encoding='utf-8') as f:
51
+ captions = json.load(f)
52
+ print(f"✓ Loaded {len(captions)} existing captions for {course}")
53
+ return captions
54
+ except Exception as e:
55
+ print(f"Error loading captions for {course}: {e}")
56
+ return []
57
+
58
+ # Configuration
59
+ SOURCE_SERVER = "https://samelias1-vs2.hf.space"
60
+ CAPTION_SERVERS = [
61
+ "https://fred808-pil-4-1.hf.space/analyze",
62
+ "https://fred808-pil-4-2.hf.space/analyze",
63
+ "https://fred808-pil-4-3.hf.space/analyze",
64
+ "https://fred1012-fred1012-gw0j2h.hf.space/analyze",
65
+ "https://fred1012-fred1012-wqs6c2.hf.space/analyze",
66
+ "https://fred1012-fred1012-oncray.hf.space/analyze",
67
+ "https://fred1012-fred1012-4goge7.hf.space/analyze",
68
+ "https://fred1012-fred1012-z0eh7m.hf.space/analyze",
69
+ "https://fred1012-fred1012-u95rte.hf.space/analyze",
70
+ "https://fred1012-fred1012-igje22.hf.space/analyze",
71
+ "https://fred1012-fred1012-ibkuf8.hf.space/analyze",
72
+ "https://fred1012-fred1012-nwqthy.hf.space/analyze",
73
+ "https://fred1012-fred1012-4ldqj4.hf.space/analyze",
74
+ "https://fred1012-fred1012-pivlzg.hf.space/analyze",
75
+ "https://fred1012-fred1012-ptlc5u.hf.space/analyze",
76
+ "https://fred1012-fred1012-u7lh57.hf.space/analyze",
77
+ "https://fred1012-fred1012-q8djv1.hf.space/analyze",
78
+ "https://fredalone-fredalone-ozugrp.hf.space/analyze",
79
+ "https://fredalone-fredalone-9brxj2.hf.space/analyze",
80
+ "https://fredalone-fredalone-p8vq9a.hf.space/analyze",
81
+ "https://fredalone-fredalone-vbli2y.hf.space/analyze",
82
+ "https://fredalone-fredalone-uggger.hf.space/analyze",
83
+ "https://fredalone-fredalone-nmi7e8.hf.space/analyze",
84
+ "https://fredalone-fredalone-d1f26d.hf.space/analyze",
85
+ "https://fredalone-fredalone-461jp2.hf.space/analyze",
86
+ "https://fredalone-fredalone-3enfg4.hf.space/analyze",
87
+ "https://fredalone-fredalone-dqdbpv.hf.space/analyze",
88
+ "https://fredalone-fredalone-ivtjua.hf.space/analyze",
89
+ "https://fredalone-fredalone-6bezt2.hf.space/analyze",
90
+ "https://fredalone-fredalone-e0wfnk.hf.space/analyze",
91
+ "https://fredalone-fredalone-zu2t7j.hf.space/analyze",
92
+ "https://fredalone-fredalone-dqtv1o.hf.space/analyze",
93
+ "https://fredalone-fredalone-wclyog.hf.space/analyze",
94
+ "https://fredalone-fredalone-t27vig.hf.space/analyze",
95
+ "https://fredalone-fredalone-gahbxh.hf.space/analyze",
96
+ "https://fredalone-fredalone-kw2po4.hf.space/analyze",
97
+ "https://fredalone-fredalone-8h285h.hf.space/analyze"
98
+ ]
99
+ MODEL_TYPE = "Florence-2-large" # Explicitly request large model
100
+
101
+ # FastAPI Models
102
+ class CourseInfo(BaseModel):
103
+ course_folder: str
104
+
105
+ class ImageInfo(BaseModel):
106
+ filename: str
107
+
108
+ class CaptionRequest(BaseModel):
109
+ image_url: str
110
+ model_choice: str = MODEL_TYPE
111
+
112
+ class CaptionResponse(BaseModel):
113
+ success: bool
114
+ caption: Optional[str] = None
115
+ error: Optional[str] = None
116
+
117
+ class ServerStatus(BaseModel):
118
+ url: str
119
+ model: str
120
+ busy: bool
121
+ total_processed: int
122
+ total_time: float
123
+ fps: float
124
+
125
+ class ProcessingStatus(BaseModel):
126
+ course: str
127
+ total_images: int
128
+ processed_images: int
129
+ progress_percent: float
130
+ status: str
131
+
132
+ class StartProcessingRequest(BaseModel):
133
+ courses: Optional[List[str]] = None # If None, process all courses
134
+ continuous: bool = True # Default to continuous like original
135
+
136
+ # FastAPI App
137
+ app = FastAPI(
138
+ title="Caption Coordinator API",
139
+ description="Distributed caption processing coordinator",
140
+ version="1.0.0"
141
+ )
142
+
143
+ # Global state
144
+ processed_images: Dict[str, Set[str]] = {} # {course: set(image_names)}
145
+ course_captions: Dict[str, List[Dict]] = {} # {course: [{image, caption, metadata}]}
146
+ failed_images: Dict[str, Set[str]] = {} # {course: set(image_names)}
147
+ servers = []
148
+ is_processing = False
149
+ current_processing_task = None
150
+ auto_start_processing = True # Set to False if you don't want auto-start
151
+
152
+ # Map of course -> vs2 callback URL
153
+ pending_vs2_callbacks: Dict[str, str] = {}
154
+
155
+ class CaptionServer:
156
+ def __init__(self, url):
157
+ self.url = url
158
+ self.busy = False
159
+ self.model = "unknown"
160
+ self.total_processed = 0
161
+ self.total_time = 0
162
+
163
+ @property
164
+ def fps(self):
165
+ return self.total_processed / self.total_time if self.total_time > 0 else 0
166
+
167
+ # Initialize servers
168
+ def initialize_servers():
169
+ global servers
170
+ servers = [CaptionServer(url) for url in CAPTION_SERVERS]
171
+
172
+ # API Routes
173
+ @app.get("/")
174
+ async def root():
175
+ return {
176
+ "message": "Caption Coordinator API",
177
+ "status": "running",
178
+ "auto_processing": auto_start_processing,
179
+ "is_processing": is_processing
180
+ }
181
+
182
+ @app.get("/health")
183
+ async def health():
184
+ return {
185
+ "status": "healthy",
186
+ "servers_available": len([s for s in servers if not s.busy]),
187
+ "total_servers": len(servers),
188
+ "is_processing": is_processing,
189
+ "auto_processing": auto_start_processing
190
+ }
191
+
192
+ @app.get("/courses")
193
+ async def get_courses():
194
+ """Fetch available courses from source server"""
195
+ try:
196
+ async with aiohttp.ClientSession() as session:
197
+ async with session.get(f"{SOURCE_SERVER}/courses") as resp:
198
+ data = await resp.json()
199
+ if isinstance(data, dict) and 'courses' in data:
200
+ return [c['course_folder'] for c in data['courses'] if isinstance(c, dict)]
201
+ return []
202
+ except Exception as e:
203
+ raise HTTPException(status_code=500, detail=f"Error fetching courses: {e}")
204
+
205
+
206
+ @app.post("/vs2/register")
207
+ async def vs2_register(payload: Dict):
208
+ """Register a VS2 callback and optionally start processing for the given course.
209
+ Expected payload: {"course": "course_name", "callback_url": "http://vs2-host/flow/done", "start": true}
210
+ """
211
+ try:
212
+ course = payload.get("course")
213
+ callback = payload.get("callback_url")
214
+ start = payload.get("start", True)
215
+
216
+ if not callback:
217
+ raise HTTPException(status_code=400, detail="callback_url is required")
218
+
219
+ # Store callback for later notification
220
+ if course:
221
+ pending_vs2_callbacks[course] = callback
222
+ else:
223
+ # store under wildcard key if course not provided
224
+ pending_vs2_callbacks["*"] = callback
225
+
226
+ # If caller asks to start processing this course immediately, and we're not currently processing,
227
+ # kick off a one-shot processing loop for that course.
228
+ if start:
229
+ global is_processing, current_processing_task
230
+ if not is_processing:
231
+ is_processing = True
232
+ current_processing_task = asyncio.create_task(processing_loop([course] if course else None, False))
233
+
234
+ return {"registered": True, "course": course}
235
+ except HTTPException:
236
+ raise
237
+ except Exception as e:
238
+ raise HTTPException(status_code=500, detail=str(e))
239
+
240
+ @app.get("/courses/{course}/images")
241
+ async def get_course_images(course: str):
242
+ """Fetch images list for a course"""
243
+ try:
244
+ course_frames = f"{course}_frames" if not course.endswith("_frames") else course
245
+ url = f"{SOURCE_SERVER}/images/{quote(course_frames)}"
246
+ async with aiohttp.ClientSession() as session:
247
+ async with session.get(url) as resp:
248
+ data = await resp.json()
249
+ if isinstance(data, dict) and 'images' in data:
250
+ return data['images']
251
+ return []
252
+ except Exception as e:
253
+ raise HTTPException(status_code=500, detail=f"Error fetching images: {e}")
254
+
255
+ @app.get("/servers/status")
256
+ async def get_servers_status():
257
+ """Get status of all caption servers"""
258
+ server_statuses = []
259
+ for server in servers:
260
+ server_statuses.append(ServerStatus(
261
+ url=server.url,
262
+ model=server.model,
263
+ busy=server.busy,
264
+ total_processed=server.total_processed,
265
+ total_time=server.total_time,
266
+ fps=server.fps
267
+ ))
268
+ return server_statuses
269
+
270
+ @app.get("/processing/status")
271
+ async def get_processing_status():
272
+ """Get current processing status"""
273
+ status_info = {}
274
+ for course in processed_images:
275
+ total = len(processed_images[course])
276
+ processed = len(course_captions.get(course, []))
277
+ failed = len(failed_images.get(course, set()))
278
+ status_info[course] = {
279
+ "course": course,
280
+ "total_images": total,
281
+ "processed_images": processed,
282
+ "failed_images": failed,
283
+ "progress_percent": (processed / total * 100) if total > 0 else 0,
284
+ "status": "completed" if processed + failed >= total else "processing"
285
+ }
286
+ return status_info
287
+
288
+ @app.post("/processing/start")
289
+ async def start_processing(request: StartProcessingRequest = StartProcessingRequest()):
290
+ """Start caption processing"""
291
+ global is_processing, current_processing_task
292
+
293
+ if is_processing:
294
+ raise HTTPException(status_code=400, detail="Processing is already running")
295
+
296
+ is_processing = True
297
+ current_processing_task = asyncio.create_task(
298
+ processing_loop(request.courses, request.continuous)
299
+ )
300
+
301
+ return {
302
+ "message": "Processing started",
303
+ "continuous": request.continuous,
304
+ "specific_courses": request.courses
305
+ }
306
+
307
+ @app.post("/processing/stop")
308
+ async def stop_processing():
309
+ """Stop caption processing"""
310
+ global is_processing, current_processing_task
311
+
312
+ if not is_processing:
313
+ raise HTTPException(status_code=400, detail="Processing is not running")
314
+
315
+ is_processing = False
316
+ if current_processing_task:
317
+ current_processing_task.cancel()
318
+ try:
319
+ await current_processing_task
320
+ except asyncio.CancelledError:
321
+ pass
322
+ current_processing_task = None
323
+
324
+ return {"message": "Processing stopped"}
325
+
326
+ @app.get("/captions/{course}")
327
+ async def get_captions(course: str):
328
+ """Get captions for a specific course"""
329
+ captions = load_captions_from_file(course)
330
+ return {
331
+ "course": course,
332
+ "total_captions": len(captions),
333
+ "captions": captions
334
+ }
335
+
336
+ @app.delete("/captions/{course}")
337
+ async def delete_captions(course: str):
338
+ """Delete captions for a specific course"""
339
+ try:
340
+ file_path = get_caption_file_path(course)
341
+ if file_path.exists():
342
+ file_path.unlink()
343
+ if course in processed_images:
344
+ del processed_images[course]
345
+ if course in course_captions:
346
+ del course_captions[course]
347
+ if course in failed_images:
348
+ del failed_images[course]
349
+ return {"message": f"Captions for {course} deleted"}
350
+ else:
351
+ raise HTTPException(status_code=404, detail=f"No captions found for {course}")
352
+ except Exception as e:
353
+ raise HTTPException(status_code=500, detail=f"Error deleting captions: {e}")
354
+
355
+ # Core processing functions
356
+ async def fetch_courses() -> List[str]:
357
+ """Fetch available courses from source server"""
358
+ async with aiohttp.ClientSession() as session:
359
+ async with session.get(f"{SOURCE_SERVER}/courses") as resp:
360
+ data = await resp.json()
361
+ if isinstance(data, dict) and 'courses' in data:
362
+ return [c['course_folder'] for c in data['courses'] if isinstance(c, dict)]
363
+ return []
364
+
365
+ async def fetch_course_images(course: str) -> List[Dict]:
366
+ """Fetch images list for a course"""
367
+ course_frames = f"{course}_frames" if not course.endswith("_frames") else course
368
+ url = f"{SOURCE_SERVER}/images/{quote(course_frames)}"
369
+ async with aiohttp.ClientSession() as session:
370
+ async with session.get(url) as resp:
371
+ data = await resp.json()
372
+ if isinstance(data, dict) and 'images' in data:
373
+ return data['images']
374
+ return []
375
+
376
+ async def get_caption(server: str, image_url: str) -> Dict:
377
+ """Get caption from a specific server"""
378
+ params = {
379
+ 'image_url': image_url,
380
+ 'model_choice': MODEL_TYPE
381
+ }
382
+ try:
383
+ async with aiohttp.ClientSession() as session:
384
+ async with session.get(server, params=params, timeout=30) as resp:
385
+ return await resp.json()
386
+ except Exception as e:
387
+ print(f"Error from {server}: {e}")
388
+ return None
389
+
390
+ async def get_model_info():
391
+ """Get model information from caption servers"""
392
+ model_info = []
393
+ async with aiohttp.ClientSession() as session:
394
+ for server in CAPTION_SERVERS:
395
+ try:
396
+ health_url = server.rsplit('/analyze', 1)[0] + '/health'
397
+ async with session.get(health_url) as resp:
398
+ info = await resp.json()
399
+ model_info.append({
400
+ 'url': server,
401
+ 'model': info.get('model_choice', 'unknown')
402
+ })
403
+ except Exception as e:
404
+ print(f"Couldn't get model info from {server}: {e}")
405
+ return model_info
406
+
407
+
408
+ async def wait_for_vs2_ready(course: str, timeout: Optional[int] = None, interval: int = 5):
409
+ """Poll the SOURCE_SERVER /vs2/state endpoint until VS2 reports 'ready' for the given course.
410
+ If timeout is None, this will poll indefinitely until VS2 is ready or idle.
411
+ """
412
+ url = f"{SOURCE_SERVER}/vs2/state"
413
+ elapsed = 0
414
+ async with aiohttp.ClientSession() as session:
415
+ while True:
416
+ try:
417
+ async with session.get(url, timeout=10) as resp:
418
+ if resp.status == 200:
419
+ data = await resp.json()
420
+ # data may be either {'state': ..., 'current_course': ...} or {'states': {...}}
421
+ state = data.get('state') or None
422
+ current = data.get('current_course') or data.get('current_file')
423
+ if state is None and 'states' in data:
424
+ # per-course states dict was returned
425
+ states = data['states']
426
+ state = states.get(course)
427
+ current = course
428
+
429
+ print(f"VS2 state: {state}, current: {current}")
430
+ # If VS2 explicitly ready for this course, proceed
431
+ if state == 'ready':
432
+ return True
433
+ # If VS2 idle for this course (or unknown), proceed
434
+ if state in (None, 'idle'):
435
+ return True
436
+ else:
437
+ print(f"VS2 state endpoint returned {resp.status}")
438
+ except Exception as e:
439
+ print(f"Could not query VS2 state: {e}")
440
+
441
+ # if timeout set and exceeded, raise; otherwise continue indefinitely
442
+ if timeout is not None:
443
+ elapsed += interval
444
+ if elapsed >= timeout:
445
+ raise Exception(f"Timeout waiting for VS2 to be ready for course {course}")
446
+
447
+ await asyncio.sleep(interval)
448
+
449
+ async def process_image(server: CaptionServer, course: str, image: Dict) -> Dict:
450
+ """Process single image through one caption server with better error handling"""
451
+ if server.busy:
452
+ return None
453
+
454
+ server.busy = True
455
+ start_time = time.time()
456
+
457
+ try:
458
+ # Structure URL correctly: /images/COURSE_NAME_frames/IMAGE.png
459
+ course_frames = f"{course}_frames" if not course.endswith("_frames") else course
460
+ image_url = urljoin(SOURCE_SERVER, f"/images/{quote(course_frames)}/{quote(image['filename'])}")
461
+ result = await get_caption(server.url, image_url)
462
+
463
+ processing_time = time.time() - start_time
464
+ server.total_time += processing_time
465
+
466
+ if result and result.get('success') and result.get('caption'):
467
+ server.total_processed += 1
468
+ metadata = {
469
+ "image": image['filename'],
470
+ "caption": result['caption'],
471
+ "server": server.url,
472
+ "processing_time": processing_time,
473
+ "timestamp": datetime.now().isoformat()
474
+ }
475
+ print(f"Server {server.url} processed {image['filename']} in {processing_time:.2f}s ({server.fps:.2f} fps)")
476
+ return metadata
477
+ else:
478
+ # Server responded but no caption (might be error or empty response)
479
+ error_msg = result.get('error', 'Unknown error') if result else 'No response'
480
+ print(f"Server {server.url} failed for {image['filename']}: {error_msg}")
481
+ return None
482
+
483
+ except asyncio.TimeoutError:
484
+ print(f"Server {server.url} timeout for {image['filename']}")
485
+ return None
486
+ except Exception as e:
487
+ print(f"Error processing {image['filename']} on {server.url}: {e}")
488
+ return None
489
+
490
+ finally:
491
+ server.busy = False
492
+
493
+ async def upload_to_huggingface(course: str, metadata_list: List[Dict]):
494
+ """Upload course captions to Hugging Face dataset"""
495
+ try:
496
+ print(f"📤 Uploading {len(metadata_list)} captions for {course} to Hugging Face...")
497
+
498
+ # Prepare data for Hugging Face dataset
499
+ dataset_data = {
500
+ "course": [],
501
+ "image_filename": [],
502
+ "caption": [],
503
+ "processing_server": [],
504
+ "processing_time": [],
505
+ "timestamp": []
506
+ }
507
+
508
+ for metadata in metadata_list:
509
+ dataset_data["course"].append(course)
510
+ dataset_data["image_filename"].append(metadata["image"])
511
+ dataset_data["caption"].append(metadata["caption"])
512
+ dataset_data["processing_server"].append(metadata["server"])
513
+ dataset_data["processing_time"].append(metadata["processing_time"])
514
+ dataset_data["timestamp"].append(metadata["timestamp"])
515
+
516
+ # Create dataset
517
+ dataset = Dataset.from_dict(dataset_data)
518
+
519
+ # Login to Hugging Face
520
+ huggingface_hub.login(token=HF_TOKEN)
521
+
522
+ # Push to hub
523
+ dataset.push_to_hub(
524
+ HF_DATASET_ID,
525
+ config_name=course.replace("/", "_").replace(" ", "_"),
526
+ split="train", # You can change this to "train", "validation", "test" as needed
527
+ commit_message=f"Add captions for course {course} - {len(metadata_list)} images"
528
+ )
529
+
530
+ print(f"✅ Successfully uploaded {len(metadata_list)} captions for {course} to {HF_DATASET_ID}")
531
+ # Notify VS2 (if VS2 provided a callback for this course)
532
+ try:
533
+ await notify_vs2_flow_done(course, success=True)
534
+ except Exception as e:
535
+ print(f"Warning: failed to notify VS2 about completion for {course}: {e}")
536
+ return True
537
+
538
+ except Exception as e:
539
+ print(f"❌ Error uploading to Hugging Face: {e}")
540
+ return False
541
+
542
+
543
+ async def notify_vs2_flow_done(course: str, success: bool):
544
+ """If VS2 provided a callback URL for this course, POST a completion signal."""
545
+ callback = pending_vs2_callbacks.get(course)
546
+ if not callback:
547
+ # try fallback: look for any callback registered under partial names
548
+ for key, cb in pending_vs2_callbacks.items():
549
+ if key in course:
550
+ callback = cb
551
+ break
552
+ if not callback:
553
+ # nothing to do
554
+ return
555
+
556
+ payload = {
557
+ "course": course,
558
+ "status": "done" if success else "failed",
559
+ "timestamp": datetime.now().isoformat()
560
+ }
561
+
562
+ print(f"Notifying VS2 at {callback} about course {course} -> {payload['status']}")
563
+ try:
564
+ async with aiohttp.ClientSession() as session:
565
+ async with session.post(callback, json=payload, timeout=30) as resp:
566
+ if resp.status >= 400:
567
+ text = await resp.text()
568
+ print(f"VS2 callback returned {resp.status}: {text}")
569
+ except Exception as e:
570
+ print(f"Error notifying VS2 callback {callback}: {e}")
571
+
572
+ async def process_course(course: str, servers: List[CaptionServer]):
573
+ """Process all images in a course using available servers with proper retry logic"""
574
+ # Initialize course tracking
575
+ if course not in processed_images:
576
+ processed_images[course] = set()
577
+ if course not in course_captions:
578
+ course_captions[course] = load_captions_from_file(course)
579
+ # Update processed images set from loaded captions
580
+ for cap in course_captions[course]:
581
+ processed_images[course].add(cap['image'])
582
+ if course not in failed_images:
583
+ failed_images[course] = set()
584
+
585
+ # Get list of images
586
+ images = await fetch_course_images(course)
587
+ if not images:
588
+ print(f"No images found for course {course}")
589
+ return
590
+
591
+ print(f"\nProcessing {len(images)} images for course {course}")
592
+
593
+ # Track images that need processing with retry count (5 retries)
594
+ pending_images = {}
595
+ for img in images:
596
+ filename = img['filename']
597
+ if filename not in processed_images[course] and filename not in failed_images[course]:
598
+ pending_images[filename] = {'image': img, 'retries': 0, 'max_retries': 5}
599
+
600
+ if not pending_images:
601
+ print(f"All images already processed or failed for course {course}")
602
+ print(f"- Processed: {len(processed_images[course])}, Failed: {len(failed_images[course])}")
603
+
604
+ # If course is completed, upload to Hugging Face
605
+ if len(processed_images[course]) + len(failed_images[course]) >= len(images):
606
+ if course_captions[course]:
607
+ print(f"📤 Course {course} completed, uploading to Hugging Face...")
608
+ await upload_to_huggingface(course, course_captions[course])
609
+ return
610
+
611
+ print(f"Images to process: {len(pending_images)} (already processed: {len(processed_images[course])}, failed: {len(failed_images[course])})")
612
+
613
+ batch_size = len([s for s in servers if not s.busy])
614
+ processed_in_this_run = 0
615
+
616
+ while pending_images and is_processing:
617
+ # Create tasks for each available server
618
+ tasks = []
619
+ assigned_images = []
620
+
621
+ for server in servers:
622
+ if not server.busy and pending_images:
623
+ # Get the next pending image
624
+ filename, img_data = next(iter(pending_images.items()))
625
+ img = img_data['image']
626
+
627
+ # Assign this image to the server
628
+ tasks.append(process_image(server, course, img))
629
+ assigned_images.append((filename, img, img_data['retries']))
630
+ # Remove from pending temporarily while it's being processed
631
+ del pending_images[filename]
632
+
633
+ if not tasks:
634
+ # If no servers available, wait a bit
635
+ await asyncio.sleep(0.1)
636
+ continue
637
+
638
+ # Process images in parallel across servers
639
+ results = await asyncio.gather(*tasks)
640
+
641
+ # Handle results and retry logic
642
+ has_new_results = False
643
+ for (filename, img, current_retries), result in zip(assigned_images, results):
644
+ if result:
645
+ # Success - image was processed
646
+ processed_images[course].add(filename)
647
+ course_captions[course].append(result)
648
+ has_new_results = True
649
+ processed_in_this_run += 1
650
+ print(f"✓ Successfully processed {filename}")
651
+ else:
652
+ # Failure - check if we should retry
653
+ if current_retries < 5: # max_retries
654
+ # Put back in pending for retry with incremented retry count
655
+ pending_images[filename] = {
656
+ 'image': img,
657
+ 'retries': current_retries + 1,
658
+ 'max_retries': 5
659
+ }
660
+ print(f"↻ Retry {current_retries + 1}/5 for {filename}")
661
+ else:
662
+ # Max retries exceeded, mark as failed
663
+ failed_images[course].add(filename)
664
+ print(f"✗ Failed to process {filename} after 5 retries")
665
+
666
+ # Save progress after each batch with new results
667
+ if has_new_results:
668
+ save_captions_to_file(course, course_captions[course])
669
+
670
+ # Show progress
671
+ total = len(images)
672
+ done = len(processed_images[course])
673
+ failed_count = len(failed_images[course])
674
+ pending_count = len(pending_images)
675
+ progress_percent = (done / total * 100) if total > 0 else 0
676
+
677
+ print(f"\rProgress: {done}/{total} ({progress_percent:.1f}%) - {pending_count} pending, {failed_count} failed, {processed_in_this_run} new", end="", flush=True)
678
+
679
+ # Small delay to prevent overwhelming the servers
680
+ await asyncio.sleep(0.5)
681
+
682
+ # Final status for this course
683
+ total = len(images)
684
+ done = len(processed_images[course])
685
+ failed_count = len(failed_images[course])
686
+
687
+ if done + failed_count >= total:
688
+ if failed_count > 0:
689
+ print(f"\n✓ Course {course} completed with {failed_count} failed images")
690
+ else:
691
+ print(f"\n✓ Course {course} fully completed")
692
+
693
+ # Upload to Hugging Face when course is completed
694
+ if course_captions[course]:
695
+ print(f"📤 Uploading {len(course_captions[course])} captions to Hugging Face...")
696
+ success = await upload_to_huggingface(course, course_captions[course])
697
+ if success:
698
+ print(f"✅ Successfully uploaded {course} to Hugging Face")
699
+ else:
700
+ print(f"❌ Failed to upload {course} to Hugging Face")
701
+ else:
702
+ print(f"\n→ Course {course} partially completed: {done}/{total} processed, {failed_count} failed")
703
+
704
+ async def processing_loop(specific_courses: Optional[List[str]] = None, continuous: bool = True):
705
+ """Main processing loop with proper error handling"""
706
+ global is_processing
707
+
708
+ # Get model information and verify Florence-2-large availability
709
+ model_info = await get_model_info()
710
+ print("\nCaption Servers:")
711
+ available_servers = []
712
+ for info, server in zip(model_info, servers):
713
+ server.model = info['model']
714
+ if MODEL_TYPE in info.get('model', ''):
715
+ available_servers.append(server)
716
+ print(f"✓ {server.url} confirmed {MODEL_TYPE}")
717
+ else:
718
+ print(f"✗ {server.url} using {server.model} - skipping (requires {MODEL_TYPE})")
719
+
720
+ if not available_servers:
721
+ print(f"\nError: No servers with {MODEL_TYPE} available!")
722
+ is_processing = False
723
+ return
724
+
725
+ # Update servers list to only use those with large model
726
+ processing_servers = available_servers
727
+ print(f"\nUsing {len(processing_servers)} servers with {MODEL_TYPE}")
728
+
729
+ # Check for existing caption files and report
730
+ existing_captions = list(CAPTIONS_DIR.glob("*_captions.json"))
731
+ if existing_captions:
732
+ print("\nFound existing caption files:")
733
+ for cap_file in existing_captions:
734
+ course = cap_file.stem.replace("_captions", "")
735
+ try:
736
+ with open(cap_file, 'r', encoding='utf-8') as f:
737
+ captions = json.load(f)
738
+ print(f"- {course}: {len(captions)} captions")
739
+ except Exception as e:
740
+ print(f"- Error reading {cap_file.name}: {e}")
741
+ print()
742
+
743
+ start_time = time.time()
744
+ iteration = 0
745
+
746
+ while is_processing:
747
+ try:
748
+ iteration += 1
749
+ print(f"\n{'='*50}")
750
+ print(f"Processing Iteration {iteration}")
751
+ print(f"{'='*50}")
752
+
753
+ # Get available courses
754
+ if specific_courses:
755
+ courses = specific_courses
756
+ print(f"Processing specific courses: {courses}")
757
+ else:
758
+ courses = await fetch_courses()
759
+ print(f"Found {len(courses)} courses")
760
+
761
+ if not courses:
762
+ print("No courses found, waiting...")
763
+ if not continuous:
764
+ break
765
+ await asyncio.sleep(10)
766
+ continue
767
+
768
+ # Process each course with all available servers
769
+ for course in courses:
770
+ if not is_processing:
771
+ break
772
+
773
+ print(f"\n--- Processing course: {course} ---")
774
+ # Before processing, ensure VS2 has finished extracting frames for this course
775
+ try:
776
+ await wait_for_vs2_ready(course)
777
+ except Exception as e:
778
+ print(f"Warning: error while checking VS2 readiness for {course}: {e}")
779
+
780
+ await process_course(course, processing_servers)
781
+
782
+ # Show server stats
783
+ print("\nServer Stats:")
784
+ total_processed = sum(s.total_processed for s in processing_servers)
785
+ elapsed = time.time() - start_time
786
+ if elapsed > 0:
787
+ print(f"Total images processed: {total_processed}")
788
+ print(f"Overall speed: {total_processed/elapsed:.2f} fps")
789
+ for s in processing_servers:
790
+ print(f"- {s.url}: {s.total_processed} images, {s.fps:.2f} fps")
791
+ print()
792
+
793
+ if not continuous:
794
+ print("One-time processing completed")
795
+ break
796
+
797
+ # Wait before next check
798
+ print("Waiting for new courses...")
799
+ await asyncio.sleep(5)
800
+
801
+ except asyncio.CancelledError:
802
+ print("Processing cancelled")
803
+ break
804
+ except Exception as e:
805
+ print(f"Error in processing loop: {str(e)}")
806
+ import traceback
807
+ traceback.print_exc()
808
+ await asyncio.sleep(10)
809
+
810
+ is_processing = False
811
+ print("Processing loop stopped")
812
+
813
+ # Startup event
814
+ @app.on_event("startup")
815
+ async def startup_event():
816
+ """Initialize servers and start processing on startup"""
817
+ initialize_servers()
818
+ print("Caption Coordinator API started")
819
+ print(f"Source server: {SOURCE_SERVER}")
820
+ print(f"Caption servers: {len(CAPTION_SERVERS)}")
821
+ print(f"Hugging Face dataset: {HF_DATASET_ID}")
822
+ print(f"HF Token: {'✅ Set' if HF_TOKEN else '❌ Missing'}")
823
+
824
+ # Start processing automatically (like original main())
825
+ if auto_start_processing:
826
+ print("Auto-starting processing loop...")
827
+ global is_processing, current_processing_task
828
+ is_processing = True
829
+ current_processing_task = asyncio.create_task(processing_loop())
830
+
831
+
832
+ @app.post("/vs2/ready")
833
+ async def vs2_ready(course: str, callback_url: str = None):
834
+ """Called by VS2 when it has finished extracting frames for a course.
835
+ VS2 should POST course (string) and its callback_url (where Flow will POST when captioning is done).
836
+ """
837
+ if not course:
838
+ raise HTTPException(status_code=400, detail="course is required")
839
+
840
+ if callback_url:
841
+ pending_vs2_callbacks[course] = callback_url
842
+ print(f"Registered VS2 callback for {course} -> {callback_url}")
843
+
844
+ # Acknowledge. The processing loop will discover the new course via SOURCE_SERVER /courses.
845
+ return {"status": "accepted", "course": course, "callback_url": callback_url}
846
+
847
+
848
+ @app.get("/vs2/callbacks")
849
+ async def list_vs2_callbacks():
850
+ """List pending VS2 callbacks (debug)"""
851
+ return pending_vs2_callbacks
852
+
853
+
854
+ if __name__ == "__main__":
855
+ uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)