Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -394,202 +394,186 @@ async def release_server(server: WhisperServer):
|
|
| 394 |
|
| 395 |
async def process_batch_dynamic(wav_files: List[str], start_batch_index: int, batch_size: int, state: Dict[str, Any], progress: Dict) -> Tuple[int, int]:
|
| 396 |
"""
|
| 397 |
-
|
|
|
|
| 398 |
Returns (next_batch_index, uploaded_count)
|
| 399 |
"""
|
| 400 |
batch_end = min(start_batch_index + batch_size, len(wav_files))
|
| 401 |
-
current_index = start_batch_index
|
| 402 |
uploaded_count = progress.get('uploaded_count', 0)
|
| 403 |
|
| 404 |
-
|
| 405 |
-
pending_tasks: Dict[asyncio.Task, Tuple[int, Path, WhisperServer]] = {}
|
| 406 |
-
|
| 407 |
-
print(f"[{FLOW_ID}] Processing batch from index {start_batch_index} to {batch_end}")
|
| 408 |
|
| 409 |
# --- Batch-level locking: mark all files in this batch as 'processing' and upload state
|
| 410 |
try:
|
|
|
|
| 411 |
for idx in range(start_batch_index, batch_end):
|
| 412 |
wav_file = wav_files[idx]
|
| 413 |
wav_name = Path(wav_file).name
|
| 414 |
-
state
|
| 415 |
-
# Only set to processing if it's not already processed/processing
|
| 416 |
-
if state["file_states"].get(wav_name) not in ("processing", "processed"):
|
| 417 |
-
state["file_states"][wav_name] = "processing"
|
| 418 |
|
| 419 |
-
#
|
| 420 |
state["next_download_index"] = batch_end
|
| 421 |
|
| 422 |
# Upload HF state to establish locks for this batch
|
| 423 |
if await upload_hf_state(state):
|
| 424 |
-
print(f"[{FLOW_ID}] ✅ Batch
|
| 425 |
else:
|
| 426 |
-
print(f"[{FLOW_ID}] ❌ Failed to upload batch lock
|
|
|
|
| 427 |
except Exception as e:
|
| 428 |
print(f"[{FLOW_ID}] Error while setting up batch locks: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
try:
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
if not available_servers:
|
| 437 |
-
# All servers busy, wait a bit
|
| 438 |
-
await asyncio.sleep(0.5)
|
| 439 |
-
continue
|
| 440 |
-
|
| 441 |
-
server = available_servers[0]
|
| 442 |
-
file_index = current_index
|
| 443 |
-
wav_file = wav_files[file_index]
|
| 444 |
-
wav_filename = Path(wav_file).name
|
| 445 |
-
|
| 446 |
-
# Download the WAV file
|
| 447 |
-
wav_path = await download_wav_file_by_index(file_index + 1, wav_file)
|
| 448 |
-
if not wav_path:
|
| 449 |
-
state["file_states"][wav_filename] = "failed"
|
| 450 |
-
# Persist failure to HF
|
| 451 |
-
await upload_hf_state(state)
|
| 452 |
-
current_index += 1
|
| 453 |
-
continue
|
| 454 |
-
|
| 455 |
-
# Assign to server and create task
|
| 456 |
-
await assign_file_to_server(file_index, server)
|
| 457 |
-
task = asyncio.create_task(send_audio_to_whisper(wav_path, server))
|
| 458 |
-
pending_tasks[task] = (file_index, wav_path, server)
|
| 459 |
-
|
| 460 |
-
current_index += 1
|
| 461 |
|
| 462 |
-
#
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
file_index, wav_path, server = pending_tasks.pop(task)
|
| 472 |
-
wav_filename = Path(wav_path).name
|
| 473 |
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
if
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
progress['uploaded_count'] = uploaded_count
|
| 484 |
-
save_progress(progress)
|
| 485 |
-
else:
|
| 486 |
-
state["file_states"][wav_filename] = "failed"
|
| 487 |
-
# Persist state change for this file immediately
|
| 488 |
-
await upload_hf_state(state)
|
| 489 |
else:
|
| 490 |
state["file_states"][wav_filename] = "failed"
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
except Exception as e:
|
| 494 |
-
print(f"[{FLOW_ID}] Error processing result for {wav_filename}: {e}")
|
| 495 |
state["file_states"][wav_filename] = "failed"
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
|
|
|
|
|
|
| 509 |
except Exception as e:
|
| 510 |
print(f"[{FLOW_ID}] Error in process_batch_dynamic: {e}")
|
|
|
|
|
|
|
| 511 |
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
async def process_dataset_task(start_index: int):
|
| 515 |
-
"""Main task to process the dataset using dynamic server assignment."""
|
| 516 |
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
|
| 531 |
-
|
| 532 |
-
|
| 533 |
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
|
| 578 |
-
|
| 579 |
-
|
| 580 |
|
| 581 |
-
|
| 582 |
-
print(f"[{FLOW_ID}]
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
print(f"[{FLOW_ID}] All files processed successfully!")
|
| 587 |
-
return True
|
| 588 |
-
|
| 589 |
-
except Exception as e:
|
| 590 |
-
print(f"[{FLOW_ID}] Critical error in process_dataset_task: {e}")
|
| 591 |
-
global_success = False
|
| 592 |
-
return global_success
|
| 593 |
|
| 594 |
# --- FastAPI App and Endpoints ---
|
| 595 |
|
|
|
|
| 394 |
|
| 395 |
async def process_batch_dynamic(wav_files: List[str], start_batch_index: int, batch_size: int, state: Dict[str, Any], progress: Dict) -> Tuple[int, int]:
|
| 396 |
"""
|
| 397 |
+
Processes a batch of WAV files in parallel using available servers.
|
| 398 |
+
Batch size = number of servers. Each server gets one file, processes it, then gets the next.
|
| 399 |
Returns (next_batch_index, uploaded_count)
|
| 400 |
"""
|
| 401 |
batch_end = min(start_batch_index + batch_size, len(wav_files))
|
|
|
|
| 402 |
uploaded_count = progress.get('uploaded_count', 0)
|
| 403 |
|
| 404 |
+
print(f"[{FLOW_ID}] Processing batch from index {start_batch_index} to {batch_end - 1} ({batch_end - start_batch_index} files)")
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
# --- Batch-level locking: mark all files in this batch as 'processing' and upload state
|
| 407 |
try:
|
| 408 |
+
state.setdefault("file_states", {})
|
| 409 |
for idx in range(start_batch_index, batch_end):
|
| 410 |
wav_file = wav_files[idx]
|
| 411 |
wav_name = Path(wav_file).name
|
| 412 |
+
state["file_states"][wav_name] = "processing"
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
+
# Update next_download_index to the end of this batch (0-based)
|
| 415 |
state["next_download_index"] = batch_end
|
| 416 |
|
| 417 |
# Upload HF state to establish locks for this batch
|
| 418 |
if await upload_hf_state(state):
|
| 419 |
+
print(f"[{FLOW_ID}] ✅ Batch locked: files {start_batch_index}-{batch_end - 1} marked 'processing'")
|
| 420 |
else:
|
| 421 |
+
print(f"[{FLOW_ID}] ❌ Failed to upload batch lock")
|
| 422 |
+
return start_batch_index, uploaded_count
|
| 423 |
except Exception as e:
|
| 424 |
print(f"[{FLOW_ID}] Error while setting up batch locks: {e}")
|
| 425 |
+
return start_batch_index, uploaded_count
|
| 426 |
+
|
| 427 |
+
# --- Assign files to servers and create tasks
|
| 428 |
+
pending_tasks: Dict[asyncio.Task, Tuple[int, Path, WhisperServer, str]] = {}
|
| 429 |
|
| 430 |
try:
|
| 431 |
+
for idx in range(start_batch_index, batch_end):
|
| 432 |
+
file_index = idx
|
| 433 |
+
wav_file = wav_files[file_index]
|
| 434 |
+
wav_filename = Path(wav_file).name
|
| 435 |
+
server = servers[idx - start_batch_index] # Assign server in order (server 0 -> file 0, server 1 -> file 1, etc.)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
+
# Download the WAV file
|
| 438 |
+
wav_path = await download_wav_file_by_index(file_index + 1, wav_file)
|
| 439 |
+
if not wav_path:
|
| 440 |
+
state["file_states"][wav_filename] = "failed"
|
| 441 |
+
await upload_hf_state(state)
|
| 442 |
+
print(f"[{FLOW_ID}] ❌ Failed to download {wav_filename}")
|
| 443 |
+
continue
|
| 444 |
+
|
| 445 |
+
# Assign to server and create task
|
| 446 |
+
await assign_file_to_server(file_index, server)
|
| 447 |
+
task = asyncio.create_task(send_audio_to_whisper(wav_path, server))
|
| 448 |
+
pending_tasks[task] = (file_index, wav_path, server, wav_filename)
|
| 449 |
+
print(f"[{FLOW_ID}] Assigned {wav_filename} to server {servers.index(server) + 1}")
|
| 450 |
+
|
| 451 |
+
# --- Wait for all tasks in this batch to complete and process results
|
| 452 |
+
while pending_tasks:
|
| 453 |
+
done, pending = await asyncio.wait(
|
| 454 |
+
pending_tasks.keys(),
|
| 455 |
+
return_when=asyncio.FIRST_COMPLETED
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
for task in done:
|
| 459 |
+
file_index, wav_path, server, wav_filename = pending_tasks.pop(task)
|
| 460 |
|
| 461 |
+
try:
|
| 462 |
+
transcription_result = task.result()
|
|
|
|
|
|
|
| 463 |
|
| 464 |
+
if transcription_result:
|
| 465 |
+
# Upload transcription immediately
|
| 466 |
+
uploaded_ok = await upload_transcription_to_hf(wav_filename, transcription_result)
|
| 467 |
+
if uploaded_ok:
|
| 468 |
+
state["file_states"][wav_filename] = "processed"
|
| 469 |
+
uploaded_count += 1
|
| 470 |
+
progress['uploaded_count'] = uploaded_count
|
| 471 |
+
save_progress(progress)
|
| 472 |
+
print(f"[{FLOW_ID}] ✅ {wav_filename} uploaded (#{uploaded_count})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
else:
|
| 474 |
state["file_states"][wav_filename] = "failed"
|
| 475 |
+
print(f"[{FLOW_ID}] ❌ Failed to upload {wav_filename}")
|
| 476 |
+
else:
|
|
|
|
|
|
|
| 477 |
state["file_states"][wav_filename] = "failed"
|
| 478 |
+
print(f"[{FLOW_ID}] ❌ Transcription failed for {wav_filename}")
|
| 479 |
+
|
| 480 |
+
# Persist state change for this file to HF
|
| 481 |
+
await upload_hf_state(state)
|
| 482 |
+
|
| 483 |
+
except Exception as e:
|
| 484 |
+
print(f"[{FLOW_ID}] Error processing result for {wav_filename}: {e}")
|
| 485 |
+
state["file_states"][wav_filename] = "failed"
|
| 486 |
+
await upload_hf_state(state)
|
| 487 |
+
finally:
|
| 488 |
+
# Release the server
|
| 489 |
+
await release_server(server)
|
| 490 |
+
# Clean up the WAV file
|
| 491 |
+
if wav_path.exists():
|
| 492 |
+
wav_path.unlink()
|
| 493 |
except Exception as e:
|
| 494 |
print(f"[{FLOW_ID}] Error in process_batch_dynamic: {e}")
|
| 495 |
+
|
| 496 |
+
return batch_end, uploaded_count
|
| 497 |
|
| 498 |
+
async def process_dataset_task(start_index: int):
|
| 499 |
+
"""Main task to process the dataset using dynamic server assignment."""
|
|
|
|
|
|
|
| 500 |
|
| 501 |
+
# Load both local progress and HF state
|
| 502 |
+
progress = load_progress()
|
| 503 |
+
current_state = await download_hf_state()
|
| 504 |
+
file_list = await get_audio_file_list(progress)
|
| 505 |
|
| 506 |
+
if not file_list:
|
| 507 |
+
print(f"[{FLOW_ID}] ERROR: Cannot proceed. File list is empty.")
|
| 508 |
+
return False
|
| 509 |
|
| 510 |
+
# Ensure start_index is within bounds
|
| 511 |
+
if start_index > len(file_list):
|
| 512 |
+
print(f"[{FLOW_ID}] WARNING: Start index {start_index} is greater than the total number of files ({len(file_list)}). Exiting.")
|
| 513 |
+
return True
|
| 514 |
|
| 515 |
+
# Determine the actual starting index in the 0-indexed list
|
| 516 |
+
start_list_index = start_index - 1
|
| 517 |
|
| 518 |
+
print(f"[{FLOW_ID}] Starting audio transcription from file index: {start_index} out of {len(file_list)}.")
|
| 519 |
+
print(f"[{FLOW_ID}] Using {len(servers)} Whisper servers for dynamic processing.")
|
| 520 |
+
print(f"[{FLOW_ID}] Upload pause enabled: {UPLOAD_PAUSE_ENABLED}, Max uploads before pause: {MAX_UPLOADS_BEFORE_PAUSE}")
|
| 521 |
|
| 522 |
+
# Initialize progress tracking
|
| 523 |
+
if 'uploaded_count' not in progress:
|
| 524 |
+
progress['uploaded_count'] = 0
|
| 525 |
|
| 526 |
+
# If there was no HF state in the repo, upload a fresh initial state file
|
| 527 |
+
try:
|
| 528 |
+
if not current_state.get("file_states") and current_state.get("next_download_index", 0) == 0:
|
| 529 |
+
print(f"[{FLOW_ID}] No HF state detected; uploading initial state file to {HF_OUTPUT_DATASET_ID}...")
|
| 530 |
+
# Ensure structure
|
| 531 |
+
current_state.setdefault("file_states", {})
|
| 532 |
+
current_state.setdefault("next_download_index", 0)
|
| 533 |
+
if await upload_hf_state(current_state):
|
| 534 |
+
print(f"[{FLOW_ID}] ✅ Initial HF state uploaded.")
|
| 535 |
+
else:
|
| 536 |
+
print(f"[{FLOW_ID}] ❌ Failed to upload initial HF state.")
|
| 537 |
+
except Exception as e:
|
| 538 |
+
print(f"[{FLOW_ID}] Error while uploading initial HF state: {e}")
|
| 539 |
+
global_success = True
|
| 540 |
+
current_batch_index = start_list_index
|
| 541 |
+
batch_size = len(servers) # Batch size = number of servers (20 files per batch)
|
| 542 |
|
| 543 |
+
try:
|
| 544 |
+
while current_batch_index < len(file_list):
|
| 545 |
+
# Process a batch dynamically
|
| 546 |
+
next_index, uploaded_count = await process_batch_dynamic(
|
| 547 |
+
file_list,
|
| 548 |
+
current_batch_index,
|
| 549 |
+
batch_size,
|
| 550 |
+
current_state,
|
| 551 |
+
progress
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
# Update progress
|
| 555 |
+
progress['last_processed_index'] = next_index
|
| 556 |
+
progress['uploaded_count'] = uploaded_count
|
| 557 |
+
save_progress(progress)
|
| 558 |
+
|
| 559 |
+
# Update current batch index
|
| 560 |
+
current_batch_index = next_index
|
| 561 |
+
|
| 562 |
+
# Log statistics
|
| 563 |
+
print(f"[{FLOW_ID}] Batch complete. Progress: {current_batch_index}/{len(file_list)}, Uploaded: {uploaded_count}")
|
| 564 |
+
|
| 565 |
+
# Print server statistics
|
| 566 |
+
print(f"[{FLOW_ID}] Server Statistics:")
|
| 567 |
+
for i, server in enumerate(servers):
|
| 568 |
+
print(f" Server {i+1}: {server.total_processed} files, {server.total_time:.2f}s total, {server.fps:.2f} files/sec")
|
| 569 |
|
| 570 |
+
print(f"[{FLOW_ID}] All files processed successfully!")
|
| 571 |
+
return True
|
| 572 |
|
| 573 |
+
except Exception as e:
|
| 574 |
+
print(f"[{FLOW_ID}] Critical error in process_dataset_task: {e}")
|
| 575 |
+
global_success = False
|
| 576 |
+
return global_success
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
|
| 578 |
# --- FastAPI App and Endpoints ---
|
| 579 |
|