Fred808 commited on
Commit
0cb4e16
·
verified ·
1 Parent(s): f4440e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +686 -676
app.py CHANGED
@@ -1,677 +1,687 @@
1
- import os
2
- import json
3
- import time
4
- import asyncio
5
- import aiohttp
6
- import zipfile
7
- import shutil
8
- from typing import Dict, List, Set, Optional, Tuple, Any
9
- from urllib.parse import quote
10
- from datetime import datetime
11
- from pathlib import Path
12
- import io
13
-
14
- from fastapi import FastAPI, BackgroundTasks, HTTPException, status
15
- from pydantic import BaseModel, Field
16
- from huggingface_hub import HfApi, hf_hub_download
17
-
18
- # --- Configuration ---
19
- AUTO_START_INDEX = 0# Hardcoded default start index if no progress is found
20
- FLOW_ID = os.getenv("FLOW_ID", "flow_default")
21
- FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
22
- HF_TOKEN = os.getenv("HF_TOKEN", "")
23
- HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3") # Source dataset for zip files
24
- HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/data") # Target dataset for captions
25
-
26
- # Progress and State Tracking
27
- PROGRESS_FILE = Path("processing_progress.json")
28
- HF_STATE_FILE = "processing_state_cursors.json" # State file in helium dataset
29
- LOCAL_STATE_FOLDER = Path(".state") # Local folder for state file
30
- LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
31
-
32
- # Directory within the HF dataset where the zip files are located
33
- ZIP_FILE_PREFIX = "frames_zips/"
34
-
35
- # Using the full list from the user's original code for actual deployment
36
- CAPTION_SERVERS = [
37
- "https://Son4live-ajax-1.hf.space/track_cursor",
38
- "https://Son4live-ajax-2.hf.space/track_cursor",
39
- "https://Son4live-ajax-3.hf.space/track_cursor",
40
- "https://Son4live-ajax-4.hf.space/track_cursor",
41
- "https://Son4live-ajax-5.hf.space/track_cursor",
42
- "https://Son4live-ajax-6.hf.space/track_cursor",
43
- "https://Son4live-ajax-7.hf.space/track_cursor",
44
- "https://Son4live-ajax-8.hf.space/track_cursor",
45
- "https://Son4live-ajax-9.hf.space/track_cursor",
46
- "https://Son4live-ajax-10.hf.space/track_cursor",
47
- "https://Son4live-ajax-11.hf.space/track_cursor",
48
- "https://Son4live-ajax-12.hf.space/track_cursor",
49
- "https://Son4live-ajax-13.hf.space/track_cursor",
50
- "https://Son4live-ajax-14.hf.space/track_cursor",
51
- "https://Son4live-ajax-15.hf.space/track_cursor",
52
- "https://Son4live-ajax-16.hf.space/track_cursor",
53
- "https://Son4live-ajax-17.hf.space/track_cursor",
54
- "https://Son4live-ajax-18.hf.space/track_cursor",
55
- "https://Son4live-ajax-19.hf.space/track_cursor",
56
- "https://Son4live-ajax-20.hf.space/track_cursor",
57
- "https://jirehlove-jaypq-1.hf.space/track_cursor",
58
- "https://jirehlove-jaypq-2.hf.space/track_cursor",
59
- "https://jirehlove-jaypq-3.hf.space/track_cursor",
60
- "https://jirehlove-jaypq-4.hf.space/track_cursor",
61
- "https://jirehlove-jaypq-5.hf.space/track_cursor",
62
- "https://jirehlove-jaypq-6.hf.space/track_cursor",
63
- "https://jirehlove-jaypq-7.hf.space/track_cursor",
64
- "https://jirehlove-jaypq-8.hf.space/track_cursor",
65
- "https://jirehlove-jaypq-9.hf.space/track_cursor",
66
- "https://jirehlove-jaypq-10.hf.space/track_cursor",
67
- "https://jirehlove-jaypq-11.hf.space/track_cursor",
68
- "https://jirehlove-jaypq-12.hf.space/track_cursor",
69
- "https://jirehlove-jaypq-13.hf.space/track_cursor",
70
- "https://jirehlove-jaypq-14.hf.space/track_cursor",
71
- "https://jirehlove-jaypq-15.hf.space/track_cursor",
72
- "https://jirehlove-jaypq-16.hf.space/track_cursor",
73
- "https://jirehlove-jaypq-17.hf.space/track_cursor",
74
- "https://jirehlove-jaypq-18.hf.space/track_cursor",
75
- "https://jirehlove-jaypq-19.hf.space/track_cursor",
76
- "https://jirehlove-jaypq-20.hf.space/track_cursor",
77
- "https://lovyone-ones-1.hf.space/track_cursor",
78
- "https://lovyone-ones-2.hf.space/track_cursor",
79
- "https://lovyone-ones-3.hf.space/track_cursor",
80
- "https://lovyone-ones-4.hf.space/track_cursor",
81
- "https://lovyone-ones-5.hf.space/track_cursor",
82
- "https://lovyone-ones-6.hf.space/track_cursor",
83
- "https://lovyone-ones-7.hf.space/track_cursor",
84
- "https://lovyone-ones-8.hf.space/track_cursor",
85
- "https://lovyone-ones-9.hf.space/track_cursor",
86
- "https://lovyone-ones-10.hf.space/track_cursor",
87
- "https://lovyone-ones-11.hf.space/track_cursor",
88
- "https://lovyone-ones-12.hf.space/track_cursor",
89
- "https://lovyone-ones-13.hf.space/track_cursor",
90
- "https://lovyone-ones-14.hf.space/track_cursor",
91
- "https://lovyone-ones-15.hf.space/track_cursor",
92
- "https://lovyone-ones-16.hf.space/track_cursor",
93
- "https://lovyone-ones-17.hf.space/track_cursor",
94
- "https://lovyone-ones-18.hf.space/track_cursor",
95
- "https://lovyone-ones-19.hf.space/track_cursor",
96
- "https://lovyone-ones-20.hf.space/track_cursor"
97
- ]
98
- MODEL_TYPE = "Florence-2-large"
99
-
100
- # Temporary storage for images
101
- TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
102
- TEMP_DIR.mkdir(exist_ok=True)
103
-
104
- # --- Models ---
105
- class ProcessStartRequest(BaseModel):
106
- start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the zip file to start processing from (1-indexed).")
107
-
108
- class CaptionServer:
109
- def __init__(self, url):
110
- self.url = url
111
- self.busy = False
112
- self.total_processed = 0
113
- self.total_time = 0
114
- self.model = MODEL_TYPE
115
-
116
- @property
117
- def fps(self):
118
- return self.total_processed / self.total_time if self.total_time > 0 else 0
119
-
120
- # Global state for caption servers
121
- servers = [CaptionServer(url) for url in CAPTION_SERVERS]
122
- server_index = 0
123
-
124
- # --- Progress and State Management Functions ---
125
-
126
- def load_progress() -> Dict:
127
- """Loads the local processing progress from the JSON file."""
128
- if PROGRESS_FILE.exists():
129
- try:
130
- with PROGRESS_FILE.open('r') as f:
131
- return json.load(f)
132
- except json.JSONDecodeError:
133
- print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
134
- # Fall through to return default structure
135
-
136
- # Default structure
137
- return {
138
- "last_processed_index": 0,
139
- "processed_files": {}, # {index: repo_path}
140
- "file_list": [] # Full list of all zip files found in the dataset
141
- }
142
-
143
- def save_progress(progress_data: Dict):
144
- """Saves the local processing progress to the JSON file."""
145
- try:
146
- with PROGRESS_FILE.open('w') as f:
147
- json.dump(progress_data, f, indent=4)
148
- except Exception as e:
149
- print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}")
150
-
151
- def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]:
152
- """Load state from JSON file with migration logic for new structure."""
153
- if os.path.exists(file_path):
154
- try:
155
- with open(file_path, "r") as f:
156
- data = json.load(f)
157
-
158
- # Migration Logic
159
- if "file_states" not in data or not isinstance(data["file_states"], dict):
160
- print(f"[{FLOW_ID}] Initializing 'file_states' dictionary.")
161
- data["file_states"] = {}
162
-
163
- if "next_download_index" not in data:
164
- data["next_download_index"] = 0
165
-
166
- return data
167
- except json.JSONDecodeError:
168
- print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}")
169
- return default_value
170
-
171
- def save_json_state(file_path: str, data: Dict[str, Any]):
172
- """Save state to JSON file"""
173
- with open(file_path, "w") as f:
174
- json.dump(data, f, indent=2)
175
-
176
- async def download_hf_state() -> Dict[str, Any]:
177
- """Downloads the state file from Hugging Face or returns a default state."""
178
- local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
179
- default_state = {"next_download_index": 0, "file_states": {}}
180
-
181
- try:
182
- # Check if the file exists in the helium repo
183
- files = HfApi(token=HF_TOKEN).list_repo_files(
184
- repo_id=HF_OUTPUT_DATASET_ID,
185
- repo_type="dataset"
186
- )
187
-
188
- if HF_STATE_FILE not in files:
189
- print(f"[{FLOW_ID}] State file not found in {HF_OUTPUT_DATASET_ID}. Starting fresh.")
190
- return default_state
191
-
192
- # Download the file
193
- hf_hub_download(
194
- repo_id=HF_OUTPUT_DATASET_ID,
195
- filename=HF_STATE_FILE,
196
- repo_type="dataset",
197
- local_dir=LOCAL_STATE_FOLDER,
198
- local_dir_use_symlinks=False,
199
- token=HF_TOKEN
200
- )
201
-
202
- print(f"[{FLOW_ID}] Successfully downloaded state file.")
203
- return load_json_state(str(local_path), default_state)
204
-
205
- except Exception as e:
206
- print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Starting fresh.")
207
- return default_state
208
-
209
- async def upload_hf_state(state: Dict[str, Any]) -> bool:
210
- """Uploads the state file to Hugging Face."""
211
- local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
212
-
213
- try:
214
- # Save state locally first
215
- save_json_state(str(local_path), state)
216
-
217
- # Upload to helium dataset
218
- HfApi(token=HF_TOKEN).upload_file(
219
- path_or_fileobj=str(local_path),
220
- path_in_repo=HF_STATE_FILE,
221
- repo_id=HF_OUTPUT_DATASET_ID,
222
- repo_type="dataset",
223
- commit_message=f"Update caption processing state: next_index={state['next_download_index']}"
224
- )
225
- print(f"[{FLOW_ID}] Successfully uploaded state file.")
226
- return True
227
- except Exception as e:
228
- print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
229
- return False
230
-
231
- async def lock_file_for_processing(zip_filename: str, state: Dict[str, Any]) -> bool:
232
- """Marks a file as 'processing' in the state file and uploads the lock."""
233
- print(f"[{FLOW_ID}] 🔒 Attempting to lock file: {zip_filename}")
234
-
235
- # Update state locally
236
- state["file_states"][zip_filename] = "processing"
237
-
238
- # Upload the updated state file immediately to establish the lock
239
- if await upload_hf_state(state):
240
- print(f"[{FLOW_ID}] ✅ Successfully locked file: {zip_filename}")
241
- return True
242
- else:
243
- print(f"[{FLOW_ID}] ❌ Failed to lock file: {zip_filename}")
244
- # Revert local state
245
- if zip_filename in state["file_states"]:
246
- del state["file_states"][zip_filename]
247
- return False
248
-
249
- async def unlock_file_as_processed(zip_filename: str, state: Dict[str, Any], next_index: int) -> bool:
250
- """Marks a file as 'processed', updates the index, and uploads the state."""
251
- print(f"[{FLOW_ID}] 🔓 Marking file as processed: {zip_filename}")
252
-
253
- # Update state locally
254
- state["file_states"][zip_filename] = "processed"
255
- state["next_download_index"] = next_index
256
-
257
- # Upload the updated state
258
- if await upload_hf_state(state):
259
- print(f"[{FLOW_ID}] ✅ Successfully marked as processed: {zip_filename}")
260
- return True
261
- else:
262
- print(f"[{FLOW_ID}] ❌ Failed to update state for: {zip_filename}")
263
- return False
264
-
265
- # --- Hugging Face Utility Functions ---
266
-
267
- async def get_zip_file_list(progress_data: Dict) -> List[str]:
268
- """
269
- Fetches the list of all zip files from the dataset, or uses the cached list.
270
- Updates the progress_data with the file list if a new list is fetched.
271
- """
272
- if progress_data['file_list']:
273
- print(f"[{FLOW_ID}] Using cached file list with {len(progress_data['file_list'])} files.")
274
- return progress_data['file_list']
275
-
276
- print(f"[{FLOW_ID}] Fetching full list of zip files from {HF_DATASET_ID}...")
277
- try:
278
- api = HfApi(token=HF_TOKEN)
279
- repo_files = api.list_repo_files(
280
- repo_id=HF_DATASET_ID,
281
- repo_type="dataset"
282
- )
283
-
284
- # Filter for zip files in the specified directory and sort them alphabetically for consistent indexing
285
- zip_files = sorted([
286
- f for f in repo_files
287
- if f.startswith(ZIP_FILE_PREFIX) and f.endswith('.zip')
288
- ])
289
-
290
- if not zip_files:
291
- raise FileNotFoundError(f"No zip files found in '{ZIP_FILE_PREFIX}' directory of dataset '{HF_DATASET_ID}'.")
292
-
293
- print(f"[{FLOW_ID}] Found {len(zip_files)} zip files.")
294
-
295
- # Update and save the progress data
296
- progress_data['file_list'] = zip_files
297
- save_progress(progress_data)
298
-
299
- return zip_files
300
-
301
- except Exception as e:
302
- print(f"[{FLOW_ID}] Error fetching file list from Hugging Face: {e}")
303
- return []
304
-
305
- async def download_and_extract_zip_by_index(file_index: int, repo_file_full_path: str) -> Optional[Path]:
306
- """Downloads the zip file for the given index and extracts its contents."""
307
-
308
- # Extract the base name for the extraction directory
309
- zip_full_name = Path(repo_file_full_path).name
310
- course_name = zip_full_name.replace('.zip', '') # Use the file name as the course/job name
311
-
312
- print(f"[{FLOW_ID}] Processing file #{file_index}: {repo_file_full_path}. Full name: {zip_full_name}")
313
-
314
- try:
315
- # Use hf_hub_download to get the file path
316
- zip_path = hf_hub_download(
317
- repo_id=HF_DATASET_ID,
318
- filename=repo_file_full_path, # Use the full path in the repo
319
- repo_type="dataset",
320
- token=HF_TOKEN,
321
- )
322
-
323
- print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
324
-
325
- # Create a temporary directory for extraction
326
- extract_dir = TEMP_DIR / course_name
327
- # Ensure a clean directory for extraction
328
- if extract_dir.exists():
329
- shutil.rmtree(extract_dir)
330
- extract_dir.mkdir(exist_ok=True)
331
-
332
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
333
- zip_ref.extractall(extract_dir)
334
-
335
- print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
336
-
337
- # Clean up the downloaded zip file to save space
338
- os.remove(zip_path)
339
-
340
- return extract_dir
341
-
342
- except Exception as e:
343
- print(f"[{FLOW_ID}] Error downloading or extracting zip for {repo_file_full_path}: {e}")
344
- return None
345
-
346
- async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
347
- """Uploads the final captions JSON file to the output dataset."""
348
- # Use the full zip name, replacing the extension with .json
349
- caption_filename = Path(zip_full_name).with_suffix('.json').name
350
-
351
- try:
352
- print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
353
-
354
- # Create JSON content in memory
355
- json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
356
-
357
- api = HfApi(token=HF_TOKEN)
358
- api.upload_file(
359
- path_or_fileobj=io.BytesIO(json_content),
360
- path_in_repo=caption_filename,
361
- repo_id=HF_OUTPUT_DATASET_ID,
362
- repo_type="dataset",
363
- commit_message=f"[{FLOW_ID}] Captions for {zip_full_name}"
364
- )
365
-
366
- print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
367
- return True
368
-
369
- except Exception as e:
370
- print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
371
- return False
372
-
373
- # --- Core Processing Functions (Modified) ---
374
-
375
- async def get_available_server(timeout: float = 300.0) -> CaptionServer:
376
- """Round-robin selection of an available caption server."""
377
- global server_index
378
- start_time = time.time()
379
- while True:
380
- # Round-robin check for an available server
381
- for _ in range(len(servers)):
382
- server = servers[server_index]
383
- server_index = (server_index + 1) % len(servers)
384
- if not server.busy:
385
- return server
386
-
387
- # If all servers are busy, wait for a short period and check again
388
- await asyncio.sleep(0.5)
389
-
390
- # Check if timeout has been reached
391
- if time.time() - start_time > timeout:
392
- raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
393
-
394
- async def send_image_for_captioning(image_path: Path, course_name: str, progress_tracker: Dict) -> Optional[Dict]:
395
- """Sends a single image to a caption server for processing."""
396
- # This function now handles server selection and retries internally
397
- MAX_RETRIES = 3
398
- for attempt in range(MAX_RETRIES):
399
- server = None
400
- try:
401
- # 1. Get an available server (will wait if all are busy, with a timeout)
402
- server = await get_available_server()
403
- server.busy = True
404
- start_time = time.time()
405
-
406
- # Print a less verbose message only on the first attempt
407
- if attempt == 0:
408
- print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
409
-
410
- # 2. Prepare request data
411
- form_data = aiohttp.FormData()
412
- form_data.add_field('file',
413
- image_path.open('rb'),
414
- filename=image_path.name,
415
- content_type='image/jpeg')
416
- form_data.add_field('model_choice', MODEL_TYPE)
417
-
418
- # 3. Send request
419
- async with aiohttp.ClientSession() as session:
420
- # Increased timeout to 10 minutes (600s) as requested by user's problem description
421
- async with session.post(server.url, data=form_data, timeout=600) as resp:
422
- if resp.status == 200:
423
- result = await resp.json()
424
-
425
- # Handle cursor detection response format
426
- if result.get('cursor_active') is not None: # Check if it's a valid cursor detection response
427
- # Update progress counter
428
- progress_tracker['completed'] += 1
429
- if progress_tracker['completed'] % 50 == 0:
430
- print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} detections completed.")
431
-
432
- # Log success only if it's not a progress report interval
433
- if progress_tracker['completed'] % 50 != 0:
434
- print(f"[{FLOW_ID}] Success: {image_path.name} processed by {server.url}")
435
-
436
- # Store the full cursor detection result
437
- return {
438
- "course": course_name,
439
- "image_path": image_path.name,
440
- "cursor_active": result.get('cursor_active', False),
441
- "x": result.get('x'),
442
- "y": result.get('y'),
443
- "confidence": result.get('confidence'),
444
- "template": result.get('template'),
445
- "image_shape": result.get('image_shape'),
446
- "server_url": server.url,
447
- "timestamp": datetime.now().isoformat()
448
- }
449
- else:
450
- print(f"[{FLOW_ID}] Server {server.url} returned invalid response format for {image_path.name}. Response: {result}")
451
- continue # Retry with a different server
452
- else:
453
- error_text = await resp.text()
454
- print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
455
- continue # Retry with a different server
456
-
457
- except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
458
- print(f"[{FLOW_ID}] Connection/Timeout error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
459
- continue # Retry with a different server
460
- except Exception as e:
461
- print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
462
- continue # Retry with a different server
463
- finally:
464
- if server:
465
- end_time = time.time()
466
- server.busy = False
467
- server.total_processed += 1
468
- server.total_time += (end_time - start_time)
469
-
470
- print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
471
- return None
472
-
473
- async def process_dataset_task(start_index: int):
474
- """Main task to process the dataset sequentially starting from a given index."""
475
-
476
- # Load both local progress and HF state
477
- progress = load_progress()
478
- current_state = await download_hf_state()
479
- file_list = await get_zip_file_list(progress)
480
-
481
- if not file_list:
482
- print(f"[{FLOW_ID}] ERROR: Cannot proceed. File list is empty.")
483
- return False
484
-
485
- # Ensure start_index is within bounds
486
- if start_index > len(file_list):
487
- print(f"[{FLOW_ID}] WARNING: Start index {start_index} is greater than the total number of files ({len(file_list)}). Exiting.")
488
- return True
489
-
490
- # Determine the actual starting index in the 0-indexed list
491
- start_list_index = start_index - 1
492
-
493
- print(f"[{FLOW_ID}] Starting dataset processing from file index: {start_index} out of {len(file_list)}.")
494
-
495
- global_success = True
496
-
497
- for i in range(start_list_index, len(file_list)):
498
- file_index = i + 1 # 1-indexed for user display and progress tracking
499
- repo_file_full_path = file_list[i]
500
- zip_full_name = Path(repo_file_full_path).name
501
- course_name = zip_full_name.replace('.zip', '') # Use the file name as the course/job name
502
-
503
- # Check file state in both local and HF state
504
- file_state = current_state["file_states"].get(zip_full_name)
505
- if file_state == "processed":
506
- print(f"[{FLOW_ID}] Skipping {zip_full_name}: Already processed in global state.")
507
- continue
508
- elif file_state == "processing":
509
- print(f"[{FLOW_ID}] Skipping {zip_full_name}: Currently being processed by another worker.")
510
- continue
511
-
512
- # Try to lock the file
513
- if not await lock_file_for_processing(zip_full_name, current_state):
514
- print(f"[{FLOW_ID}] Failed to lock {zip_full_name}. Skipping.")
515
- continue
516
-
517
- extract_dir = None
518
- current_file_success = False
519
-
520
- try:
521
- # 1. Download and Extract
522
- extract_dir = await download_and_extract_zip_by_index(file_index, repo_file_full_path)
523
-
524
- if not extract_dir:
525
- raise Exception("Failed to download or extract zip file.")
526
-
527
- # 2. Find Images
528
- # Use recursive glob to find images in subdirectories
529
- image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
530
- print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
531
-
532
- if not image_paths:
533
- print(f"[{FLOW_ID}] No images found in {zip_full_name}. Marking as complete.")
534
- current_file_success = True
535
- else:
536
- # 3. Process Images (Captioning)
537
- progress_tracker = {
538
- 'total': len(image_paths),
539
- 'completed': 0
540
- }
541
- print(f"[{FLOW_ID}] Starting captioning for {progress_tracker['total']} images in {zip_full_name}...")
542
-
543
- # Create a semaphore to limit concurrent tasks to the number of available servers
544
- semaphore = asyncio.Semaphore(len(servers))
545
-
546
- async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
547
- async with semaphore:
548
- return await send_image_for_captioning(image_path, course_name, progress_tracker)
549
-
550
- # Create a list of tasks for parallel captioning
551
- caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
552
-
553
- # Run all captioning tasks concurrently
554
- results = await asyncio.gather(*caption_tasks)
555
-
556
- # Filter out failed results
557
- all_captions = [r for r in results if r is not None]
558
-
559
- # Final progress report for the current file
560
- if len(all_captions) == len(image_paths):
561
- print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully processed all {len(all_captions)} images.")
562
- else:
563
- print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} images.")
564
-
565
- # Calculate success statistics
566
- cursor_detected = sum(1 for c in all_captions if c.get('cursor_active', False))
567
- print(f"[{FLOW_ID}] Detection Statistics:")
568
- print(f"- Total processed: {len(all_captions)}")
569
- print(f"- Cursors detected: {cursor_detected}")
570
- print(f"- Detection rate: {(cursor_detected/len(all_captions)*100):.2f}%")
571
-
572
- # Consider the file successful if we have any captions at all
573
- current_file_success = len(all_captions) > 0
574
-
575
- # 4. Upload Results
576
- if all_captions:
577
- print(f"[{FLOW_ID}] Uploading {len(all_captions)} captions for {zip_full_name}...")
578
- if await upload_captions_to_hf(zip_full_name, all_captions):
579
- print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
580
- # Mark as success if we have any captions and successfully uploaded them
581
- current_file_success = True
582
- else:
583
- print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}.")
584
- current_file_success = False
585
- else:
586
- print(f"[{FLOW_ID}] No captions generated. Skipping upload for {zip_full_name}.")
587
- current_file_success = False
588
-
589
- except Exception as e:
590
- print(f"[{FLOW_ID}] Critical error in process_dataset_task for file #{file_index} ({zip_full_name}): {e}")
591
- current_file_success = False
592
- global_success = False # Mark overall task as failed if any file fails critically
593
-
594
- finally:
595
- # 5. Cleanup and Update Progress
596
- if extract_dir and extract_dir.exists():
597
- print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
598
- shutil.rmtree(extract_dir, ignore_errors=True)
599
-
600
- if current_file_success:
601
- # Update both local progress and HF state
602
- progress['last_processed_index'] = file_index
603
- progress['processed_files'][str(file_index)] = repo_file_full_path
604
- save_progress(progress)
605
-
606
- # Update HF state and unlock the file
607
- if await unlock_file_as_processed(zip_full_name, current_state, file_index + 1):
608
- print(f"[{FLOW_ID}] Progress saved and file unlocked: {zip_full_name}")
609
- else:
610
- print(f"[{FLOW_ID}] Warning: File processed but state update failed: {zip_full_name}")
611
- else:
612
- # Mark as failed in the state and continue with next file
613
- current_state["file_states"][zip_full_name] = "failed"
614
- await upload_hf_state(current_state)
615
- print(f"[{FLOW_ID}] File {zip_full_name} marked as failed. Continuing with next file.")
616
- global_success = False
617
-
618
- print(f"[{FLOW_ID}] All processing loops complete. Overall success: {global_success}")
619
- return global_success
620
-
621
- # --- FastAPI App and Endpoints ---
622
-
623
- app = FastAPI(
624
- title=f"Flow Server {FLOW_ID} API",
625
- description="Sequentially processes zip files from a dataset, captions images, and tracks progress.",
626
- version="1.0.0"
627
- )
628
-
629
- @app.on_event("startup")
630
- async def startup_event():
631
- print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}.")
632
-
633
- # Automatically start the processing task
634
- progress = load_progress()
635
- # Start from the last processed index + 1, or the hardcoded AUTO_START_INDEX if the progress file is new/empty
636
- start_index = progress.get('last_processed_index', 0) + 1
637
- if start_index < AUTO_START_INDEX:
638
- start_index = AUTO_START_INDEX
639
-
640
- # Use a dummy BackgroundTasks object for the startup task
641
- # Note: FastAPI's startup events can't directly use BackgroundTasks, but we can use asyncio.create_task
642
- # to run the long-running process in the background without blocking the server startup.
643
- print(f"[{FLOW_ID}] Auto-starting processing from index: {start_index}...")
644
- asyncio.create_task(process_dataset_task(start_index))
645
-
646
- @app.get("/")
647
- async def root():
648
- progress = load_progress()
649
- return {
650
- "flow_id": FLOW_ID,
651
- "status": "ready",
652
- "last_processed_index": progress['last_processed_index'],
653
- "total_files_in_list": len(progress['file_list']),
654
- "processed_files_count": len(progress['processed_files']),
655
- "total_servers": len(servers),
656
- "busy_servers": sum(1 for s in servers if s.busy),
657
- }
658
-
659
- @app.post("/start_processing")
660
- async def start_processing(request: ProcessStartRequest, background_tasks: BackgroundTasks):
661
- """
662
- Starts the sequential processing of zip files from the given index in the background.
663
- """
664
- start_index = request.start_index
665
-
666
- print(f"[{FLOW_ID}] Received request to start processing from index: {start_index}. Starting background task.")
667
-
668
- # Start the heavy processing in a background task so the API call returns immediately
669
- # Note: The server is already auto-starting, but this allows for manual restart/override.
670
- background_tasks.add_task(process_dataset_task, start_index)
671
-
672
- return {"status": "processing", "start_index": start_index, "message": "Dataset processing started in background."}
673
-
674
- if __name__ == "__main__":
675
- import uvicorn
676
- # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
 
 
 
 
 
 
 
 
 
 
677
  uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
 
1
+ import os
2
+ import json
3
+ import time
4
+ import asyncio
5
+ import aiohttp
6
+ import zipfile
7
+ import shutil
8
+ from typing import Dict, List, Set, Optional, Tuple, Any
9
+ from urllib.parse import quote
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ import io
13
+
14
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, status
15
+ from pydantic import BaseModel, Field
16
+ from huggingface_hub import HfApi, hf_hub_download
17
+
18
+ # --- Configuration ---
19
+ AUTO_START_INDEX = 0# Hardcoded default start index if no progress is found
20
+ FLOW_ID = os.getenv("FLOW_ID", "flow_default")
21
+ FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
22
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
23
+ HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3") # Source dataset for zip files
24
+ HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/data") # Target dataset for captions
25
+
26
+ # Progress and State Tracking
27
+ PROGRESS_FILE = Path("processing_progress.json")
28
+ HF_STATE_FILE = "processing_state_cursors.json" # State file in helium dataset
29
+ LOCAL_STATE_FOLDER = Path(".state") # Local folder for state file
30
+ LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
31
+
32
+ # Directory within the HF dataset where the zip files are located
33
+ ZIP_FILE_PREFIX = "frames_zips/"
34
+
35
+ # Using the full list from the user's original code for actual deployment
36
+ CAPTION_SERVERS = [
37
+ "https://Son4live-ajax-1.hf.space/track_cursor",
38
+ "https://Son4live-ajax-2.hf.space/track_cursor",
39
+ "https://Son4live-ajax-3.hf.space/track_cursor",
40
+ "https://Son4live-ajax-4.hf.space/track_cursor",
41
+ "https://Son4live-ajax-5.hf.space/track_cursor",
42
+ "https://Son4live-ajax-6.hf.space/track_cursor",
43
+ "https://Son4live-ajax-7.hf.space/track_cursor",
44
+ "https://Son4live-ajax-8.hf.space/track_cursor",
45
+ "https://Son4live-ajax-9.hf.space/track_cursor",
46
+ "https://Son4live-ajax-10.hf.space/track_cursor",
47
+ "https://Son4live-ajax-11.hf.space/track_cursor",
48
+ "https://Son4live-ajax-12.hf.space/track_cursor",
49
+ "https://Son4live-ajax-13.hf.space/track_cursor",
50
+ "https://Son4live-ajax-14.hf.space/track_cursor",
51
+ "https://Son4live-ajax-15.hf.space/track_cursor",
52
+ "https://Son4live-ajax-16.hf.space/track_cursor",
53
+ "https://Son4live-ajax-17.hf.space/track_cursor",
54
+ "https://Son4live-ajax-18.hf.space/track_cursor",
55
+ "https://Son4live-ajax-19.hf.space/track_cursor",
56
+ "https://Son4live-ajax-20.hf.space/track_cursor",
57
+ "https://jirehlove-jaypq-1.hf.space/track_cursor",
58
+ "https://jirehlove-jaypq-2.hf.space/track_cursor",
59
+ "https://jirehlove-jaypq-3.hf.space/track_cursor",
60
+ "https://jirehlove-jaypq-4.hf.space/track_cursor",
61
+ "https://jirehlove-jaypq-5.hf.space/track_cursor",
62
+ "https://jirehlove-jaypq-6.hf.space/track_cursor",
63
+ "https://jirehlove-jaypq-7.hf.space/track_cursor",
64
+ "https://jirehlove-jaypq-8.hf.space/track_cursor",
65
+ "https://jirehlove-jaypq-9.hf.space/track_cursor",
66
+ "https://jirehlove-jaypq-10.hf.space/track_cursor",
67
+ "https://jirehlove-jaypq-11.hf.space/track_cursor",
68
+ "https://jirehlove-jaypq-12.hf.space/track_cursor",
69
+ "https://jirehlove-jaypq-13.hf.space/track_cursor",
70
+ "https://jirehlove-jaypq-14.hf.space/track_cursor",
71
+ "https://jirehlove-jaypq-15.hf.space/track_cursor",
72
+ "https://jirehlove-jaypq-16.hf.space/track_cursor",
73
+ "https://jirehlove-jaypq-17.hf.space/track_cursor",
74
+ "https://jirehlove-jaypq-18.hf.space/track_cursor",
75
+ "https://jirehlove-jaypq-19.hf.space/track_cursor",
76
+ "https://jirehlove-jaypq-20.hf.space/track_cursor",
77
+ "https://lovyone-ones-1.hf.space/track_cursor",
78
+ "https://lovyone-ones-2.hf.space/track_cursor",
79
+ "https://lovyone-ones-3.hf.space/track_cursor",
80
+ "https://lovyone-ones-4.hf.space/track_cursor",
81
+ "https://lovyone-ones-5.hf.space/track_cursor",
82
+ "https://lovyone-ones-6.hf.space/track_cursor",
83
+ "https://lovyone-ones-7.hf.space/track_cursor",
84
+ "https://lovyone-ones-8.hf.space/track_cursor",
85
+ "https://lovyone-ones-9.hf.space/track_cursor",
86
+ "https://lovyone-ones-10.hf.space/track_cursor",
87
+ "https://lovyone-ones-11.hf.space/track_cursor",
88
+ "https://lovyone-ones-12.hf.space/track_cursor",
89
+ "https://lovyone-ones-13.hf.space/track_cursor",
90
+ "https://lovyone-ones-14.hf.space/track_cursor",
91
+ "https://lovyone-ones-15.hf.space/track_cursor",
92
+ "https://lovyone-ones-16.hf.space/track_cursor",
93
+ "https://lovyone-ones-17.hf.space/track_cursor",
94
+ "https://lovyone-ones-18.hf.space/track_cursor",
95
+ "https://lovyone-ones-19.hf.space/track_cursor",
96
+ "https://lovyone-ones-20.hf.space/track_cursor"
97
+ ]
98
+ MODEL_TYPE = "Florence-2-large"
99
+
100
+ # Temporary storage for images
101
+ TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
102
+ TEMP_DIR.mkdir(exist_ok=True)
103
+
104
+ # --- Models ---
105
+ class ProcessStartRequest(BaseModel):
106
+ start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the zip file to start processing from (1-indexed).")
107
+
108
+ class CaptionServer:
109
+ def __init__(self, url):
110
+ self.url = url
111
+ self.busy = False
112
+ self.total_processed = 0
113
+ self.total_time = 0
114
+ self.model = MODEL_TYPE
115
+
116
+ @property
117
+ def fps(self):
118
+ return self.total_processed / self.total_time if self.total_time > 0 else 0
119
+
120
+ # Global state for caption servers
121
+ servers = [CaptionServer(url) for url in CAPTION_SERVERS]
122
+ server_index = 0
123
+
124
+ # --- Progress and State Management Functions ---
125
+
126
+ def load_progress() -> Dict:
127
+ """Loads the local processing progress from the JSON file."""
128
+ if PROGRESS_FILE.exists():
129
+ try:
130
+ with PROGRESS_FILE.open('r') as f:
131
+ return json.load(f)
132
+ except json.JSONDecodeError:
133
+ print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
134
+ # Fall through to return default structure
135
+
136
+ # Default structure
137
+ return {
138
+ "last_processed_index": 0,
139
+ "processed_files": {}, # {index: repo_path}
140
+ "file_list": [] # Full list of all zip files found in the dataset
141
+ }
142
+
143
+ def save_progress(progress_data: Dict):
144
+ """Saves the local processing progress to the JSON file."""
145
+ try:
146
+ with PROGRESS_FILE.open('w') as f:
147
+ json.dump(progress_data, f, indent=4)
148
+ except Exception as e:
149
+ print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}")
150
+
151
+ def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]:
152
+ """Load state from JSON file with migration logic for new structure."""
153
+ if os.path.exists(file_path):
154
+ try:
155
+ with open(file_path, "r") as f:
156
+ data = json.load(f)
157
+
158
+ # Migration Logic
159
+ if "file_states" not in data or not isinstance(data["file_states"], dict):
160
+ print(f"[{FLOW_ID}] Initializing 'file_states' dictionary.")
161
+ data["file_states"] = {}
162
+
163
+ if "next_download_index" not in data:
164
+ data["next_download_index"] = 0
165
+
166
+ return data
167
+ except json.JSONDecodeError:
168
+ print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}")
169
+ return default_value
170
+
171
+ def save_json_state(file_path: str, data: Dict[str, Any]):
172
+ """Save state to JSON file"""
173
+ with open(file_path, "w") as f:
174
+ json.dump(data, f, indent=2)
175
+
176
+ async def download_hf_state() -> Dict[str, Any]:
177
+ """Downloads the state file from Hugging Face or returns a default state."""
178
+ local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
179
+ default_state = {"next_download_index": 0, "file_states": {}}
180
+
181
+ try:
182
+ # Check if the file exists in the helium repo
183
+ files = HfApi(token=HF_TOKEN).list_repo_files(
184
+ repo_id=HF_OUTPUT_DATASET_ID,
185
+ repo_type="dataset"
186
+ )
187
+
188
+ if HF_STATE_FILE not in files:
189
+ print(f"[{FLOW_ID}] State file not found in {HF_OUTPUT_DATASET_ID}. Starting fresh.")
190
+ return default_state
191
+
192
+ # Download the file
193
+ hf_hub_download(
194
+ repo_id=HF_OUTPUT_DATASET_ID,
195
+ filename=HF_STATE_FILE,
196
+ repo_type="dataset",
197
+ local_dir=LOCAL_STATE_FOLDER,
198
+ local_dir_use_symlinks=False,
199
+ token=HF_TOKEN
200
+ )
201
+
202
+ print(f"[{FLOW_ID}] Successfully downloaded state file.")
203
+ return load_json_state(str(local_path), default_state)
204
+
205
+ except Exception as e:
206
+ print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Starting fresh.")
207
+ return default_state
208
+
209
+ async def upload_hf_state(state: Dict[str, Any]) -> bool:
210
+ """Uploads the state file to Hugging Face."""
211
+ local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
212
+
213
+ try:
214
+ # Save state locally first
215
+ save_json_state(str(local_path), state)
216
+
217
+ # Upload to helium dataset
218
+ HfApi(token=HF_TOKEN).upload_file(
219
+ path_or_fileobj=str(local_path),
220
+ path_in_repo=HF_STATE_FILE,
221
+ repo_id=HF_OUTPUT_DATASET_ID,
222
+ repo_type="dataset",
223
+ commit_message=f"Update caption processing state: next_index={state['next_download_index']}"
224
+ )
225
+ print(f"[{FLOW_ID}] Successfully uploaded state file.")
226
+ return True
227
+ except Exception as e:
228
+ print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
229
+ return False
230
+
231
+ async def lock_file_for_processing(zip_filename: str, state: Dict[str, Any]) -> bool:
232
+ """Marks a file as 'processing' in the state file and uploads the lock."""
233
+ print(f"[{FLOW_ID}] 🔒 Attempting to lock file: {zip_filename}")
234
+
235
+ # Update state locally
236
+ state["file_states"][zip_filename] = "processing"
237
+
238
+ # Upload the updated state file immediately to establish the lock
239
+ if await upload_hf_state(state):
240
+ print(f"[{FLOW_ID}] ✅ Successfully locked file: {zip_filename}")
241
+ return True
242
+ else:
243
+ print(f"[{FLOW_ID}] ❌ Failed to lock file: {zip_filename}")
244
+ # Revert local state
245
+ if zip_filename in state["file_states"]:
246
+ del state["file_states"][zip_filename]
247
+ return False
248
+
249
+ async def unlock_file_as_processed(zip_filename: str, state: Dict[str, Any], next_index: int) -> bool:
250
+ """Marks a file as 'processed', updates the index, and uploads the state."""
251
+ print(f"[{FLOW_ID}] 🔓 Marking file as processed: {zip_filename}")
252
+
253
+ # Update state locally
254
+ state["file_states"][zip_filename] = "processed"
255
+ state["next_download_index"] = next_index
256
+
257
+ # Upload the updated state
258
+ if await upload_hf_state(state):
259
+ print(f"[{FLOW_ID}] ✅ Successfully marked as processed: {zip_filename}")
260
+ return True
261
+ else:
262
+ print(f"[{FLOW_ID}] ❌ Failed to update state for: {zip_filename}")
263
+ return False
264
+
265
+ # --- Hugging Face Utility Functions ---
266
+
267
+ async def get_zip_file_list(progress_data: Dict) -> List[str]:
268
+ """
269
+ Fetches the list of all zip files from the dataset, or uses the cached list.
270
+ Updates the progress_data with the file list if a new list is fetched.
271
+ """
272
+ if progress_data['file_list']:
273
+ print(f"[{FLOW_ID}] Using cached file list with {len(progress_data['file_list'])} files.")
274
+ return progress_data['file_list']
275
+
276
+ print(f"[{FLOW_ID}] Fetching full list of zip files from {HF_DATASET_ID}...")
277
+ try:
278
+ api = HfApi(token=HF_TOKEN)
279
+ repo_files = api.list_repo_files(
280
+ repo_id=HF_DATASET_ID,
281
+ repo_type="dataset"
282
+ )
283
+
284
+ # Filter for zip files in the specified directory and sort them alphabetically for consistent indexing
285
+ zip_files = sorted([
286
+ f for f in repo_files
287
+ if f.startswith(ZIP_FILE_PREFIX) and f.endswith('.zip')
288
+ ])
289
+
290
+ if not zip_files:
291
+ raise FileNotFoundError(f"No zip files found in '{ZIP_FILE_PREFIX}' directory of dataset '{HF_DATASET_ID}'.")
292
+
293
+ print(f"[{FLOW_ID}] Found {len(zip_files)} zip files.")
294
+
295
+ # Update and save the progress data
296
+ progress_data['file_list'] = zip_files
297
+ save_progress(progress_data)
298
+
299
+ return zip_files
300
+
301
+ except Exception as e:
302
+ print(f"[{FLOW_ID}] Error fetching file list from Hugging Face: {e}")
303
+ return []
304
+
305
+ async def download_and_extract_zip_by_index(file_index: int, repo_file_full_path: str) -> Optional[Path]:
306
+ """Downloads the zip file for the given index and extracts its contents."""
307
+
308
+ # Extract the base name for the extraction directory
309
+ zip_full_name = Path(repo_file_full_path).name
310
+ course_name = zip_full_name.replace('.zip', '') # Use the file name as the course/job name
311
+
312
+ print(f"[{FLOW_ID}] Processing file #{file_index}: {repo_file_full_path}. Full name: {zip_full_name}")
313
+
314
+ try:
315
+ # Use hf_hub_download to get the file path
316
+ zip_path = hf_hub_download(
317
+ repo_id=HF_DATASET_ID,
318
+ filename=repo_file_full_path, # Use the full path in the repo
319
+ repo_type="dataset",
320
+ token=HF_TOKEN,
321
+ )
322
+
323
+ print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
324
+
325
+ # Create a temporary directory for extraction
326
+ extract_dir = TEMP_DIR / course_name
327
+ # Ensure a clean directory for extraction
328
+ if extract_dir.exists():
329
+ shutil.rmtree(extract_dir)
330
+ extract_dir.mkdir(exist_ok=True)
331
+
332
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
333
+ zip_ref.extractall(extract_dir)
334
+
335
+ print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
336
+
337
+ # Clean up the downloaded zip file to save space
338
+ os.remove(zip_path)
339
+
340
+ return extract_dir
341
+
342
+ except Exception as e:
343
+ print(f"[{FLOW_ID}] Error downloading or extracting zip for {repo_file_full_path}: {e}")
344
+ return None
345
+
346
+ async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
347
+ """Uploads the final captions JSON file to the output dataset."""
348
+ # Use the full zip name, replacing the extension with .json
349
+ caption_filename = Path(zip_full_name).with_suffix('.json').name
350
+
351
+ try:
352
+ print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
353
+
354
+ # Create JSON content in memory
355
+ json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
356
+
357
+ api = HfApi(token=HF_TOKEN)
358
+ api.upload_file(
359
+ path_or_fileobj=io.BytesIO(json_content),
360
+ path_in_repo=caption_filename,
361
+ repo_id=HF_OUTPUT_DATASET_ID,
362
+ repo_type="dataset",
363
+ commit_message=f"[{FLOW_ID}] Captions for {zip_full_name}"
364
+ )
365
+
366
+ print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
367
+ return True
368
+
369
+ except Exception as e:
370
+ print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
371
+ return False
372
+
373
+ # --- Core Processing Functions (Modified) ---
374
+
375
+ async def get_available_server(timeout: float = 300.0) -> CaptionServer:
376
+ """Round-robin selection of an available caption server."""
377
+ global server_index
378
+ start_time = time.time()
379
+ while True:
380
+ # Round-robin check for an available server
381
+ for _ in range(len(servers)):
382
+ server = servers[server_index]
383
+ server_index = (server_index + 1) % len(servers)
384
+ if not server.busy:
385
+ return server
386
+
387
+ # If all servers are busy, wait for a short period and check again
388
+ await asyncio.sleep(0.5)
389
+
390
+ # Check if timeout has been reached
391
+ if time.time() - start_time > timeout:
392
+ raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
393
+
394
+ async def send_image_for_captioning(image_path: Path, course_name: str, progress_tracker: Dict) -> Optional[Dict]:
395
+ """Sends a single image to a caption server for processing."""
396
+ # This function now handles server selection and retries internally
397
+ MAX_RETRIES = 3
398
+ for attempt in range(MAX_RETRIES):
399
+ server = None
400
+ try:
401
+ # 1. Get an available server (will wait if all are busy, with a timeout)
402
+ server = await get_available_server()
403
+ server.busy = True
404
+ start_time = time.time()
405
+
406
+ # Print a less verbose message only on the first attempt
407
+ if attempt == 0:
408
+ print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
409
+
410
+ # 2. Prepare request data
411
+ form_data = aiohttp.FormData()
412
+ form_data.add_field('file',
413
+ image_path.open('rb'),
414
+ filename=image_path.name,
415
+ content_type='image/jpeg')
416
+ form_data.add_field('model_choice', MODEL_TYPE)
417
+
418
+ # 3. Send request
419
+ async with aiohttp.ClientSession() as session:
420
+ # Increased timeout to 10 minutes (600s) as requested by user's problem description
421
+ async with session.post(server.url, data=form_data, timeout=600) as resp:
422
+ if resp.status == 200:
423
+ result = await resp.json()
424
+
425
+ # Handle cursor detection response format
426
+ if result.get('cursor_active') is not None: # Check if it's a valid cursor detection response
427
+ # Update progress counter
428
+ progress_tracker['completed'] += 1
429
+ if progress_tracker['completed'] % 50 == 0:
430
+ print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} detections completed.")
431
+
432
+ # Log success only if it's not a progress report interval
433
+ if progress_tracker['completed'] % 50 != 0:
434
+ print(f"[{FLOW_ID}] Success: {image_path.name} processed by {server.url}")
435
+
436
+ # Store the full cursor detection result
437
+ return {
438
+ "course": course_name,
439
+ "image_path": image_path.name,
440
+ "cursor_active": result.get('cursor_active', False),
441
+ "x": result.get('x'),
442
+ "y": result.get('y'),
443
+ "confidence": result.get('confidence'),
444
+ "template": result.get('template'),
445
+ "image_shape": result.get('image_shape'),
446
+ "server_url": server.url,
447
+ "timestamp": datetime.now().isoformat()
448
+ }
449
+ else:
450
+ print(f"[{FLOW_ID}] Server {server.url} returned invalid response format for {image_path.name}. Response: {result}")
451
+ continue # Retry with a different server
452
+ else:
453
+ error_text = await resp.text()
454
+ print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
455
+ continue # Retry with a different server
456
+
457
+ except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
458
+ print(f"[{FLOW_ID}] Connection/Timeout error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
459
+ continue # Retry with a different server
460
+ except Exception as e:
461
+ print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
462
+ continue # Retry with a different server
463
+ finally:
464
+ if server:
465
+ end_time = time.time()
466
+ server.busy = False
467
+ server.total_processed += 1
468
+ server.total_time += (end_time - start_time)
469
+
470
+ print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
471
+ return None
472
+
473
+ async def process_dataset_task(start_index: int):
474
+ """Main task to process the dataset sequentially starting from a given index."""
475
+
476
+ # Load both local progress and HF state
477
+ progress = load_progress()
478
+ current_state = await download_hf_state()
479
+ file_list = await get_zip_file_list(progress)
480
+
481
+ if not file_list:
482
+ print(f"[{FLOW_ID}] ERROR: Cannot proceed. File list is empty.")
483
+ return False
484
+
485
+ # Ensure start_index is within bounds
486
+ if start_index > len(file_list):
487
+ print(f"[{FLOW_ID}] WARNING: Start index {start_index} is greater than the total number of files ({len(file_list)}). Exiting.")
488
+ return True
489
+
490
+ # Determine the actual starting index in the 0-indexed list
491
+ start_list_index = start_index - 1
492
+
493
+ print(f"[{FLOW_ID}] Starting dataset processing from file index: {start_index} out of {len(file_list)}.")
494
+
495
+ global_success = True
496
+
497
+ for i in range(start_list_index, len(file_list)):
498
+ file_index = i + 1 # 1-indexed for user display and progress tracking
499
+ repo_file_full_path = file_list[i]
500
+ zip_full_name = Path(repo_file_full_path).name
501
+ course_name = zip_full_name.replace('.zip', '') # Use the file name as the course/job name
502
+
503
+ # Check file state in both local and HF state
504
+ file_state = current_state["file_states"].get(zip_full_name)
505
+ if file_state == "processed":
506
+ print(f"[{FLOW_ID}] Skipping {zip_full_name}: Already processed in global state.")
507
+ continue
508
+ elif file_state == "processing":
509
+ print(f"[{FLOW_ID}] Skipping {zip_full_name}: Currently being processed by another worker.")
510
+ continue
511
+
512
+ # Try to lock the file
513
+ if not await lock_file_for_processing(zip_full_name, current_state):
514
+ print(f"[{FLOW_ID}] Failed to lock {zip_full_name}. Skipping.")
515
+ continue
516
+
517
+ extract_dir = None
518
+ current_file_success = False
519
+
520
+ try:
521
+ # 1. Download and Extract
522
+ extract_dir = await download_and_extract_zip_by_index(file_index, repo_file_full_path)
523
+
524
+ if not extract_dir:
525
+ raise Exception("Failed to download or extract zip file.")
526
+
527
+ # 2. Find Images
528
+ # Use recursive glob to find images in subdirectories
529
+ image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
530
+ print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
531
+
532
+ if not image_paths:
533
+ print(f"[{FLOW_ID}] No images found in {zip_full_name}. Marking as complete.")
534
+ current_file_success = True
535
+ else:
536
+ # 3. Process Images (Captioning)
537
+ progress_tracker = {
538
+ 'total': len(image_paths),
539
+ 'completed': 0
540
+ }
541
+ print(f"[{FLOW_ID}] Starting captioning for {progress_tracker['total']} images in {zip_full_name}...")
542
+
543
+ # Create a semaphore to limit concurrent tasks to the number of available servers
544
+ semaphore = asyncio.Semaphore(len(servers))
545
+
546
+ async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
547
+ async with semaphore:
548
+ return await send_image_for_captioning(image_path, course_name, progress_tracker)
549
+
550
+ # Create a list of tasks for parallel captioning
551
+ caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
552
+
553
+ # Run all captioning tasks concurrently
554
+ results = await asyncio.gather(*caption_tasks)
555
+
556
+ # Filter out failed results
557
+ all_captions = [r for r in results if r is not None]
558
+
559
+ # Final progress report for the current file
560
+ if len(all_captions) == len(image_paths):
561
+ print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully processed all {len(all_captions)} images.")
562
+ else:
563
+ print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} images.")
564
+
565
+ # Calculate success statistics
566
+ cursor_detected = sum(1 for c in all_captions if c.get('cursor_active', False))
567
+ print(f"[{FLOW_ID}] Detection Statistics:")
568
+ print(f"- Total processed: {len(all_captions)}")
569
+ print(f"- Cursors detected: {cursor_detected}")
570
+ print(f"- Detection rate: {(cursor_detected/len(all_captions)*100):.2f}%")
571
+
572
+ # Consider the file successful if we have any captions at all
573
+ current_file_success = len(all_captions) > 0
574
+
575
+ # 4. Upload Results
576
+ if all_captions:
577
+ print(f"[{FLOW_ID}] Uploading {len(all_captions)} captions for {zip_full_name}...")
578
+ if await upload_captions_to_hf(zip_full_name, all_captions):
579
+ print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
580
+ # Mark as success if we have any captions and successfully uploaded them
581
+ current_file_success = True
582
+ else:
583
+ print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}.")
584
+ current_file_success = False
585
+ else:
586
+ print(f"[{FLOW_ID}] No captions generated. Skipping upload for {zip_full_name}.")
587
+ current_file_success = False
588
+
589
+ except Exception as e:
590
+ print(f"[{FLOW_ID}] Critical error in process_dataset_task for file #{file_index} ({zip_full_name}): {e}")
591
+ current_file_success = False
592
+ global_success = False # Mark overall task as failed if any file fails critically
593
+
594
+ finally:
595
+ # 5. Cleanup and Update Progress
596
+ if extract_dir and extract_dir.exists():
597
+ print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
598
+ shutil.rmtree(extract_dir, ignore_errors=True)
599
+
600
+ if current_file_success:
601
+ # Update both local progress and HF state
602
+ progress['last_processed_index'] = file_index
603
+ progress['processed_files'][str(file_index)] = repo_file_full_path
604
+ save_progress(progress)
605
+
606
+ # Update HF state and unlock the file
607
+ if await unlock_file_as_processed(zip_full_name, current_state, file_index + 1):
608
+ print(f"[{FLOW_ID}] Progress saved and file unlocked: {zip_full_name}")
609
+ else:
610
+ print(f"[{FLOW_ID}] Warning: File processed but state update failed: {zip_full_name}")
611
+ else:
612
+ # Mark as failed in the state and continue with next file
613
+ current_state["file_states"][zip_full_name] = "failed"
614
+ await upload_hf_state(current_state)
615
+ print(f"[{FLOW_ID}] File {zip_full_name} marked as failed. Continuing with next file.")
616
+ global_success = False
617
+
618
+ print(f"[{FLOW_ID}] All processing loops complete. Overall success: {global_success}")
619
+ return global_success
620
+
621
+ # --- FastAPI App and Endpoints ---
622
+
623
+ app = FastAPI(
624
+ title=f"Flow Server {FLOW_ID} API",
625
+ description="Sequentially processes zip files from a dataset, captions images, and tracks progress.",
626
+ version="1.0.0"
627
+ )
628
+
629
+ @app.on_event("startup")
630
+ async def startup_event():
631
+ print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}.")
632
+
633
+ # Get both local progress and HF state
634
+ progress = load_progress()
635
+ current_state = await download_hf_state()
636
+
637
+ # Get the next_download_index from HF state if available
638
+ hf_next_index = current_state.get("next_download_index", 0)
639
+
640
+ # If HF state has a higher index, use that instead of local progress
641
+ if hf_next_index > 0:
642
+ start_index = hf_next_index
643
+ print(f"[{FLOW_ID}] Using next_download_index from HF state: {start_index}")
644
+ else:
645
+ # Fall back to local progress if HF state doesn't have a meaningful index
646
+ start_index = progress.get('last_processed_index', 0) + 1
647
+ if start_index < AUTO_START_INDEX:
648
+ start_index = AUTO_START_INDEX
649
+
650
+ # Use a dummy BackgroundTasks object for the startup task
651
+ # Note: FastAPI's startup events can't directly use BackgroundTasks, but we can use asyncio.create_task
652
+ # to run the long-running process in the background without blocking the server startup.
653
+ print(f"[{FLOW_ID}] Auto-starting processing from index: {start_index}...")
654
+ asyncio.create_task(process_dataset_task(start_index))
655
+
656
+ @app.get("/")
657
+ async def root():
658
+ progress = load_progress()
659
+ return {
660
+ "flow_id": FLOW_ID,
661
+ "status": "ready",
662
+ "last_processed_index": progress['last_processed_index'],
663
+ "total_files_in_list": len(progress['file_list']),
664
+ "processed_files_count": len(progress['processed_files']),
665
+ "total_servers": len(servers),
666
+ "busy_servers": sum(1 for s in servers if s.busy),
667
+ }
668
+
669
+ @app.post("/start_processing")
670
+ async def start_processing(request: ProcessStartRequest, background_tasks: BackgroundTasks):
671
+ """
672
+ Starts the sequential processing of zip files from the given index in the background.
673
+ """
674
+ start_index = request.start_index
675
+
676
+ print(f"[{FLOW_ID}] Received request to start processing from index: {start_index}. Starting background task.")
677
+
678
+ # Start the heavy processing in a background task so the API call returns immediately
679
+ # Note: The server is already auto-starting, but this allows for manual restart/override.
680
+ background_tasks.add_task(process_dataset_task, start_index)
681
+
682
+ return {"status": "processing", "start_index": start_index, "message": "Dataset processing started in background."}
683
+
684
+ if __name__ == "__main__":
685
+ import uvicorn
686
+ # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
687
  uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)