Samfredoly commited on
Commit
99dce0a
·
verified ·
1 Parent(s): 890b86d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +352 -539
app.py CHANGED
@@ -5,10 +5,9 @@ import asyncio
5
  import aiohttp
6
  import zipfile
7
  import shutil
8
- import threading
9
  from typing import Dict, List, Set, Optional, Tuple, Any
10
  from urllib.parse import quote
11
- from datetime import datetime, timedelta
12
  from pathlib import Path
13
  import io
14
 
@@ -17,12 +16,12 @@ from pydantic import BaseModel, Field
17
  from huggingface_hub import HfApi, hf_hub_download
18
 
19
  # --- Configuration ---
20
- AUTO_START_INDEX = 1290
21
  FLOW_ID = os.getenv("FLOW_ID", "flow_default")
22
  FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
23
  HF_TOKEN = os.getenv("HF_TOKEN", "")
24
  HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_Vid")
25
- HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/AVTF")
26
 
27
  # Progress and State Tracking
28
  PROGRESS_FILE = Path("processing_progress.json")
@@ -30,10 +29,12 @@ HF_STATE_FILE = "processing_state_transcriptions.json"
30
  LOCAL_STATE_FOLDER = Path(".state")
31
  LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
32
 
33
- AUDIO_FILE_PREFIX = "audio/"
 
 
34
 
35
- # Reference dataset for filename mapping
36
- REFERENCE_REPO_ID = os.getenv("REFERENCE_REPO_ID", "Fred808/BG3")
37
 
38
  WHISPER_SERVERS = [
39
  "https://fred1012-switch3.hf.space/transcribe",
@@ -58,98 +59,58 @@ WHISPER_SERVERS = [
58
  "https://Eliasishere-mint-20.hf.space/transcribe"
59
  ]
60
 
61
- MODEL_TYPE = "whisper-small"
62
- ZIP_UPLOAD_THRESHOLD = 100 # Upload and zip after this many transcriptions
63
-
64
  # Temporary storage for audio files
65
  TEMP_DIR = Path(f"temp_audio_{FLOW_ID}")
66
  TEMP_DIR.mkdir(exist_ok=True)
67
 
68
- # Temporary storage for transcription results
69
- RESULTS_DIR = Path(f"transcription_results_{FLOW_ID}")
70
- RESULTS_DIR.mkdir(exist_ok=True)
71
-
72
  # --- Models ---
 
 
 
73
  class WhisperServer:
74
- def __init__(self, url):
75
  self.url = url
76
- self.busy = False
 
77
  self.total_processed = 0
78
- self.total_time = 0
79
- self.model = MODEL_TYPE
80
 
81
  @property
82
  def fps(self):
 
83
  return self.total_processed / self.total_time if self.total_time > 0 else 0
84
-
85
- class RateLimiter:
86
- """Tracks uploads per hour with max limit of 120, stops at 128."""
87
- def __init__(self, max_per_hour: int = 120, stop_at: int = 128):
88
- self.max_per_hour = max_per_hour
89
- self.stop_at = stop_at
90
- self.uploads = [] # List of timestamps
91
- self.lock = asyncio.Lock()
92
-
93
- async def wait_if_needed(self) -> bool:
94
- """
95
- Returns True if upload can proceed, False if rate limit reached.
96
- Waits if needed to stay within limits.
97
- """
98
- async with self.lock:
99
- now = datetime.now()
100
- one_hour_ago = now - timedelta(hours=1)
101
-
102
- # Remove old uploads outside the 1-hour window
103
- self.uploads = [ts for ts in self.uploads if ts > one_hour_ago]
104
-
105
- # If we've reached the hard stop limit (128), return False
106
- if len(self.uploads) >= self.stop_at:
107
- print(f"[{FLOW_ID}] ⏸️ Upload limit ({self.stop_at}) reached. Waiting for next hour...")
108
- return False
109
-
110
- # If we're at the soft limit (120), add timestamp and continue
111
- if len(self.uploads) < self.max_per_hour:
112
- self.uploads.append(now)
113
- remaining = self.max_per_hour - len(self.uploads)
114
- print(f"[{FLOW_ID}] 📤 Upload #{len(self.uploads)}/120 this hour ({remaining} remaining)")
115
- return True
116
-
117
- # Between soft limit and hard stop, add and continue
118
- self.uploads.append(now)
119
- print(f"[{FLOW_ID}] ⚠️ Upload #{len(self.uploads)}/120 this hour (approaching limit)")
120
- return True
121
-
122
- async def can_upload(self) -> bool:
123
- """Check if upload is allowed without waiting."""
124
- async with self.lock:
125
- now = datetime.now()
126
- one_hour_ago = now - timedelta(hours=1)
127
- self.uploads = [ts for ts in self.uploads if ts > one_hour_ago]
128
- return len(self.uploads) < self.stop_at
129
-
130
- # Global rate limiter
131
- rate_limiter = RateLimiter(max_per_hour=120, stop_at=128)
132
 
133
  # Global state for whisper servers
134
  servers = [WhisperServer(url) for url in WHISPER_SERVERS]
135
- server_index = 0
136
 
 
137
 
138
  def load_progress() -> Dict:
 
139
  if PROGRESS_FILE.exists():
140
  try:
141
  with PROGRESS_FILE.open('r') as f:
142
  return json.load(f)
143
  except json.JSONDecodeError:
144
  print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
 
145
 
146
-
147
  return {
148
  "last_processed_index": 0,
149
- "processed_files": {},
150
- "file_list": [],
151
- "transcription_count": 0,
152
- "reference_map": {},
153
  }
154
 
155
  def save_progress(progress_data: Dict):
@@ -175,13 +136,10 @@ def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str,
175
  if "next_download_index" not in data:
176
  data["next_download_index"] = 0
177
 
178
- if "transcription_count" not in data:
179
- data["transcription_count"] = 0
180
-
181
  return data
182
  except json.JSONDecodeError:
183
  print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}")
184
- return default_value
185
 
186
  def save_json_state(file_path: str, data: Dict[str, Any]):
187
  """Save state to JSON file"""
@@ -191,10 +149,10 @@ def save_json_state(file_path: str, data: Dict[str, Any]):
191
  async def download_hf_state() -> Dict[str, Any]:
192
  """Downloads the state file from Hugging Face or returns a default state."""
193
  local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
194
- default_state = {"next_download_index": 0, "file_states": {}, "transcription_count": 0}
195
 
196
  try:
197
- # Check if the file exists in the output repo
198
  files = HfApi(token=HF_TOKEN).list_repo_files(
199
  repo_id=HF_OUTPUT_DATASET_ID,
200
  repo_type="dataset"
@@ -229,13 +187,13 @@ async def upload_hf_state(state: Dict[str, Any]) -> bool:
229
  # Save state locally first
230
  save_json_state(str(local_path), state)
231
 
232
- # Upload to output dataset
233
  HfApi(token=HF_TOKEN).upload_file(
234
  path_or_fileobj=str(local_path),
235
  path_in_repo=HF_STATE_FILE,
236
  repo_id=HF_OUTPUT_DATASET_ID,
237
  repo_type="dataset",
238
- commit_message=f"Update transcription processing state: next_index={state['next_download_index']}, count={state.get('transcription_count', 0)}"
239
  )
240
  print(f"[{FLOW_ID}] Successfully uploaded state file.")
241
  return True
@@ -243,84 +201,52 @@ async def upload_hf_state(state: Dict[str, Any]) -> bool:
243
  print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
244
  return False
245
 
246
- async def lock_file_for_processing(audio_filename: str, state: Dict[str, Any]) -> bool:
247
  """Marks a file as 'processing' in the state file and uploads the lock."""
248
- print(f"[{FLOW_ID}] 🔒 Attempting to lock file: {audio_filename}")
249
 
250
  # Update state locally
251
- state["file_states"][audio_filename] = "processing"
252
 
253
  # Upload the updated state file immediately to establish the lock
254
  if await upload_hf_state(state):
255
- print(f"[{FLOW_ID}] ✅ Successfully locked file: {audio_filename}")
256
  return True
257
  else:
258
- print(f"[{FLOW_ID}] ❌ Failed to lock file: {audio_filename}")
259
  # Revert local state
260
- if audio_filename in state["file_states"]:
261
- del state["file_states"][audio_filename]
262
  return False
263
 
264
- async def unlock_file_as_processed(audio_filename: str, state: Dict[str, Any], next_index: int) -> bool:
265
  """Marks a file as 'processed', updates the index, and uploads the state."""
266
- print(f"[{FLOW_ID}] 🔓 Marking file as processed: {audio_filename}")
267
 
268
  # Update state locally
269
- state["file_states"][audio_filename] = "processed"
270
  state["next_download_index"] = next_index
271
 
272
  # Upload the updated state
273
  if await upload_hf_state(state):
274
- print(f"[{FLOW_ID}] ✅ Successfully marked as processed: {audio_filename}")
275
  return True
276
  else:
277
- print(f"[{FLOW_ID}] ❌ Failed to update state for: {audio_filename}")
278
  return False
279
 
280
  # --- Hugging Face Utility Functions ---
281
 
282
- async def get_reference_map(reference_repo_id: str) -> Dict[str, str]:
283
- """
284
- Fetches the reference file list from the Hugging Face repo and creates a map
285
- from audio filename (without extension) to reference filename.
286
- """
287
- print(f"[{FLOW_ID}] Fetching reference file list from {reference_repo_id}...")
288
-
289
- try:
290
- api = HfApi(token=HF_TOKEN)
291
- repo_files = api.list_repo_files(repo_id=reference_repo_id, repo_type="dataset")
292
-
293
- reference_map = {}
294
- for file in repo_files:
295
- base_name, ext = os.path.splitext(file)
296
- if ext.lower() in ['.txt', '.json']: # Consider text/json files as reference
297
- reference_map[base_name] = file
298
-
299
- print(f"[{FLOW_ID}] ✅ Successfully created reference map with {len(reference_map)} entries.")
300
- return reference_map
301
-
302
- except Exception as e:
303
- print(f"[{FLOW_ID}] ⚠️ Failed to fetch reference map from Hugging Face: {e}")
304
- return {}
305
-
306
- def find_matching_filename(audio_filename: str, reference_map: Dict[str, str]) -> Optional[str]:
307
- """
308
- Finds the matching reference filename for a given audio filename.
309
- Returns the reference filename if found, otherwise None.
310
- """
311
- base_name, _ = os.path.splitext(audio_filename)
312
- return reference_map.get(base_name)
313
-
314
  async def get_audio_file_list(progress_data: Dict) -> List[str]:
315
  """
316
- Fetches the list of all audio files from the dataset, or uses the cached list.
317
  Updates the progress_data with the file list if a new list is fetched.
318
  """
319
  if progress_data['file_list']:
320
  print(f"[{FLOW_ID}] Using cached file list with {len(progress_data['file_list'])} files.")
321
  return progress_data['file_list']
322
 
323
- print(f"[{FLOW_ID}] Fetching full list of audio files from {HF_AUDIO_DATASET_ID}...")
324
  try:
325
  api = HfApi(token=HF_TOKEN)
326
  repo_files = api.list_repo_files(
@@ -328,497 +254,384 @@ async def get_audio_file_list(progress_data: Dict) -> List[str]:
328
  repo_type="dataset"
329
  )
330
 
331
- # Filter for audio files in the specified directory and sort them alphabetically for consistent indexing
332
- audio_extensions = ['.mp3', '.wav', '.m4a', '.flac', '.ogg', '.aac']
333
- audio_files = sorted([
334
  f for f in repo_files
335
- if f.startswith(AUDIO_FILE_PREFIX) and any(f.lower().endswith(ext) for ext in audio_extensions)
336
  ])
337
 
338
- if not audio_files:
339
- raise FileNotFoundError(f"No audio files found in '{AUDIO_FILE_PREFIX}' directory of dataset '{HF_AUDIO_DATASET_ID}'.")
340
 
341
- print(f"[{FLOW_ID}] Found {len(audio_files)} audio files.")
342
 
343
  # Update and save the progress data
344
- progress_data['file_list'] = audio_files
345
  save_progress(progress_data)
346
 
347
- return audio_files
348
 
349
  except Exception as e:
350
  print(f"[{FLOW_ID}] Error fetching file list from Hugging Face: {e}")
351
  return []
352
 
353
- async def download_audio_file(file_index: int, repo_file_full_path: str) -> Optional[Path]:
354
- """Downloads the audio file for the given index."""
355
 
356
- audio_filename = Path(repo_file_full_path).name
357
 
358
- print(f"[{FLOW_ID}] Processing audio file #{file_index}: {repo_file_full_path}")
359
 
360
  try:
361
  # Use hf_hub_download to get the file path
362
- audio_path = hf_hub_download(
363
  repo_id=HF_AUDIO_DATASET_ID,
364
  filename=repo_file_full_path,
365
  repo_type="dataset",
366
  token=HF_TOKEN,
367
  )
368
 
369
- print(f"[{FLOW_ID}] Downloaded audio to {audio_path}.")
370
-
371
- # Copy to temp directory
372
- temp_path = TEMP_DIR / audio_filename
373
- shutil.copy2(audio_path, temp_path)
374
-
375
- return temp_path
376
 
377
  except Exception as e:
378
- print(f"[{FLOW_ID}] Error downloading audio file {repo_file_full_path}: {e}")
379
  return None
380
 
381
- async def upload_json_to_dataset(json_file_path: Path, json_filename: str) -> bool:
382
- """Uploads a single JSON transcription file directly to HF dataset."""
383
- try:
384
- # Check rate limit before uploading
385
- if not await rate_limiter.wait_if_needed():
386
- print(f"[{FLOW_ID}] ⏸️ Upload rate limit reached for {json_filename}. Waiting...")
387
- return False
388
-
389
- print(f"[{FLOW_ID}] 📤 Uploading JSON file: {json_filename}...")
390
-
391
- api = HfApi(token=HF_TOKEN)
392
- api.upload_file(
393
- path_or_fileobj=str(json_file_path),
394
- path_in_repo=f"transcriptions/{json_filename}",
395
- repo_id=HF_OUTPUT_DATASET_ID,
396
- repo_type="dataset",
397
- commit_message=f"[{FLOW_ID}] Transcription: {json_filename}"
398
- )
399
-
400
- print(f"[{FLOW_ID}] ✅ Successfully uploaded: {json_filename}")
401
- return True
402
-
403
- except Exception as e:
404
- print(f"[{FLOW_ID}] ❌ Error uploading {json_filename}: {e}")
405
- return False
406
-
407
- async def zip_and_upload_transcriptions(transcription_files: List[Path], batch_number: int) -> bool:
408
- """Zips transcription JSON files and uploads to dataset with batch numbering."""
409
- if not transcription_files:
410
- print(f"[{FLOW_ID}] No transcription files to zip.")
411
- return False
412
 
413
  try:
414
- zip_filename = f"audio_json_batch_{batch_number}.zip"
415
- zip_path = RESULTS_DIR / zip_filename
416
-
417
- print(f"[{FLOW_ID}] 📦 Creating zip file: {zip_filename} with {len(transcription_files)} files...")
418
-
419
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
420
- for file_path in transcription_files:
421
- if file_path.exists():
422
- zipf.write(file_path, arcname=file_path.name)
423
 
424
- print(f"[{FLOW_ID}] 📤 Uploading zip file to {HF_OUTPUT_DATASET_ID}...")
 
425
 
426
  api = HfApi(token=HF_TOKEN)
427
  api.upload_file(
428
- path_or_fileobj=str(zip_path),
429
- path_in_repo=zip_filename,
430
  repo_id=HF_OUTPUT_DATASET_ID,
431
  repo_type="dataset",
432
- commit_message=f"[{FLOW_ID}] Batch {batch_number}: {len(transcription_files)} transcriptions"
433
  )
434
 
435
- print(f"[{FLOW_ID}] Successfully uploaded: {zip_filename}")
436
-
437
- # Cleanup
438
- os.remove(zip_path)
439
-
440
  return True
441
 
442
  except Exception as e:
443
- print(f"[{FLOW_ID}] Error zipping and uploading transcriptions: {e}")
444
  return False
445
 
446
  # --- Core Processing Functions ---
447
 
448
- async def get_available_server(timeout: float = 300.0) -> WhisperServer:
449
- """Round-robin selection of an available whisper server."""
450
- global server_index
451
- start_time = time.time()
452
- while True:
453
- # Round-robin check for an available server
454
- for _ in range(len(servers)):
455
- server = servers[server_index]
456
- server_index = (server_index + 1) % len(servers)
457
- if not server.busy:
458
- return server
459
-
460
- # If all servers are busy, wait for a short period and check again
461
- await asyncio.sleep(0.5)
462
-
463
- # Check if timeout has been reached
464
- if time.time() - start_time > timeout:
465
- raise TimeoutError(f"Timeout ({timeout}s) waiting for an available whisper server.")
466
-
467
- async def send_audio_for_transcription(audio_path: Path, progress_tracker: Dict) -> Optional[Dict]:
468
- """Sends a single audio file to a whisper server for transcription."""
469
- MAX_RETRIES = 3
470
- for attempt in range(MAX_RETRIES):
471
- server = None
472
- try:
473
- # 1. Get an available server
474
- server = await get_available_server()
475
- server.busy = True
476
- start_time = time.time()
477
-
478
- if attempt == 0:
479
- print(f"[{FLOW_ID}] Starting transcription attempt on {audio_path.name}...")
480
-
481
- # 2. Prepare request data - keep file open until request is done
482
- with audio_path.open('rb') as f:
483
- file_content = f.read()
484
-
485
- form_data = aiohttp.FormData()
486
- form_data.add_field('file',
487
- io.BytesIO(file_content),
488
- filename=audio_path.name,
489
- content_type='audio/mpeg')
490
-
491
- # 3. Send request
492
- async with aiohttp.ClientSession() as session:
493
- print(f"[{FLOW_ID}] Sending audio file to {server.url}...")
494
- async with session.post(server.url, data=form_data, timeout=aiohttp.ClientTimeout(total=600)) as resp:
495
- print(f"[{FLOW_ID}] Received response status: {resp.status}")
496
 
497
- if resp.status == 200:
498
- result = await resp.json()
499
- print(f"[{FLOW_ID}] Response data: {result}")
500
-
501
- # Check if response contains transcription data
502
- if result.get('text') or result.get('transcription'):
503
- # Update progress counter
504
- progress_tracker['completed'] += 1
505
- if progress_tracker['completed'] % 10 == 0:
506
- print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} transcriptions completed.")
507
-
508
- print(f"[{FLOW_ID}] ✅ Success: {audio_path.name} transcribed by {server.url}")
509
-
510
- # Store the full transcription result
511
- return {
512
- "audio_file": audio_path.name,
513
- "text": result.get('text', result.get('transcription', '')),
514
- "language": result.get('language', 'unknown'),
515
- "confidence": result.get('confidence'),
516
- "duration": result.get('duration'),
517
- }
518
- else:
519
- print(f"[{FLOW_ID}] ⚠️ Server {server.url} returned invalid response format for {audio_path.name}. Response: {result}")
520
- continue
521
- else:
522
- error_text = await resp.text()
523
- print(f"[{FLOW_ID}] ❌ Error from server {server.url} for {audio_path.name}: {resp.status} - {error_text}. Retrying...")
524
- continue
525
-
526
- except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
527
- print(f"[{FLOW_ID}] ❌ Connection/Timeout error for {audio_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
528
- continue
529
- except Exception as e:
530
- print(f"[{FLOW_ID}] ❌ Unexpected error during transcription for {audio_path.name}: {e}. Retrying...")
531
- import traceback
532
- traceback.print_exc()
533
- continue
534
- finally:
535
- if server:
536
- end_time = time.time()
537
- server.busy = False
538
- server.total_processed += 1
539
- server.total_time += (end_time - start_time)
540
-
541
- print(f"[{FLOW_ID}] ❌ FAILED after {MAX_RETRIES} attempts for {audio_path.name}.")
542
- return None
543
-
544
- # --- FastAPI App and Endpoints ---
545
-
546
- app = FastAPI(
547
- title=f"Flow Server {FLOW_ID} API",
548
- description="Processes audio files from a dataset, sends to whisper servers for transcription, and tracks progress.",
549
- version="1.0.0"
550
- )
551
-
552
- @app.on_event("startup")
553
- async def startup_event():
554
- print(f"[{FLOW_ID}] Flow Server started on port {FLOW_PORT}.")
555
- print(f"[{FLOW_ID}] 🚀 Auto-starting background processing...")
556
-
557
- # Create a background task to run the processing loop
558
- thread = threading.Thread(target=lambda: asyncio.run(process_audio_files_background()), daemon=True)
559
- thread.start()
560
- print(f"[{FLOW_ID}] ✅ Background processing thread started")
561
 
562
- @app.post("/process")
563
- async def process_audio_files(background_tasks: BackgroundTasks):
564
  """
565
- Manually trigger processing endpoint (in addition to auto-start on startup).
566
- Orchestrates transcription of audio files with reference file mapping.
567
  """
568
- print(f"[{FLOW_ID}] /process endpoint called, starting additional background task...")
569
- background_tasks.add_task(process_audio_files_background)
570
- return {
571
- "status": "processing_started",
572
- "flow_id": FLOW_ID,
573
- "message": "Background processing task started. Check /status for progress."
574
- }
575
 
576
- async def process_audio_files_background():
 
 
 
 
 
 
 
 
 
 
577
  """
578
- Background task that processes audio files with reference mapping.
579
- - Downloads batch of files (1 per server)
580
- - Distributes to Whisper servers in parallel
581
- - Uploads JSON results directly to HF dataset
582
- - Updates processing state after each batch round (dynamically based on actual processed count)
583
- - Respects rate limit: max 120 uploads/hour, stops at 128
584
  """
585
- progress_data = load_progress()
586
- reference_map = progress_data.get('reference_map', {})
587
-
588
- # Fetch reference map if empty
589
- if not reference_map:
590
- print(f"[{FLOW_ID}] Reference map is empty. Fetching from {REFERENCE_REPO_ID}...")
591
- reference_map = await get_reference_map(REFERENCE_REPO_ID)
592
- progress_data['reference_map'] = reference_map
593
- save_progress(progress_data)
594
-
595
- audio_files = await get_audio_file_list(progress_data)
596
- if not audio_files:
597
- print(f"[{FLOW_ID}] No audio files found. Exiting.")
598
- return
599
-
600
- # Dynamic batch size: one file per server
601
- BATCH_SIZE = len(servers)
602
- print(f"[{FLOW_ID}] 📊 Configuration: {len(servers)} Whisper server(s) → Batch size: {BATCH_SIZE} (1 file per server)")
603
 
604
- start_index = progress_data['last_processed_index']
 
605
 
606
- print(f"[{FLOW_ID}] Starting batch processing from file #{start_index} (out of {len(audio_files)})...")
607
 
608
- # Process in batches
609
- for batch_start in range(start_index, len(audio_files), BATCH_SIZE):
610
- batch_end = min(batch_start + BATCH_SIZE, len(audio_files))
611
- batch_files = audio_files[batch_start:batch_end]
612
-
613
- print(f"\n[{FLOW_ID}] 📦 BATCH ROUND: Processing files #{batch_start}-#{batch_end-1} ({len(batch_files)} files)")
614
-
615
- # Step 1: Download all files in batch in parallel
616
- print(f"[{FLOW_ID}] ⬇️ Downloading batch ({len(batch_files)} files)...")
617
- download_tasks = []
618
- for idx, repo_file_path in enumerate(batch_files):
619
- file_index = batch_start + idx
620
- download_tasks.append(download_audio_file(file_index, repo_file_path))
621
-
622
- downloaded_paths = await asyncio.gather(*download_tasks, return_exceptions=True)
623
-
624
- # Step 2: Send all downloaded files to Whisper servers in parallel
625
- print(f"[{FLOW_ID}] 🎤 Distributing to {len(servers)} Whisper server(s) ({len(batch_files)} files)...")
626
-
627
- transcription_tasks = []
628
- file_metadata = [] # Track file info for results
629
-
630
- for idx, (repo_file_path, audio_path) in enumerate(zip(batch_files, downloaded_paths)):
631
- file_index = batch_start + idx
632
- audio_filename = Path(repo_file_path).name
633
-
634
- # Skip if download failed
635
- if isinstance(audio_path, Exception):
636
- print(f"[{FLOW_ID}] ⏭️ Skipping {audio_filename} (download failed)")
637
- continue
638
-
639
- if not audio_path or not audio_path.exists():
640
- continue
641
-
642
- reference_filename = find_matching_filename(audio_filename, reference_map)
643
- file_metadata.append({
644
- 'audio_filename': audio_filename,
645
- 'audio_path': audio_path,
646
- 'reference_filename': reference_filename,
647
- 'file_index': file_index
648
- })
649
-
650
- # Create transcription task (will be awaited in parallel)
651
- transcription_tasks.append(send_audio_for_transcription_task(audio_path, audio_filename))
652
-
653
- if transcription_tasks:
654
- print(f"[{FLOW_ID}] ⏳ Waiting for {len(transcription_tasks)} transcriptions (parallel)...")
655
- transcription_results = await asyncio.gather(*transcription_tasks, return_exceptions=True)
656
-
657
- # Step 3: Upload transcriptions directly to HF dataset
658
- successful_uploads = 0
659
- uploaded_files = []
660
- state = await download_hf_state()
661
-
662
- print(f"[{FLOW_ID}] 📤 Uploading {len([r for r in transcription_results if r and not isinstance(r, Exception)])}/{len(transcription_results)} transcriptions directly to dataset...")
663
-
664
- for metadata, result in zip(file_metadata, transcription_results):
665
- if isinstance(result, Exception):
666
- print(f"[{FLOW_ID}] ❌ Transcription failed for {metadata['audio_filename']}: {result}")
667
  continue
668
 
669
- if result:
670
- # Save JSON locally first
671
- json_filename = Path(metadata['reference_filename']).stem if metadata['reference_filename'] else Path(metadata['audio_filename']).stem
672
- json_file_path = Path(RESULTS_DIR) / f"{json_filename}.json"
673
-
674
- # Write JSON to file
675
- with open(json_file_path, 'w', encoding='utf-8') as f:
676
- json.dump(result, f, indent=2, ensure_ascii=False)
677
-
678
- # Upload directly to HF dataset
679
- if await upload_json_to_dataset(json_file_path, f"{json_filename}.json"):
680
- successful_uploads += 1
681
- uploaded_files.append(json_file_path)
682
- progress_data['transcription_count'] += 1
683
-
684
- # Cleanup local JSON file after upload
685
- if json_file_path.exists():
686
- os.remove(json_file_path)
687
-
688
- # Step 4: Cleanup downloaded audio files
689
- for metadata in file_metadata:
690
- if metadata['audio_path'].exists():
691
- os.remove(metadata['audio_path'])
692
-
693
- # Step 5: Update processing state after this batch round
694
- # Update next_download_index based on actual files processed this round
695
- files_processed_this_round = len([m for m in file_metadata if m]) # Count of files actually processed
696
- new_download_index = batch_start + files_processed_this_round
697
-
698
- print(f"[{FLOW_ID}] 🔄 Batch round complete: {files_processed_this_round} files distributed and processed")
699
- print(f"[{FLOW_ID}] 📊 Updating state: next_download_index {state['next_download_index']} → {new_download_index}")
700
-
701
- state['next_download_index'] = new_download_index
702
 
703
- # Mark all files in this round as processed in the state
704
- for metadata in file_metadata:
705
- state['file_states'][metadata['audio_filename']] = "processed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
 
707
- # Upload updated state
708
  await upload_hf_state(state)
709
-
710
- # Save local progress
711
- progress_data['last_processed_index'] = batch_end
712
- save_progress(progress_data)
713
-
714
- print(f"[{FLOW_ID}] ✅ State updated. Successful uploads this round: {successful_uploads}/{len(file_metadata)}")
715
 
716
- print(f"\n[{FLOW_ID}] ALL DONE! Total transcriptions: {progress_data['transcription_count']}")
 
 
 
717
 
718
- async def send_audio_for_transcription_task(audio_path: Path, audio_filename: str) -> Optional[Dict]:
719
- """Wrapper for transcription that can be used in asyncio.gather."""
720
- MAX_RETRIES = 3
721
- for attempt in range(MAX_RETRIES):
722
- server = None
723
- try:
724
- server = await get_available_server()
725
- server.busy = True
726
- start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
 
728
- # Read file content once
729
- with audio_path.open('rb') as f:
730
- file_content = f.read()
 
731
 
732
- form_data = aiohttp.FormData()
733
- form_data.add_field('file',
734
- io.BytesIO(file_content),
735
- filename=audio_filename,
736
- content_type='audio/mpeg')
737
 
738
- async with aiohttp.ClientSession() as session:
739
- async with session.post(server.url, data=form_data, timeout=aiohttp.ClientTimeout(total=600)) as resp:
740
- if resp.status == 200:
741
- result = await resp.json()
742
-
743
- if result.get('text') or result.get('transcription'):
744
- print(f"[{FLOW_ID}] {audio_filename}")
745
-
746
- return {
747
- "audio_file": audio_filename,
748
- "text": result.get('text', result.get('transcription', '')),
749
- "language": result.get('language', 'unknown'),
750
- "confidence": result.get('confidence'),
751
- "duration": result.get('duration'),
752
- }
753
- else:
754
- print(f"[{FLOW_ID}] ⚠️ Invalid response for {audio_filename}")
755
- continue
756
- else:
757
- error_text = await resp.text()
758
- print(f"[{FLOW_ID}] Server error {resp.status}: {audio_filename}")
759
- continue
760
-
761
- except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
762
- print(f"[{FLOW_ID}] ⏱️ Timeout/Connection error for {audio_filename}")
763
- continue
764
- except Exception as e:
765
- print(f"[{FLOW_ID}] ❌ Error for {audio_filename}: {str(e)[:50]}")
766
- continue
767
- finally:
768
- if server:
769
- end_time = time.time()
770
- server.busy = False
771
- server.total_processed += 1
772
- server.total_time += (end_time - start_time)
773
-
774
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
776
  @app.get("/")
777
  async def root():
778
  progress = load_progress()
 
 
 
 
 
 
779
  return {
780
  "flow_id": FLOW_ID,
781
  "status": "ready",
782
- "last_processed_index": progress['last_processed_index'],
783
  "total_files_in_list": len(progress['file_list']),
784
- "processed_files_count": len(progress['processed_files']),
785
- "transcription_count": progress.get('transcription_count', 0),
786
  "total_servers": len(servers),
787
- "busy_servers": sum(1 for s in servers if s.busy),
 
 
 
788
  }
789
 
790
- @app.get("/status")
791
- async def get_status():
792
- """Returns detailed processing status with reference map info."""
793
- progress = load_progress()
794
- state = await download_hf_state()
 
795
 
796
- return {
797
- "flow_id": FLOW_ID,
798
- "status": "processing" if state['next_download_index'] < len(progress.get('file_list', [])) else "idle",
799
- "progress": {
800
- "current_index": state['next_download_index'],
801
- "total_files": len(progress.get('file_list', [])),
802
- "percentage": (state['next_download_index'] / len(progress.get('file_list', [])) * 100) if progress.get('file_list') else 0
803
- },
804
- "transcription_count": progress.get('transcription_count', 0),
805
- "reference_map_size": len(progress.get('reference_map', {})),
806
- "server_stats": {
807
- "total_servers": len(servers),
808
- "busy_servers": sum(1 for s in servers if s.busy),
809
- "details": [
810
- {
811
- "url": s.url,
812
- "busy": s.busy,
813
- "total_processed": s.total_processed,
814
- "avg_time_per_file": s.total_time / s.total_processed if s.total_processed > 0 else 0
815
- }
816
- for s in servers
817
- ]
818
- },
819
- "files_in_processing": list(state.get('file_states', {}).keys())
820
- }
821
 
822
  if __name__ == "__main__":
823
  import uvicorn
 
824
  uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
 
5
  import aiohttp
6
  import zipfile
7
  import shutil
 
8
  from typing import Dict, List, Set, Optional, Tuple, Any
9
  from urllib.parse import quote
10
+ from datetime import datetime
11
  from pathlib import Path
12
  import io
13
 
 
16
  from huggingface_hub import HfApi, hf_hub_download
17
 
18
  # --- Configuration ---
19
+ AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found
20
  FLOW_ID = os.getenv("FLOW_ID", "flow_default")
21
  FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
22
  HF_TOKEN = os.getenv("HF_TOKEN", "")
23
  HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_Vid")
24
+ HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO")
25
 
26
  # Progress and State Tracking
27
  PROGRESS_FILE = Path("processing_progress.json")
 
29
  LOCAL_STATE_FOLDER = Path(".state")
30
  LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
31
 
32
+ # Processing configuration
33
+ MAX_UPLOADS_BEFORE_PAUSE = 120 # Pause uploading after 120 files
34
+ UPLOAD_PAUSE_ENABLED = True
35
 
36
+ # Directory within the HF dataset where the audio files are located
37
+ AUDIO_FILE_PREFIX = "audio/"
38
 
39
  WHISPER_SERVERS = [
40
  "https://fred1012-switch3.hf.space/transcribe",
 
59
  "https://Eliasishere-mint-20.hf.space/transcribe"
60
  ]
61
 
 
 
 
62
  # Temporary storage for audio files
63
  TEMP_DIR = Path(f"temp_audio_{FLOW_ID}")
64
  TEMP_DIR.mkdir(exist_ok=True)
65
 
 
 
 
 
66
  # --- Models ---
67
+ class ProcessStartRequest(BaseModel):
68
+ start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).")
69
+
70
  class WhisperServer:
71
+ def __init__(self, url: str):
72
  self.url = url
73
+ self.is_processing = False
74
+ self.current_file_index: Optional[int] = None
75
  self.total_processed = 0
76
+ self.total_time = 0.0
 
77
 
78
  @property
79
  def fps(self):
80
+ """Files per second"""
81
  return self.total_processed / self.total_time if self.total_time > 0 else 0
82
+
83
+ def assign_file(self, file_index: int):
84
+ """Assign a file index to this server"""
85
+ self.is_processing = True
86
+ self.current_file_index = file_index
87
+
88
+ def release(self):
89
+ """Release the server for a new file"""
90
+ self.is_processing = False
91
+ self.current_file_index = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Global state for whisper servers
94
  servers = [WhisperServer(url) for url in WHISPER_SERVERS]
95
+ server_lock = asyncio.Lock() # Lock for thread-safe server state access
96
 
97
+ # --- Progress and State Management Functions ---
98
 
99
  def load_progress() -> Dict:
100
+ """Loads the local processing progress from the JSON file."""
101
  if PROGRESS_FILE.exists():
102
  try:
103
  with PROGRESS_FILE.open('r') as f:
104
  return json.load(f)
105
  except json.JSONDecodeError:
106
  print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
107
+ # Fall through to return default structure
108
 
109
+ # Default structure
110
  return {
111
  "last_processed_index": 0,
112
+ "processed_files": {}, # {index: repo_path}
113
+ "file_list": [] # Full list of all zip files found in the dataset
 
 
114
  }
115
 
116
  def save_progress(progress_data: Dict):
 
136
  if "next_download_index" not in data:
137
  data["next_download_index"] = 0
138
 
 
 
 
139
  return data
140
  except json.JSONDecodeError:
141
  print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}")
142
+ return default_value
143
 
144
  def save_json_state(file_path: str, data: Dict[str, Any]):
145
  """Save state to JSON file"""
 
149
  async def download_hf_state() -> Dict[str, Any]:
150
  """Downloads the state file from Hugging Face or returns a default state."""
151
  local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
152
+ default_state = {"next_download_index": 0, "file_states": {}}
153
 
154
  try:
155
+ # Check if the file exists in the helium repo
156
  files = HfApi(token=HF_TOKEN).list_repo_files(
157
  repo_id=HF_OUTPUT_DATASET_ID,
158
  repo_type="dataset"
 
187
  # Save state locally first
188
  save_json_state(str(local_path), state)
189
 
190
+ # Upload to helium dataset
191
  HfApi(token=HF_TOKEN).upload_file(
192
  path_or_fileobj=str(local_path),
193
  path_in_repo=HF_STATE_FILE,
194
  repo_id=HF_OUTPUT_DATASET_ID,
195
  repo_type="dataset",
196
+ commit_message=f"Update caption processing state: next_index={state['next_download_index']}"
197
  )
198
  print(f"[{FLOW_ID}] Successfully uploaded state file.")
199
  return True
 
201
  print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
202
  return False
203
 
204
+ async def lock_file_for_processing(zip_filename: str, state: Dict[str, Any]) -> bool:
205
  """Marks a file as 'processing' in the state file and uploads the lock."""
206
+ print(f"[{FLOW_ID}] 🔒 Attempting to lock file: {zip_filename}")
207
 
208
  # Update state locally
209
+ state["file_states"][zip_filename] = "processing"
210
 
211
  # Upload the updated state file immediately to establish the lock
212
  if await upload_hf_state(state):
213
+ print(f"[{FLOW_ID}] ✅ Successfully locked file: {zip_filename}")
214
  return True
215
  else:
216
+ print(f"[{FLOW_ID}] ❌ Failed to lock file: {zip_filename}")
217
  # Revert local state
218
+ if zip_filename in state["file_states"]:
219
+ del state["file_states"][zip_filename]
220
  return False
221
 
222
+ async def unlock_file_as_processed(zip_filename: str, state: Dict[str, Any], next_index: int) -> bool:
223
  """Marks a file as 'processed', updates the index, and uploads the state."""
224
+ print(f"[{FLOW_ID}] 🔓 Marking file as processed: {zip_filename}")
225
 
226
  # Update state locally
227
+ state["file_states"][zip_filename] = "processed"
228
  state["next_download_index"] = next_index
229
 
230
  # Upload the updated state
231
  if await upload_hf_state(state):
232
+ print(f"[{FLOW_ID}] ✅ Successfully marked as processed: {zip_filename}")
233
  return True
234
  else:
235
+ print(f"[{FLOW_ID}] ❌ Failed to update state for: {zip_filename}")
236
  return False
237
 
238
  # --- Hugging Face Utility Functions ---
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  async def get_audio_file_list(progress_data: Dict) -> List[str]:
241
  """
242
+ Fetches the list of all WAV files from the dataset, or uses the cached list.
243
  Updates the progress_data with the file list if a new list is fetched.
244
  """
245
  if progress_data['file_list']:
246
  print(f"[{FLOW_ID}] Using cached file list with {len(progress_data['file_list'])} files.")
247
  return progress_data['file_list']
248
 
249
+ print(f"[{FLOW_ID}] Fetching full list of WAV files from {HF_AUDIO_DATASET_ID}...")
250
  try:
251
  api = HfApi(token=HF_TOKEN)
252
  repo_files = api.list_repo_files(
 
254
  repo_type="dataset"
255
  )
256
 
257
+ # Filter for WAV files and sort them alphabetically for consistent indexing
258
+ wav_files = sorted([
 
259
  f for f in repo_files
260
+ if f.endswith('.wav')
261
  ])
262
 
263
+ if not wav_files:
264
+ raise FileNotFoundError(f"No WAV files found in dataset '{HF_AUDIO_DATASET_ID}'.")
265
 
266
+ print(f"[{FLOW_ID}] Found {len(wav_files)} WAV files.")
267
 
268
  # Update and save the progress data
269
+ progress_data['file_list'] = wav_files
270
  save_progress(progress_data)
271
 
272
+ return wav_files
273
 
274
  except Exception as e:
275
  print(f"[{FLOW_ID}] Error fetching file list from Hugging Face: {e}")
276
  return []
277
 
278
+ async def download_wav_file_by_index(file_index: int, repo_file_full_path: str) -> Optional[Path]:
279
+ """Downloads a WAV file from the repository."""
280
 
281
+ wav_filename = Path(repo_file_full_path).name
282
 
283
+ print(f"[{FLOW_ID}] Downloading file #{file_index}: {repo_file_full_path}")
284
 
285
  try:
286
  # Use hf_hub_download to get the file path
287
+ wav_path = hf_hub_download(
288
  repo_id=HF_AUDIO_DATASET_ID,
289
  filename=repo_file_full_path,
290
  repo_type="dataset",
291
  token=HF_TOKEN,
292
  )
293
 
294
+ print(f"[{FLOW_ID}] Downloaded WAV file to {wav_path}")
295
+ return Path(wav_path)
 
 
 
 
 
296
 
297
  except Exception as e:
298
+ print(f"[{FLOW_ID}] Error downloading WAV file {repo_file_full_path}: {e}")
299
  return None
300
 
301
+ async def upload_transcription_to_hf(wav_filename: str, transcription_data: Dict) -> bool:
302
+ """Uploads the transcription JSON file to the output dataset."""
303
+ # Use the WAV filename, replacing the extension with .json
304
+ json_filename = Path(wav_filename).with_suffix('.json').name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  try:
307
+ print(f"[{FLOW_ID}] Uploading transcription for {wav_filename} as {json_filename} to {HF_OUTPUT_DATASET_ID}...")
 
 
 
 
 
 
 
 
308
 
309
+ # Create JSON content in memory
310
+ json_content = json.dumps(transcription_data, indent=2, ensure_ascii=False).encode('utf-8')
311
 
312
  api = HfApi(token=HF_TOKEN)
313
  api.upload_file(
314
+ path_or_fileobj=io.BytesIO(json_content),
315
+ path_in_repo=json_filename,
316
  repo_id=HF_OUTPUT_DATASET_ID,
317
  repo_type="dataset",
318
+ commit_message=f"[{FLOW_ID}] Transcription for {wav_filename}"
319
  )
320
 
321
+ print(f"[{FLOW_ID}] Successfully uploaded transcription for {wav_filename}.")
 
 
 
 
322
  return True
323
 
324
  except Exception as e:
325
+ print(f"[{FLOW_ID}] Error uploading transcription for {wav_filename}: {e}")
326
  return False
327
 
328
  # --- Core Processing Functions ---
329
 
330
+ async def send_audio_to_whisper(wav_path: Path, server: WhisperServer) -> Optional[Dict]:
331
+ """Sends a WAV file to a Whisper server for transcription."""
332
+ try:
333
+ print(f"[{FLOW_ID}] Sending {wav_path.name} to {server.url}...")
334
+
335
+ start_time = time.time()
336
+
337
+ # Prepare multipart form data
338
+ form_data = aiohttp.FormData()
339
+ form_data.add_field('file',
340
+ wav_path.open('rb'),
341
+ filename=wav_path.name,
342
+ content_type='audio/wav')
343
+
344
+ async with aiohttp.ClientSession() as session:
345
+ # 10 minute timeout for transcription
346
+ async with session.post(server.url, data=form_data, timeout=600) as resp:
347
+ if resp.status == 200:
348
+ result = await resp.json()
349
+ end_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ # Update server stats
352
+ server.total_processed += 1
353
+ server.total_time += (end_time - start_time)
354
+
355
+ print(f"[{FLOW_ID}] {wav_path.name} transcribed successfully by {server.url}")
356
+
357
+ return {
358
+ "file": wav_path.name,
359
+ "transcription": result,
360
+ "timestamp": datetime.now().isoformat(),
361
+ "processing_time_seconds": end_time - start_time
362
+ }
363
+ else:
364
+ error_text = await resp.text()
365
+ print(f"[{FLOW_ID}] ✗ Error from {server.url}: {resp.status} - {error_text}")
366
+ return None
367
+
368
+ except asyncio.TimeoutError:
369
+ print(f"[{FLOW_ID}] ✗ Timeout from {server.url} for {wav_path.name}")
370
+ return None
371
+ except Exception as e:
372
+ print(f"[{FLOW_ID}] ✗ Exception on {server.url} for {wav_path.name}: {e}")
373
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ async def get_available_servers() -> List[WhisperServer]:
 
376
  """
377
+ Returns a list of servers that are not currently processing.
378
+ Dynamically assigns new files to available servers.
379
  """
380
+ async with server_lock:
381
+ available = [s for s in servers if not s.is_processing]
382
+ return available
 
 
 
 
383
 
384
+ async def assign_file_to_server(file_index: int, server: WhisperServer):
385
+ """Safely assign a file to a server"""
386
+ async with server_lock:
387
+ server.assign_file(file_index)
388
+
389
+ async def release_server(server: WhisperServer):
390
+ """Safely release a server for new work"""
391
+ async with server_lock:
392
+ server.release()
393
+
394
+ async def process_batch_dynamic(wav_files: List[str], start_batch_index: int, batch_size: int, state: Dict[str, Any], progress: Dict) -> Tuple[int, int]:
395
  """
396
+ Dynamically processes a batch of WAV files using available servers.
397
+ Returns (next_batch_index, uploaded_count)
 
 
 
 
398
  """
399
+ batch_end = min(start_batch_index + batch_size, len(wav_files))
400
+ current_index = start_batch_index
401
+ uploaded_count = progress.get('uploaded_count', 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
+ # Create tasks for all servers to process files dynamically
404
+ pending_tasks: Dict[asyncio.Task, Tuple[int, Path, WhisperServer]] = {}
405
 
406
+ print(f"[{FLOW_ID}] Processing batch from index {start_batch_index} to {batch_end}")
407
 
408
+ try:
409
+ while current_index < batch_end or pending_tasks:
410
+ # Assign new files to available servers
411
+ while current_index < batch_end:
412
+ available_servers = await get_available_servers()
413
+
414
+ if not available_servers:
415
+ # All servers busy, wait a bit
416
+ await asyncio.sleep(0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  continue
418
 
419
+ server = available_servers[0]
420
+ file_index = current_index
421
+ wav_file = wav_files[file_index]
422
+ wav_filename = Path(wav_file).name
423
+
424
+ # Mark file as processing in state
425
+ state["file_states"][wav_filename] = "processing"
426
+
427
+ # Download the WAV file
428
+ wav_path = await download_wav_file_by_index(file_index + 1, wav_file)
429
+ if not wav_path:
430
+ state["file_states"][wav_filename] = "failed"
431
+ current_index += 1
432
+ continue
433
+
434
+ # Assign to server and create task
435
+ await assign_file_to_server(file_index, server)
436
+ task = asyncio.create_task(send_audio_to_whisper(wav_path, server))
437
+ pending_tasks[task] = (file_index, wav_path, server)
438
+
439
+ current_index += 1
 
 
 
 
 
 
 
 
 
 
 
 
440
 
441
+ # Wait for at least one task to complete
442
+ if pending_tasks:
443
+ done, pending_tasks_remaining = await asyncio.wait(
444
+ pending_tasks.keys(),
445
+ return_when=asyncio.FIRST_COMPLETED
446
+ )
447
+
448
+ # Process completed tasks
449
+ for task in done:
450
+ file_index, wav_path, server = pending_tasks.pop(task)
451
+ wav_filename = Path(wav_path).name
452
+
453
+ try:
454
+ transcription_result = task.result()
455
+
456
+ if transcription_result:
457
+ # Check if we should pause uploading
458
+ if UPLOAD_PAUSE_ENABLED and uploaded_count >= MAX_UPLOADS_BEFORE_PAUSE:
459
+ print(f"[{FLOW_ID}] ⏸️ Upload limit reached ({uploaded_count}/{MAX_UPLOADS_BEFORE_PAUSE}). Pausing uploads but continuing processing...")
460
+ # Mark as processed but don't upload
461
+ state["file_states"][wav_filename] = "processed"
462
+ else:
463
+ # Upload transcription
464
+ if await upload_transcription_to_hf(wav_filename, transcription_result):
465
+ state["file_states"][wav_filename] = "processed"
466
+ uploaded_count += 1
467
+ progress['uploaded_count'] = uploaded_count
468
+ save_progress(progress)
469
+ else:
470
+ state["file_states"][wav_filename] = "failed"
471
+ else:
472
+ state["file_states"][wav_filename] = "failed"
473
+
474
+ except Exception as e:
475
+ print(f"[{FLOW_ID}] Error processing result for {wav_filename}: {e}")
476
+ state["file_states"][wav_filename] = "failed"
477
+ finally:
478
+ # Release the server
479
+ await release_server(server)
480
+ # Clean up the WAV file
481
+ if wav_path.exists():
482
+ wav_path.unlink()
483
+
484
+ # Update pending_tasks with remaining
485
+ pending_tasks = {task: pending_tasks[task] for task in pending_tasks_remaining}
486
 
487
+ # Update HF state periodically
488
  await upload_hf_state(state)
 
 
 
 
 
 
489
 
490
+ except Exception as e:
491
+ print(f"[{FLOW_ID}] Error in process_batch_dynamic: {e}")
492
+
493
+ return current_index, uploaded_count
494
 
495
+ async def process_dataset_task(start_index: int):
496
+ """Main task to process the dataset using dynamic server assignment."""
497
+
498
+ # Load both local progress and HF state
499
+ progress = load_progress()
500
+ current_state = await download_hf_state()
501
+ file_list = await get_audio_file_list(progress)
502
+
503
+ if not file_list:
504
+ print(f"[{FLOW_ID}] ERROR: Cannot proceed. File list is empty.")
505
+ return False
506
+
507
+ # Ensure start_index is within bounds
508
+ if start_index > len(file_list):
509
+ print(f"[{FLOW_ID}] WARNING: Start index {start_index} is greater than the total number of files ({len(file_list)}). Exiting.")
510
+ return True
511
+
512
+ # Determine the actual starting index in the 0-indexed list
513
+ start_list_index = start_index - 1
514
+
515
+ print(f"[{FLOW_ID}] Starting audio transcription from file index: {start_index} out of {len(file_list)}.")
516
+ print(f"[{FLOW_ID}] Using {len(servers)} Whisper servers for dynamic processing.")
517
+ print(f"[{FLOW_ID}] Upload pause enabled: {UPLOAD_PAUSE_ENABLED}, Max uploads before pause: {MAX_UPLOADS_BEFORE_PAUSE}")
518
+
519
+ # Initialize progress tracking
520
+ if 'uploaded_count' not in progress:
521
+ progress['uploaded_count'] = 0
522
+
523
+ global_success = True
524
+ current_batch_index = start_list_index
525
+ batch_size = len(servers) * 2 # Process 2 batches per server at a time
526
+
527
+ try:
528
+ while current_batch_index < len(file_list):
529
+ # Process a batch dynamically
530
+ next_index, uploaded_count = await process_batch_dynamic(
531
+ file_list,
532
+ current_batch_index,
533
+ batch_size,
534
+ current_state,
535
+ progress
536
+ )
537
 
538
+ # Update progress
539
+ progress['last_processed_index'] = next_index
540
+ progress['uploaded_count'] = uploaded_count
541
+ save_progress(progress)
542
 
543
+ # Update current batch index
544
+ current_batch_index = next_index
 
 
 
545
 
546
+ # Log statistics
547
+ print(f"[{FLOW_ID}] Batch complete. Progress: {current_batch_index}/{len(file_list)}, Uploaded: {uploaded_count}")
548
+
549
+ # Print server statistics
550
+ print(f"[{FLOW_ID}] Server Statistics:")
551
+ for i, server in enumerate(servers):
552
+ print(f" Server {i+1}: {server.total_processed} files, {server.total_time:.2f}s total, {server.fps:.2f} files/sec")
553
+
554
+ print(f"[{FLOW_ID}] All files processed successfully!")
555
+ return True
556
+
557
+ except Exception as e:
558
+ print(f"[{FLOW_ID}] Critical error in process_dataset_task: {e}")
559
+ global_success = False
560
+ return global_success
561
+
562
+ # --- FastAPI App and Endpoints ---
563
+
564
+ app = FastAPI(
565
+ title=f"Flow Server {FLOW_ID} API",
566
+ description="Sequentially processes zip files from a dataset, captions images, and tracks progress.",
567
+ version="1.0.0"
568
+ )
569
+
570
+ @app.on_event("startup")
571
+ async def startup_event():
572
+ print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}.")
573
+
574
+ # Get both local progress and HF state
575
+ progress = load_progress()
576
+ current_state = await download_hf_state()
577
+
578
+ # Get the next_download_index from HF state if available
579
+ hf_next_index = current_state.get("next_download_index", 0)
580
+
581
+ # If HF state has a higher index, use that instead of local progress
582
+ if hf_next_index > 0:
583
+ start_index = hf_next_index
584
+ print(f"[{FLOW_ID}] Using next_download_index from HF state: {start_index}")
585
+ else:
586
+ # Fall back to local progress if HF state doesn't have a meaningful index
587
+ start_index = progress.get('last_processed_index', 0) + 1
588
+ if start_index < AUTO_START_INDEX:
589
+ start_index = AUTO_START_INDEX
590
+
591
+ # Use a dummy BackgroundTasks object for the startup task
592
+ # Note: FastAPI's startup events can't directly use BackgroundTasks, but we can use asyncio.create_task
593
+ # to run the long-running process in the background without blocking the server startup.
594
+ print(f"[{FLOW_ID}] Auto-starting processing from index: {start_index}...")
595
+ asyncio.create_task(process_dataset_task(start_index))
596
 
597
  @app.get("/")
598
  async def root():
599
  progress = load_progress()
600
+
601
+ # Calculate server stats
602
+ total_processed = sum(s.total_processed for s in servers)
603
+ total_time = sum(s.total_time for s in servers)
604
+ avg_fps = total_processed / total_time if total_time > 0 else 0
605
+
606
  return {
607
  "flow_id": FLOW_ID,
608
  "status": "ready",
609
+ "last_processed_index": progress.get('last_processed_index', 0),
610
  "total_files_in_list": len(progress['file_list']),
611
+ "uploaded_count": progress.get('uploaded_count', 0),
 
612
  "total_servers": len(servers),
613
+ "processing_servers": sum(1 for s in servers if s.is_processing),
614
+ "total_files_processed_by_servers": total_processed,
615
+ "avg_files_per_second": avg_fps,
616
+ "upload_limit_paused": progress.get('uploaded_count', 0) >= MAX_UPLOADS_BEFORE_PAUSE
617
  }
618
 
619
+ @app.post("/start_processing")
620
+ async def start_processing(request: ProcessStartRequest, background_tasks: BackgroundTasks):
621
+ """
622
+ Starts the sequential processing of zip files from the given index in the background.
623
+ """
624
+ start_index = request.start_index
625
 
626
+ print(f"[{FLOW_ID}] Received request to start processing from index: {start_index}. Starting background task.")
627
+
628
+ # Start the heavy processing in a background task so the API call returns immediately
629
+ # Note: The server is already auto-starting, but this allows for manual restart/override.
630
+ background_tasks.add_task(process_dataset_task, start_index)
631
+
632
+ return {"status": "processing", "start_index": start_index, "message": "Dataset processing started in background."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
  if __name__ == "__main__":
635
  import uvicorn
636
+ # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
637
  uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)