Fred808 commited on
Commit
708ee50
·
verified ·
1 Parent(s): f3af125

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +411 -360
app.py CHANGED
@@ -4,14 +4,14 @@ import time
4
  import asyncio
5
  import aiohttp
6
  import zipfile
 
7
  import shutil
8
  from typing import Dict, List, Set, Optional, Any
9
  from urllib.parse import quote
10
  from datetime import datetime
11
  from pathlib import Path
12
- import io
13
 
14
- from fastapi import FastAPI, BackgroundTasks, HTTPException, status, Request, Form
15
  from fastapi.responses import HTMLResponse
16
  from fastapi.templating import Jinja2Templates
17
  from pydantic import BaseModel, Field
@@ -19,21 +19,17 @@ from huggingface_hub import HfApi, hf_hub_download, HfFileSystem
19
  import uvicorn
20
 
21
  # --- Configuration ---
22
- # Flow Server ID and Port will be set via environment variables for easy deployment
23
  FLOW_ID = os.getenv("FLOW_ID", "flow_default")
24
- FLOW_PORT = int(os.getenv("FLOW_PORT", 8001)) # Default to 8001 for flow1
25
-
26
- # Manager Server Configuration
27
  MANAGER_URL = os.getenv("MANAGER_URL", "https://fred808-fcord.hf.space")
28
  MANAGER_COMPLETE_TASK_URL = f"{MANAGER_URL}/task/complete"
29
-
30
- # Hugging Face Configuration
31
- HF_TOKEN = os.getenv("HF_TOKEN", "") # User provided token
32
  HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3")
33
- HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/helium") # Target dataset for captions
 
34
 
35
- # Using the full list from the user's original code
36
- INITIAL_CAPTION_SERVERS = [
37
  "https://fred808-pil-4-1.hf.space/analyze",
38
  "https://fred808-pil-4-2.hf.space/analyze",
39
  "https://fred808-pil-4-3.hf.space/analyze",
@@ -78,126 +74,156 @@ MODEL_TYPE = "Florence-2-large"
78
  TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
79
  TEMP_DIR.mkdir(exist_ok=True)
80
 
81
- # State persistence file name in the output dataset
82
- STATE_FILENAME = f"processing_state_{FLOW_ID}.json"
83
-
84
  # --- Models ---
85
  class ProcessCourseRequest(BaseModel):
86
  course_name: Optional[str] = None
87
- start_index: int = 0 # New field for configurable start index
88
 
89
- class CaptionServer(BaseModel):
90
- url: str
91
- busy: bool = False
92
- total_processed: int = 0
93
- total_time: float = 0.0
94
- model: str = MODEL_TYPE
 
95
 
96
  @property
97
  def fps(self):
98
  return self.total_processed / self.total_time if self.total_time > 0 else 0
99
 
100
- class ProcessingState(BaseModel):
101
- # processed_files is a Set in the Pydantic model but stored as a List in JSON
102
- processed_files: Set[str] = Field(default_factory=set)
103
- last_processed_course: Optional[str] = None
104
- last_processed_index: int = 0
105
- servers: List[CaptionServer] = Field(default_factory=list)
106
-
107
- # --- Global State ---
108
- # Global state object
109
- state = ProcessingState()
110
- # Lock for safely modifying the global state (especially servers list)
111
- state_lock = asyncio.Lock()
112
- # Templates for the UI
113
- templates = Jinja2Templates(directory="templates")
114
- # Index for round-robin selection
 
 
 
 
 
 
 
115
  server_index = 0
 
 
 
 
 
116
 
117
- # --- State Management Functions ---
 
 
 
 
 
 
118
 
119
  async def load_state_from_hf():
120
- """Downloads and loads the processing state from the output dataset."""
121
  global state
122
- print(f"[{FLOW_ID}] Attempting to load state from {STATE_FILENAME} in {HF_OUTPUT_DATASET_ID}...")
123
- try:
124
- fs = HfFileSystem(token=HF_TOKEN)
125
- if fs.exists(f"{HF_OUTPUT_DATASET_ID}/{STATE_FILENAME}"):
126
- with fs.open(f"{HF_OUTPUT_DATASET_ID}/{STATE_FILENAME}", "r") as f:
127
- data = json.load(f)
128
-
129
- # Convert list back to set for processed_files
130
- if 'processed_files' in data and isinstance(data['processed_files'], list):
131
- data['processed_files'] = set(data['processed_files'])
132
-
133
- # Ensure servers are loaded, falling back to initial list if not present
134
- if not data.get('servers'):
135
- data['servers'] = [CaptionServer(url=url).dict() for url in INITIAL_CAPTION_SERVERS]
136
-
137
- # Manually parse servers to Pydantic models to handle nested structure
138
- data['servers'] = [CaptionServer(**s) for s in data['servers']]
139
-
140
- state = ProcessingState(**data)
141
- print(f"[{FLOW_ID}] State loaded successfully. Processed files: {len(state.processed_files)}")
142
- return True
143
- else:
144
- print(f"[{FLOW_ID}] State file not found. Initializing with default servers.")
145
- state.servers = [CaptionServer(url=url) for url in INITIAL_CAPTION_SERVERS]
146
- return False
147
- except Exception as e:
148
- print(f"[{FLOW_ID}] Error loading state: {e}. Initializing with default servers.")
149
- state.servers = [CaptionServer(url=url) for url in INITIAL_CAPTION_SERVERS]
150
- return False
151
 
152
  async def save_state_to_hf():
153
- """Saves the current processing state to the output dataset."""
 
 
 
 
154
  async with state_lock:
155
- print(f"[{FLOW_ID}] Saving state to {STATE_FILENAME} in {HF_OUTPUT_DATASET_ID}...")
 
 
 
 
 
 
156
  try:
157
- # Prepare data for saving, converting sets/objects to serializable types
158
- data_to_save = state.dict()
159
- data_to_save['processed_files'] = list(state.processed_files) # Convert set to list for JSON
160
-
161
- json_content = json.dumps(data_to_save, indent=2, ensure_ascii=False).encode('utf-8')
162
-
163
- api = HfApi(token=HF_TOKEN)
164
  api.upload_file(
165
  path_or_fileobj=io.BytesIO(json_content),
166
- path_in_repo=STATE_FILENAME,
167
  repo_id=HF_OUTPUT_DATASET_ID,
168
  repo_type="dataset",
169
- commit_message=f"[{FLOW_ID}] Update processing state"
170
  )
171
  print(f"[{FLOW_ID}] State saved successfully.")
172
  return True
173
  except Exception as e:
174
- print(f"[{FLOW_ID}] Error saving state: {e}")
175
  return False
176
 
177
- # --- Core Processing Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  async def get_available_server(timeout: float = 300.0) -> CaptionServer:
180
- """Round-robin selection of an available caption server from the global state."""
181
  global server_index
182
  start_time = time.time()
183
-
184
  while True:
185
- async with state_lock:
186
- # Check if there are any servers configured
187
- if not state.servers:
188
- raise RuntimeError("No caption servers are configured.")
189
-
190
- # Round-robin check for an available server
191
- for _ in range(len(state.servers)):
192
- server = state.servers[server_index % len(state.servers)]
193
- server_index = (server_index + 1) % len(state.servers)
194
- if not server.busy:
195
- return server
196
 
197
- # If all servers are busy, wait for a short period and check again
198
  await asyncio.sleep(0.5)
199
 
200
- # Check if timeout has been reached
201
  if time.time() - start_time > timeout:
202
  raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
203
 
@@ -207,21 +233,13 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
207
  for attempt in range(MAX_RETRIES):
208
  server = None
209
  try:
210
- # 1. Get an available server (will wait if all are busy, with a timeout)
211
  server = await get_available_server()
212
-
213
- async with state_lock:
214
- # Find the server in the global list and mark it busy
215
- server_in_state = next(s for s in state.servers if s.url == server.url)
216
- server_in_state.busy = True
217
-
218
  start_time = time.time()
219
 
220
- # Print a less verbose message only on the first attempt
221
  if attempt == 0:
222
  print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
223
 
224
- # 2. Prepare request data
225
  form_data = aiohttp.FormData()
226
  form_data.add_field('file',
227
  image_path.open('rb'),
@@ -229,24 +247,21 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
229
  content_type='image/jpeg')
230
  form_data.add_field('model_choice', MODEL_TYPE)
231
 
232
- # 3. Send request
233
  async with aiohttp.ClientSession() as session:
234
- # Increased timeout to 10 minutes (600s)
235
  async with session.post(server.url, data=form_data, timeout=600) as resp:
236
  if resp.status == 200:
237
  result = await resp.json()
238
  caption = result.get("caption")
239
 
240
  if caption:
241
- # Update progress counter
242
  progress_tracker['completed'] += 1
 
 
 
243
  if progress_tracker['completed'] % 50 == 0:
244
  print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} captions completed.")
245
 
246
- # Log success only if it's not a progress report interval
247
- if progress_tracker['completed'] % 50 != 0:
248
- print(f"[{FLOW_ID}] Success: {image_path.name} captioned by {server.url}")
249
-
250
  return {
251
  "course": course_name,
252
  "image_path": image_path.name,
@@ -255,51 +270,76 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
255
  }
256
  else:
257
  print(f"[{FLOW_ID}] Server {server.url} returned success but no caption for {image_path.name}. Retrying...")
258
- continue # Retry with a different server
259
  else:
260
  error_text = await resp.text()
261
  print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
262
- continue # Retry with a different server
263
 
264
- except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError, RuntimeError) as e:
265
- # RuntimeError is for "No caption servers are configured."
266
- print(f"[{FLOW_ID}] Connection/Timeout/Server error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
267
- continue # Retry with a different server
268
  except Exception as e:
269
  print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
270
- continue # Retry with a different server
271
  finally:
272
  if server:
273
  end_time = time.time()
274
- async with state_lock:
275
- # Find the server in the global list and update its stats
276
- try:
277
- server_in_state = next(s for s in state.servers if s.url == server.url)
278
- server_in_state.busy = False
279
- server_in_state.total_processed += 1
280
- server_in_state.total_time += (end_time - start_time)
281
- except StopIteration:
282
- # Server might have been removed while processing
283
- print(f"[{FLOW_ID}] Warning: Completed task on a server that was likely removed: {server.url}")
284
 
285
  print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
286
  return None
287
 
288
- async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
289
- """Uploads the final captions JSON file to the output dataset.
290
 
291
- The user requested the output JSON file to be named after the full zip file name.
292
- """
293
- # Use the full zip name, replacing the extension with .json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  caption_filename = Path(zip_full_name).with_suffix('.json').name
295
 
296
  try:
297
  print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
298
 
299
- # Create JSON content in memory
300
  json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
301
 
302
- api = HfApi(token=HF_TOKEN)
303
  api.upload_file(
304
  path_or_fileobj=io.BytesIO(json_content),
305
  path_in_repo=caption_filename,
@@ -315,225 +355,147 @@ async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> boo
315
  print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
316
  return False
317
 
318
- async def download_and_extract_zip(course_name: str) -> Optional[tuple[Path, str, str]]:
319
- """Downloads the next unprocessed zip file for the course and extracts its contents."""
320
- print(f"[{FLOW_ID}] Looking for files starting with '{course_name}' in frames/ directory...")
321
 
322
- try:
323
- api = HfApi(token=HF_TOKEN)
324
-
325
- # List all files in the frames directory
326
- repo_files = api.list_repo_files(
327
- repo_id=HF_DATASET_ID,
328
- repo_type="dataset"
329
- )
330
 
331
- # Find zip files that start with the course name
332
- matching_files = [
333
- f for f in repo_files
334
- if f.startswith(f"frames/{course_name}") and f.endswith('.zip')
335
- ]
336
 
337
- if not matching_files:
338
- print(f"[{FLOW_ID}] No zip files found starting with '{course_name}' in frames/ directory.")
339
- return None, None, None
340
-
341
  async with state_lock:
342
- # Filter out already processed files using the global state
343
- unprocessed_files = [f for f in matching_files if f not in state.processed_files]
344
-
345
- if not unprocessed_files:
346
- print(f"[{FLOW_ID}] No new zip files found for '{course_name}'.")
347
- return None, None, None
 
 
 
348
 
349
- repo_file_full_path = unprocessed_files[0] # e.g., frames/DAREEFSA_full_name.zip
350
-
351
- # Extract the full file name from the path (e.g., DAREEFSA_full_name.zip)
352
- zip_full_name = Path(repo_file_full_path).name
353
- print(f"[{FLOW_ID}] Found new matching file: {repo_file_full_path}. Full name: {zip_full_name}")
354
-
355
- # Use hf_hub_download to get the file path
356
- zip_path = hf_hub_download(
357
- repo_id=HF_DATASET_ID,
358
- filename=repo_file_full_path, # Use the full path in the repo
359
- repo_type="dataset",
360
- token=HF_TOKEN,
361
- )
362
-
363
- print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
364
-
365
- # Create a temporary directory for extraction
366
- extract_dir = TEMP_DIR / course_name
367
- extract_dir.mkdir(exist_ok=True)
368
-
369
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
370
- zip_ref.extractall(extract_dir)
371
 
372
- print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
373
-
374
- # Return the extraction directory, the full zip file name, and the repo path
375
- return extract_dir, zip_full_name, repo_file_full_path
376
-
377
- except Exception as e:
378
- print(f"[{FLOW_ID}] Error downloading or extracting zip for {course_name}: {e}")
379
- return None, None, None
380
-
381
- async def process_course_task(course_name: str, start_index: int = 0):
382
- """Main task to process a single course, looping until all files are processed."""
383
- print(f"[{FLOW_ID}] Starting continuous processing for course: {course_name} with start index {start_index}")
384
-
385
- global_success = True
386
-
387
- # Update state before starting the loop
388
- async with state_lock:
389
- state.last_processed_course = course_name
390
- state.last_processed_index = start_index
391
- await save_state_to_hf()
392
-
393
- # Loop to continuously check for new files matching the course_name prefix
394
- while True:
395
  extract_dir = None
396
  zip_full_name = None
397
- repo_file_full_path = None
398
 
399
  try:
400
- # download_and_extract_zip now uses global state to check for processed files
401
- download_result = await download_and_extract_zip(course_name)
402
 
403
- if download_result is None or download_result[0] is None:
404
- # No new files found, or an error occurred during search/download
405
- print(f"[{FLOW_ID}] No new files found for {course_name}. Exiting loop.")
406
- break
407
 
408
- extract_dir, zip_full_name, repo_file_full_path = download_result
409
-
410
- # --- Start Processing the single file ---
411
 
412
- # FIX: Use recursive glob to find images in subdirectories
413
  image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
414
-
415
- # Apply start_index logic
416
- if start_index > 0:
417
- original_count = len(image_paths)
418
- image_paths = image_paths[start_index:]
419
- print(f"[{FLOW_ID}] Applying start index {start_index}. Processing {len(image_paths)} images from {original_count} in {zip_full_name}.")
420
- # Reset start_index for subsequent zip files
421
- start_index = 0
422
- else:
423
- print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
424
-
425
- current_file_success = False
426
 
427
  if not image_paths:
428
- print(f"[{FLOW_ID}] No images to process after applying start index in {zip_full_name}. Marking as complete.")
429
- current_file_success = True
430
  else:
431
  # Initialize progress tracker
432
  progress_tracker = {
433
  'total': len(image_paths),
434
  'completed': 0
435
  }
436
- print(f"[{FLOW_ID}] Starting captioning for {progress_tracker['total']} images in {zip_full_name}...")
437
-
438
- # Create a semaphore to limit concurrent tasks to the number of available servers
439
  async with state_lock:
440
- # Use the current number of servers from the global state
441
- semaphore = asyncio.Semaphore(len(state.servers) if state.servers else 1)
442
-
 
 
443
  async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
444
  async with semaphore:
445
  return await send_image_for_captioning(image_path, course_name, progress_tracker)
446
 
447
- # Create a list of tasks for parallel captioning
448
  caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
449
-
450
- # Run all captioning tasks concurrently
451
  results = await asyncio.gather(*caption_tasks)
452
-
453
- # Filter out failed results
454
  all_captions = [r for r in results if r is not None]
455
 
456
- # Final progress report for the current file
457
  if len(all_captions) == len(image_paths):
458
  print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully completed all {len(all_captions)} captions.")
459
- current_file_success = True
460
  else:
461
  print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} captions.")
462
- current_file_success = False
463
 
464
  # Upload results
465
  if all_captions and zip_full_name:
466
- # Use the full zip file name for the upload as requested
467
- print(f"[{FLOW_ID}] Uploading {len(all_captions)} captions for {zip_full_name}...")
468
  if await upload_captions_to_hf(zip_full_name, all_captions):
469
  print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
470
- # If partial success, we still upload, but the overall task is marked as failure if any file failed
471
- if not current_file_success:
472
- global_success = False
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  else:
474
- print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}.")
475
- current_file_success = False
476
- global_success = False
 
 
 
 
 
477
  else:
478
- print(f"[{FLOW_ID}] No captions generated or zip_full_name is missing. Skipping upload for {zip_full_name}.")
479
- current_file_success = False
480
- global_success = False
481
-
482
- # --- End Processing the single file ---
483
-
484
- # Mark the file as processed and save state
485
- if current_file_success:
486
- async with state_lock:
487
- state.processed_files.add(repo_file_full_path)
488
- await save_state_to_hf()
489
 
490
  except Exception as e:
491
  error_message = str(e)
492
- print(f"[{FLOW_ID}] Critical error in process_course_task for {course_name}: {error_message}")
493
- global_success = False
 
 
 
 
494
 
495
  finally:
496
- # Cleanup temporary files for the current file
497
  if extract_dir and extract_dir.exists():
498
  print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
499
  shutil.rmtree(extract_dir, ignore_errors=True)
500
-
501
- # If an unrecoverable error occurred (e.g., during search/download), break the loop
502
- if download_result is None and extract_dir is None:
503
- break
504
-
505
- # --- Final Report after the loop is complete ---
506
- print(f"[{FLOW_ID}] All processing loops complete for {course_name}.")
507
-
508
- # Report completion to manager
509
- final_error_message = error_message if not global_success else None
510
- await report_completion(course_name, global_success, final_error_message)
511
-
512
- return global_success
513
-
514
- async def report_completion(course_name: str, success: bool, error_message: Optional[str] = None):
515
- """Reports the task result back to the Manager Server."""
516
- print(f"[{FLOW_ID}] Reporting completion for {course_name} (Success: {success})...")
517
-
518
- payload = {
519
- "flow_id": FLOW_ID,
520
- "course_name": course_name,
521
- "success": success,
522
- "error_message": error_message
523
- }
524
-
525
- try:
526
- async with aiohttp.ClientSession() as session:
527
- async with session.post(MANAGER_COMPLETE_TASK_URL, json=payload) as resp:
528
- if resp.status != 200:
529
- print(f"[{FLOW_ID}] ERROR: Manager reported non-200 status: {resp.status} - {await resp.text()}")
530
- else:
531
- print(f"[{FLOW_ID}] Successfully reported completion to Manager.")
532
-
533
- except aiohttp.ClientError as e:
534
- print(f"[{FLOW_ID}] CRITICAL ERROR: Could not connect to Manager at {MANAGER_COMPLETE_TASK_URL}. Task completion not reported. Error: {e}")
535
- except Exception as e:
536
- print(f"[{FLOW_ID}] Unexpected error during reporting: {e}")
537
 
538
  # --- FastAPI App and Endpoints ---
539
 
@@ -543,78 +505,167 @@ app = FastAPI(
543
  version="2.0.0"
544
  )
545
 
 
 
 
546
  @app.on_event("startup")
547
  async def startup_event():
548
- print(f"Flow Server {FLOW_ID} starting up...")
549
- # Load state first before starting the server
550
- await load_state_from_hf()
551
  print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}. Manager URL: {MANAGER_URL}")
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  @app.get("/", response_class=HTMLResponse)
554
- async def root(request: Request):
555
- """The main UI dashboard."""
556
  async with state_lock:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  context = {
558
  "request": request,
559
  "flow_id": FLOW_ID,
560
- "status": "ready" if not any(s.busy for s in state.servers) else "processing",
561
- "manager_url": MANAGER_URL,
562
- "servers": state.servers,
563
- "total_servers": len(state.servers),
564
- "busy_servers": sum(1 for s in state.servers if s.busy),
565
- "processed_files_count": len(state.processed_files),
566
- "last_course": state.last_processed_course,
567
- "last_index": state.last_processed_index,
 
 
 
568
  }
569
- return templates.TemplateResponse("dashboard.html", context)
570
 
571
- @app.post("/add_server")
572
- async def add_server_endpoint(server_url: str = Form(...)):
573
- """API endpoint to add a new caption server."""
574
- if not server_url.endswith("/analyze"):
575
- server_url = server_url.rstrip("/") + "/analyze"
 
 
 
 
 
576
 
577
  async with state_lock:
578
- # Check if server already exists
579
- if any(s.url == server_url for s in state.servers):
580
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Server already exists.")
581
-
582
- new_server = CaptionServer(url=server_url)
583
- state.servers.append(new_server)
584
-
585
- await save_state_to_hf()
586
- return {"status": "success", "message": f"Server {server_url} added.", "server": new_server.dict()}
 
 
 
 
 
 
 
 
 
 
 
587
 
588
- @app.post("/remove_server")
589
- async def remove_server_endpoint(server_url: str = Form(...)):
590
- """API endpoint to remove a caption server."""
591
- async with state_lock:
592
- initial_count = len(state.servers)
593
- state.servers = [s for s in state.servers if s.url != server_url]
594
- if len(state.servers) == initial_count:
595
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Server not found.")
596
 
597
- await save_state_to_hf()
598
- return {"status": "success", "message": f"Server {server_url} removed."}
599
-
600
- @app.post("/process_course")
601
- async def process_course_api(request: ProcessCourseRequest, background_tasks: BackgroundTasks):
602
- """
603
- Receives a course name and optional start index and starts processing in the background.
604
- """
605
- course_name = request.course_name
606
- start_index = request.start_index
607
 
608
- if not course_name:
609
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Course name is required.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
- print(f"[{FLOW_ID}] Received course: {course_name} with start index {start_index}. Starting background task.")
612
-
613
- # Start the heavy processing in a background task so the API call returns immediately
614
- background_tasks.add_task(process_course_task, course_name, start_index)
615
-
616
- return {"status": "processing", "course_name": course_name, "start_index": start_index, "message": "Processing started in background."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
 
618
  if __name__ == "__main__":
619
- # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
620
  uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
 
4
  import asyncio
5
  import aiohttp
6
  import zipfile
7
+ import io
8
  import shutil
9
  from typing import Dict, List, Set, Optional, Any
10
  from urllib.parse import quote
11
  from datetime import datetime
12
  from pathlib import Path
 
13
 
14
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, status, Request
15
  from fastapi.responses import HTMLResponse
16
  from fastapi.templating import Jinja2Templates
17
  from pydantic import BaseModel, Field
 
19
  import uvicorn
20
 
21
  # --- Configuration ---
 
22
  FLOW_ID = os.getenv("FLOW_ID", "flow_default")
23
+ FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
 
 
24
  MANAGER_URL = os.getenv("MANAGER_URL", "https://fred808-fcord.hf.space")
25
  MANAGER_COMPLETE_TASK_URL = f"{MANAGER_URL}/task/complete"
26
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
 
 
27
  HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3")
28
+ HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/helium")
29
+ STATE_FILE_NAME = f"{FLOW_ID}_state.json"
30
 
31
+ # Using the full list from the user's original code for actual deployment
32
+ CAPTION_SERVERS = [
33
  "https://fred808-pil-4-1.hf.space/analyze",
34
  "https://fred808-pil-4-2.hf.space/analyze",
35
  "https://fred808-pil-4-3.hf.space/analyze",
 
74
  TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
75
  TEMP_DIR.mkdir(exist_ok=True)
76
 
 
 
 
77
  # --- Models ---
78
  class ProcessCourseRequest(BaseModel):
79
  course_name: Optional[str] = None
 
80
 
81
+ class CaptionServer:
82
+ def __init__(self, url):
83
+ self.url = url
84
+ self.busy = False
85
+ self.total_processed = 0
86
+ self.total_time = 0
87
+ self.model = MODEL_TYPE
88
 
89
  @property
90
  def fps(self):
91
  return self.total_processed / self.total_time if self.total_time > 0 else 0
92
 
93
+ class ServerState(BaseModel):
94
+ # The list of all zip files in the dataset (frames/ directory)
95
+ all_zip_files: List[str] = Field(default_factory=list)
96
+ # The set of zip files that have been successfully processed and uploaded
97
+ processed_files: Set[str] = Field(default_factory=set)
98
+ # The index in all_zip_files from which the next download should start
99
+ current_index: int = 0
100
+ # Total number of files to process
101
+ total_files: int = 0
102
+ # Status of the current operation
103
+ status: str = "Idle"
104
+ # Name of the file currently being processed
105
+ current_file: Optional[str] = None
106
+ # Progress within the current file
107
+ current_file_progress: str = "0/0"
108
+ # Timestamp of the last update
109
+ last_update: str = datetime.now().isoformat()
110
+ # Flag to control the processing loop
111
+ is_running: bool = False
112
+
113
+ # Global state for caption servers and the overall server state
114
+ servers = [CaptionServer(url) for url in CAPTION_SERVERS]
115
  server_index = 0
116
+ state = ServerState()
117
+ # Lock for thread-safe access to the global state
118
+ state_lock = asyncio.Lock()
119
+
120
+ # --- Persistence Functions ---
121
 
122
+ def get_hf_api():
123
+ """Helper to get HfApi instance."""
124
+ return HfApi(token=HF_TOKEN)
125
+
126
+ def get_hf_fs():
127
+ """Helper to get HfFileSystem instance."""
128
+ return HfFileSystem(token=HF_TOKEN)
129
 
130
  async def load_state_from_hf():
131
+ """Loads the state from the Hugging Face output dataset."""
132
  global state
133
+ fs = get_hf_fs()
134
+ state_path = f"{HF_OUTPUT_DATASET_ID}/{STATE_FILE_NAME}"
135
+
136
+ async with state_lock:
137
+ try:
138
+ if fs.exists(state_path):
139
+ print(f"[{FLOW_ID}] Loading state from {state_path}...")
140
+ with fs.open(state_path, 'rb') as f:
141
+ data = json.load(f)
142
+ # Convert list of processed files back to a set
143
+ if 'processed_files' in data and isinstance(data['processed_files'], list):
144
+ data['processed_files'] = set(data['processed_files'])
145
+ state = ServerState.parse_obj(data)
146
+ print(f"[{FLOW_ID}] State loaded successfully. Current index: {state.current_index}")
147
+ else:
148
+ print(f"[{FLOW_ID}] State file {state_path} not found. Starting with default state.")
149
+ except Exception as e:
150
+ print(f"[{FLOW_ID}] Error loading state from HF: {e}. Starting with default state.")
151
+ state = ServerState()
 
 
 
 
 
 
 
 
 
 
152
 
153
  async def save_state_to_hf():
154
+ """Saves the current state to the Hugging Face output dataset."""
155
+ global state
156
+ api = get_hf_api()
157
+ state_path = STATE_FILE_NAME
158
+
159
  async with state_lock:
160
+ state.last_update = datetime.now().isoformat()
161
+ # Convert set of processed files to a list for JSON serialization
162
+ data_to_save = state.dict()
163
+ data_to_save['processed_files'] = list(state.processed_files)
164
+
165
+ json_content = json.dumps(data_to_save, indent=2, ensure_ascii=False).encode('utf-8')
166
+
167
  try:
168
+ print(f"[{FLOW_ID}] Saving state to {state_path} in {HF_OUTPUT_DATASET_ID}...")
 
 
 
 
 
 
169
  api.upload_file(
170
  path_or_fileobj=io.BytesIO(json_content),
171
+ path_in_repo=state_path,
172
  repo_id=HF_OUTPUT_DATASET_ID,
173
  repo_type="dataset",
174
+ commit_message=f"[{FLOW_ID}] Update server state. Index: {state.current_index}"
175
  )
176
  print(f"[{FLOW_ID}] State saved successfully.")
177
  return True
178
  except Exception as e:
179
+ print(f"[{FLOW_ID}] Error saving state to HF: {e}")
180
  return False
181
 
182
+ async def update_file_list():
183
+ """Fetches the list of all zip files from the BG3 dataset."""
184
+ global state
185
+ api = get_hf_api()
186
+
187
+ async with state_lock:
188
+ try:
189
+ state.status = "Updating file list..."
190
+ print(f"[{FLOW_ID}] Fetching file list from {HF_DATASET_ID}...")
191
+ repo_files = api.list_repo_files(
192
+ repo_id=HF_DATASET_ID,
193
+ repo_type="dataset"
194
+ )
195
+
196
+ # Filter for zip files in the 'frames/' directory
197
+ zip_files = sorted([
198
+ f for f in repo_files
199
+ if f.startswith("frames/") and f.endswith('.zip')
200
+ ])
201
+
202
+ state.all_zip_files = zip_files
203
+ state.total_files = len(zip_files)
204
+ state.status = "File list updated."
205
+ print(f"[{FLOW_ID}] Found {state.total_files} zip files.")
206
+ except Exception as e:
207
+ state.status = f"Error updating file list: {e}"
208
+ print(f"[{FLOW_ID}] Error updating file list: {e}")
209
+
210
+ await save_state_to_hf()
211
+
212
+ # --- Core Processing Functions (Modified) ---
213
 
214
  async def get_available_server(timeout: float = 300.0) -> CaptionServer:
215
+ """Round-robin selection of an available caption server."""
216
  global server_index
217
  start_time = time.time()
 
218
  while True:
219
+ for _ in range(len(servers)):
220
+ server = servers[server_index]
221
+ server_index = (server_index + 1) % len(servers)
222
+ if not server.busy:
223
+ return server
 
 
 
 
 
 
224
 
 
225
  await asyncio.sleep(0.5)
226
 
 
227
  if time.time() - start_time > timeout:
228
  raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
229
 
 
233
  for attempt in range(MAX_RETRIES):
234
  server = None
235
  try:
 
236
  server = await get_available_server()
237
+ server.busy = True
 
 
 
 
 
238
  start_time = time.time()
239
 
 
240
  if attempt == 0:
241
  print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
242
 
 
243
  form_data = aiohttp.FormData()
244
  form_data.add_field('file',
245
  image_path.open('rb'),
 
247
  content_type='image/jpeg')
248
  form_data.add_field('model_choice', MODEL_TYPE)
249
 
 
250
  async with aiohttp.ClientSession() as session:
 
251
  async with session.post(server.url, data=form_data, timeout=600) as resp:
252
  if resp.status == 200:
253
  result = await resp.json()
254
  caption = result.get("caption")
255
 
256
  if caption:
257
+ # Update progress counter and global state
258
  progress_tracker['completed'] += 1
259
+ async with state_lock:
260
+ state.current_file_progress = f"{progress_tracker['completed']}/{progress_tracker['total']}"
261
+
262
  if progress_tracker['completed'] % 50 == 0:
263
  print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} captions completed.")
264
 
 
 
 
 
265
  return {
266
  "course": course_name,
267
  "image_path": image_path.name,
 
270
  }
271
  else:
272
  print(f"[{FLOW_ID}] Server {server.url} returned success but no caption for {image_path.name}. Retrying...")
273
+ continue
274
  else:
275
  error_text = await resp.text()
276
  print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
277
+ continue
278
 
279
+ except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
280
+ print(f"[{FLOW_ID}] Connection/Timeout error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
281
+ continue
 
282
  except Exception as e:
283
  print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
284
+ continue
285
  finally:
286
  if server:
287
  end_time = time.time()
288
+ server.busy = False
289
+ server.total_processed += 1
290
+ server.total_time += (end_time - start_time)
 
 
 
 
 
 
 
291
 
292
  print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
293
  return None
294
 
295
+ async def download_and_extract_zip(repo_file_full_path: str) -> Optional[tuple[Path, str]]:
296
+ """Downloads the zip file at the given path and extracts its contents."""
297
 
298
+ zip_full_name = Path(repo_file_full_path).name
299
+ course_name = zip_full_name.split('_')[0] # Assuming course name is the prefix before the first underscore
300
+
301
+ try:
302
+ print(f"[{FLOW_ID}] Downloading file: {repo_file_full_path}. Full name: {zip_full_name}")
303
+
304
+ # Use hf_hub_download to get the file path
305
+ zip_path = hf_hub_download(
306
+ repo_id=HF_DATASET_ID,
307
+ filename=repo_file_full_path, # Use the full path in the repo
308
+ repo_type="dataset",
309
+ token=HF_TOKEN,
310
+ )
311
+
312
+ print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
313
+
314
+ # Create a temporary directory for extraction
315
+ extract_dir = TEMP_DIR / course_name / zip_full_name.replace('.', '_')
316
+ extract_dir.mkdir(parents=True, exist_ok=True)
317
+
318
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
319
+ zip_ref.extractall(extract_dir)
320
+
321
+ print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
322
+
323
+ # Clean up the downloaded zip file
324
+ os.remove(zip_path)
325
+
326
+ # Return the extraction directory and the full zip file name
327
+ return extract_dir, zip_full_name
328
+
329
+ except Exception as e:
330
+ print(f"[{FLOW_ID}] Error downloading or extracting zip for {repo_file_full_path}: {e}")
331
+ return None
332
+
333
+ async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
334
+ """Uploads the final captions JSON file to the output dataset."""
335
  caption_filename = Path(zip_full_name).with_suffix('.json').name
336
 
337
  try:
338
  print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
339
 
 
340
  json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
341
 
342
+ api = get_hf_api()
343
  api.upload_file(
344
  path_or_fileobj=io.BytesIO(json_content),
345
  path_in_repo=caption_filename,
 
355
  print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
356
  return False
357
 
358
+ async def process_next_file_task():
359
+ """Task to process the next file in the list based on the current index."""
360
+ global state
361
 
362
+ if not state.is_running:
363
+ print(f"[{FLOW_ID}] Processing loop is not running. Exiting task.")
364
+ return
 
 
 
 
 
365
 
366
+ while state.is_running:
367
+ repo_file_full_path = None
368
+ current_index = -1
 
 
369
 
 
 
 
 
370
  async with state_lock:
371
+ current_index = state.current_index
372
+ if current_index >= state.total_files:
373
+ state.status = "Finished processing all files."
374
+ state.is_running = False
375
+ print(f"[{FLOW_ID}] Reached end of file list. Stopping processing.")
376
+ await save_state_to_hf()
377
+ break
378
+
379
+ repo_file_full_path = state.all_zip_files[current_index]
380
 
381
+ if repo_file_full_path in state.processed_files:
382
+ state.current_index += 1
383
+ state.status = f"Skipping processed file: {Path(repo_file_full_path).name}"
384
+ state.current_file = Path(repo_file_full_path).name
385
+ print(f"[{FLOW_ID}] Skipping already processed file: {repo_file_full_path}")
386
+ await save_state_to_hf()
387
+ continue
388
+
389
+ # Mark the file as in-progress in the state
390
+ state.status = f"Processing file {current_index + 1}/{state.total_files}"
391
+ state.current_file = Path(repo_file_full_path).name
392
+ state.current_file_progress = "0/0"
393
+ await save_state_to_hf()
 
 
 
 
 
 
 
 
 
394
 
395
+ # --- Start Processing ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  extract_dir = None
397
  zip_full_name = None
398
+ global_success = False
399
 
400
  try:
401
+ download_result = await download_and_extract_zip(repo_file_full_path)
 
402
 
403
+ if download_result is None:
404
+ raise Exception("Failed to download or extract zip file.")
 
 
405
 
406
+ extract_dir, zip_full_name = download_result
407
+ course_name = zip_full_name.split('_')[0]
 
408
 
409
+ # Find images
410
  image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
411
+ print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  if not image_paths:
414
+ print(f"[{FLOW_ID}] No images found in {zip_full_name}. Marking as complete.")
415
+ global_success = True
416
  else:
417
  # Initialize progress tracker
418
  progress_tracker = {
419
  'total': len(image_paths),
420
  'completed': 0
421
  }
 
 
 
422
  async with state_lock:
423
+ state.current_file_progress = f"0/{len(image_paths)}"
424
+ await save_state_to_hf()
425
+
426
+ # Create and run captioning tasks
427
+ semaphore = asyncio.Semaphore(len(servers))
428
  async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
429
  async with semaphore:
430
  return await send_image_for_captioning(image_path, course_name, progress_tracker)
431
 
 
432
  caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
 
 
433
  results = await asyncio.gather(*caption_tasks)
 
 
434
  all_captions = [r for r in results if r is not None]
435
 
436
+ # Final progress report
437
  if len(all_captions) == len(image_paths):
438
  print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully completed all {len(all_captions)} captions.")
439
+ global_success = True
440
  else:
441
  print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} captions.")
442
+ global_success = False
443
 
444
  # Upload results
445
  if all_captions and zip_full_name:
 
 
446
  if await upload_captions_to_hf(zip_full_name, all_captions):
447
  print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
448
+ # If upload is successful, we mark the file as processed, regardless of partial success
449
+ # The uploaded JSON will reflect the actual number of captions
450
+ if global_success:
451
+ print(f"[{FLOW_ID}] Fully processed and uploaded: {zip_full_name}")
452
+ else:
453
+ print(f"[{FLOW_ID}] Partially processed but uploaded: {zip_full_name}. Needs manual review.")
454
+
455
+ # Mark as processed only if upload succeeded
456
+ async with state_lock:
457
+ state.processed_files.add(repo_file_full_path)
458
+ state.current_index += 1
459
+ state.current_file = None
460
+ state.current_file_progress = "0/0"
461
+ state.status = "Idle"
462
+ await save_state_to_hf()
463
+
464
  else:
465
+ print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}. Will retry this file later.")
466
+ # Do NOT increment index or mark as processed, so it will be retried
467
+ async with state_lock:
468
+ state.status = f"Error uploading captions for {zip_full_name}. Retrying later."
469
+ await save_state_to_hf()
470
+ # Wait before retrying to avoid immediate re-attempt on a transient error
471
+ await asyncio.sleep(60)
472
+
473
  else:
474
+ print(f"[{FLOW_ID}] No captions generated or zip_full_name is missing. Skipping upload for {zip_full_name}. Will retry later.")
475
+ # Do NOT increment index or mark as processed
476
+ async with state_lock:
477
+ state.status = f"No captions generated for {zip_full_name}. Retrying later."
478
+ await save_state_to_hf()
479
+ await asyncio.sleep(60)
 
 
 
 
 
480
 
481
  except Exception as e:
482
  error_message = str(e)
483
+ print(f"[{FLOW_ID}] Critical error in process_next_file_task for {repo_file_full_path}: {error_message}")
484
+ async with state_lock:
485
+ state.status = f"CRITICAL ERROR for {Path(repo_file_full_path).name}. Retrying later. Error: {error_message[:50]}..."
486
+ await save_state_to_hf()
487
+ # Wait before retrying
488
+ await asyncio.sleep(60)
489
 
490
  finally:
491
+ # Cleanup temporary files
492
  if extract_dir and extract_dir.exists():
493
  print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
494
  shutil.rmtree(extract_dir, ignore_errors=True)
495
+
496
+ # If the loop is still running, wait a short time before checking for the next file
497
+ if state.is_running:
498
+ await asyncio.sleep(5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
  # --- FastAPI App and Endpoints ---
501
 
 
505
  version="2.0.0"
506
  )
507
 
508
+ # Setup Jinja2 templates for the UI
509
+ templates = Jinja2Templates(directory="templates")
510
+
511
  @app.on_event("startup")
512
  async def startup_event():
 
 
 
513
  print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}. Manager URL: {MANAGER_URL}")
514
+ # 1. Load state from persistence (HF)
515
+ await load_state_from_hf()
516
+ # 2. Update the list of all files from the dataset
517
+ await update_file_list()
518
+ # 3. Start the continuous processing task if the index is valid
519
+ if state.current_index < state.total_files:
520
+ state.is_running = True
521
+ BackgroundTasks().add_task(process_next_file_task)
522
+ else:
523
+ state.is_running = False
524
+ print(f"[{FLOW_ID}] Index {state.current_index} is out of bounds. Starting in Idle mode.")
525
+
526
 
527
  @app.get("/", response_class=HTMLResponse)
528
+ async def home(request: Request):
529
+ """Home page with status and controls."""
530
  async with state_lock:
531
+ processed_count = len(state.processed_files)
532
+ remaining_count = state.total_files - processed_count
533
+
534
+ # Calculate server stats
535
+ server_stats = [
536
+ {
537
+ "url": s.url,
538
+ "busy": s.busy,
539
+ "processed": s.total_processed,
540
+ "fps": f"{s.fps:.2f}"
541
+ } for s in servers
542
+ ]
543
+
544
+ # Calculate overall FPS
545
+ total_processed = sum(s.total_processed for s in servers)
546
+ total_time = sum(s.total_time for s in servers)
547
+ overall_fps = total_processed / total_time if total_time > 0 else 0
548
+
549
  context = {
550
  "request": request,
551
  "flow_id": FLOW_ID,
552
+ "status": state.status,
553
+ "is_running": state.is_running,
554
+ "total_files": state.total_files,
555
+ "processed_count": processed_count,
556
+ "remaining_count": remaining_count,
557
+ "current_index": state.current_index,
558
+ "current_file": state.current_file if state.current_file else "N/A",
559
+ "current_file_progress": state.current_file_progress,
560
+ "last_update": state.last_update,
561
+ "overall_fps": f"{overall_fps:.2f}",
562
+ "server_stats": server_stats
563
  }
564
+ return templates.TemplateResponse("index.html", context)
565
 
566
+ @app.post("/set_index")
567
+ async def set_index(request: Request, background_tasks: BackgroundTasks):
568
+ """Endpoint to manually set the start index."""
569
+ global state
570
+
571
+ form = await request.form()
572
+ try:
573
+ new_index = int(form.get("start_index"))
574
+ except (TypeError, ValueError):
575
+ raise HTTPException(status_code=400, detail="Invalid index value.")
576
 
577
  async with state_lock:
578
+ if 0 <= new_index < state.total_files:
579
+ state.current_index = new_index
580
+ state.status = f"Index set to {new_index}. Restarting processing."
581
+
582
+ # If the loop is not running, start it
583
+ if not state.is_running:
584
+ state.is_running = True
585
+ background_tasks.add_task(process_next_file_task)
586
+
587
+ await save_state_to_hf()
588
+ print(f"[{FLOW_ID}] Index manually set to {new_index}.")
589
+ return {"status": "success", "message": f"Start index set to {new_index}. Processing will resume from this point."}
590
+ elif new_index == state.total_files:
591
+ state.current_index = new_index
592
+ state.is_running = False
593
+ state.status = "Finished processing all files."
594
+ await save_state_to_hf()
595
+ return {"status": "success", "message": "Index set to end of list. Processing stopped."}
596
+ else:
597
+ raise HTTPException(status_code=400, detail=f"Index {new_index} is out of bounds (0 to {state.total_files}).")
598
 
599
+ @app.post("/control_processing")
600
+ async def control_processing(request: Request, background_tasks: BackgroundTasks):
601
+ """Endpoint to start/stop the processing loop."""
602
+ global state
 
 
 
 
603
 
604
+ form = await request.form()
605
+ action = form.get("action")
 
 
 
 
 
 
 
 
606
 
607
+ async with state_lock:
608
+ if action == "start":
609
+ if not state.is_running and state.current_index < state.total_files:
610
+ state.is_running = True
611
+ state.status = "Processing started."
612
+ background_tasks.add_task(process_next_file_task)
613
+ await save_state_to_hf()
614
+ return {"status": "success", "message": "Processing loop started."}
615
+ elif state.current_index >= state.total_files:
616
+ return {"status": "error", "message": "Cannot start. All files have been processed."}
617
+ else:
618
+ return {"status": "info", "message": "Processing is already running."}
619
+ elif action == "stop":
620
+ if state.is_running:
621
+ state.is_running = False
622
+ state.status = "Processing stopped by user."
623
+ await save_state_to_hf()
624
+ return {"status": "success", "message": "Processing loop stopped."}
625
+ else:
626
+ return {"status": "info", "message": "Processing is already stopped."}
627
+ else:
628
+ raise HTTPException(status_code=400, detail="Invalid action.")
629
+
630
+ @app.get("/status")
631
+ async def get_status():
632
+ """API endpoint to get the current server status as JSON."""
633
+ async with state_lock:
634
+ processed_count = len(state.processed_files)
635
 
636
+ server_stats = [
637
+ {
638
+ "url": s.url,
639
+ "busy": s.busy,
640
+ "processed": s.total_processed,
641
+ "fps": f"{s.fps:.2f}"
642
+ } for s in servers
643
+ ]
644
+
645
+ total_processed = sum(s.total_processed for s in servers)
646
+ total_time = sum(s.total_time for s in servers)
647
+ overall_fps = total_processed / total_time if total_time > 0 else 0
648
+
649
+ return {
650
+ "flow_id": FLOW_ID,
651
+ "status": state.status,
652
+ "is_running": state.is_running,
653
+ "total_files": state.total_files,
654
+ "processed_count": processed_count,
655
+ "remaining_count": state.total_files - processed_count,
656
+ "current_index": state.current_index,
657
+ "current_file": state.current_file,
658
+ "current_file_progress": state.current_file_progress,
659
+ "last_update": state.last_update,
660
+ "overall_fps": f"{overall_fps:.2f}",
661
+ "server_stats": server_stats
662
+ }
663
+
664
+ # The original /process_course endpoint is now obsolete as the server manages its own queue
665
+ # @app.post("/process_course")
666
+ # async def process_course(request: ProcessCourseRequest, background_tasks: BackgroundTasks):
667
+ # return {"status": "obsolete", "message": "The server now manages its own processing queue based on the index."}
668
+
669
 
670
  if __name__ == "__main__":
 
671
  uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)