Spaces:
Paused
Paused
| import os | |
| import json | |
| import time | |
| import asyncio | |
| import aiohttp | |
| import zipfile | |
| import shutil | |
| from typing import Dict, List, Set, Optional, Tuple, Any | |
| from urllib.parse import quote | |
| from datetime import datetime | |
| from pathlib import Path | |
| import io | |
| from fastapi import FastAPI, BackgroundTasks, HTTPException, status | |
| from pydantic import BaseModel, Field | |
| from huggingface_hub import HfApi, hf_hub_download | |
| # --- Configuration --- | |
| AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found | |
| FLOW_ID = os.getenv("FLOW_ID", "flow_default") | |
| FLOW_PORT = int(os.getenv("FLOW_PORT", 8001)) | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD") | |
| HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG") | |
| # Progress and State Tracking | |
| PROGRESS_FILE = Path("processing_progress.json") | |
| HF_STATE_FILE = "processing_state_transcriptions.json" | |
| LOCAL_STATE_FOLDER = Path(".state") | |
| LOCAL_STATE_FOLDER.mkdir(exist_ok=True) | |
| # Processing configuration | |
| MAX_UPLOADS_BEFORE_PAUSE = 120 # Pause uploading after 120 files | |
| UPLOAD_PAUSE_ENABLED = True | |
| # Directory within the HF dataset where the audio files are located | |
| AUDIO_FILE_PREFIX = "audio/" | |
| WHISPER_SERVERS = [ | |
| "https://makeitfr-mineo-1.hf.space/transcribe", | |
| "https://makeitfr-mineo-2.hf.space/transcribe", | |
| "https://makeitfr-mineo-3.hf.space/transcribe", | |
| "https://makeitfr-mineo-4.hf.space/transcribe", | |
| "https://makeitfr-mineo-5.hf.space/transcribe", | |
| "https://makeitfr-mineo-6.hf.space/transcribe", | |
| "https://makeitfr-mineo-7.hf.space/transcribe", | |
| "https://makeitfr-mineo-8.hf.space/transcribe", | |
| "https://makeitfr-mineo-9.hf.space/transcribe", | |
| "https://makeitfr-mineo-10.hf.space/transcribe", | |
| "https://makeitfr-mineo-11.hf.space/transcribe", | |
| "https://makeitfr-mineo-12.hf.space/transcribe", | |
| "https://makeitfr-mineo-13.hf.space/transcribe", | |
| "https://makeitfr-mineo-14.hf.space/transcribe", | |
| "https://makeitfr-mineo-15.hf.space/transcribe", | |
| "https://makeitfr-mineo-16.hf.space/transcribe", | |
| "https://makeitfr-mineo-17.hf.space/transcribe", | |
| "https://makeitfr-mineo-18.hf.space/transcribe", | |
| "https://makeitfr-mineo-19.hf.space/transcribe", | |
| "https://makeitfr-mineo-20.hf.space/transcribe" | |
| ] | |
| # Temporary storage for audio files | |
| TEMP_DIR = Path(f"temp_audio_{FLOW_ID}") | |
| TEMP_DIR.mkdir(exist_ok=True) | |
| # --- Models --- | |
| class ProcessStartRequest(BaseModel): | |
| start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).") | |
| class WhisperServer: | |
| def __init__(self, url: str): | |
| self.url = url | |
| self.is_processing = False | |
| self.current_file_index: Optional[int] = None | |
| self.total_processed = 0 | |
| self.total_time = 0.0 | |
| def fps(self): | |
| """Files per second""" | |
| return self.total_processed / self.total_time if self.total_time > 0 else 0 | |
| def assign_file(self, file_index: int): | |
| """Assign a file index to this server""" | |
| self.is_processing = True | |
| self.current_file_index = file_index | |
| def release(self): | |
| """Release the server for a new file""" | |
| self.is_processing = False | |
| self.current_file_index = None | |
| # Global state for whisper servers | |
| servers = [WhisperServer(url) for url in WHISPER_SERVERS] | |
| server_lock = asyncio.Lock() # Lock for thread-safe server state access | |
| # --- Progress and State Management Functions --- | |
| def load_progress() -> Dict: | |
| """Loads the local processing progress from the JSON file.""" | |
| if PROGRESS_FILE.exists(): | |
| try: | |
| with PROGRESS_FILE.open('r') as f: | |
| return json.load(f) | |
| except json.JSONDecodeError: | |
| print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.") | |
| # Fall through to return default structure | |
| # Default structure | |
| return { | |
| "last_processed_index": 0, | |
| "processed_files": {}, # {index: repo_path} | |
| "file_list": [] # Full list of all zip files found in the dataset | |
| } | |
| def save_progress(progress_data: Dict): | |
| """Saves the local processing progress to the JSON file.""" | |
| try: | |
| with PROGRESS_FILE.open('w') as f: | |
| json.dump(progress_data, f, indent=4) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}") | |
| def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]: | |
| """Load state from JSON file with migration logic for new structure.""" | |
| if os.path.exists(file_path): | |
| try: | |
| with open(file_path, "r") as f: | |
| data = json.load(f) | |
| # Migration Logic | |
| if "file_states" not in data or not isinstance(data["file_states"], dict): | |
| print(f"[{FLOW_ID}] Initializing 'file_states' dictionary.") | |
| data["file_states"] = {} | |
| if "next_download_index" not in data: | |
| data["next_download_index"] = 0 | |
| return data | |
| except json.JSONDecodeError: | |
| print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}") | |
| return default_value | |
| def save_json_state(file_path: str, data: Dict[str, Any]): | |
| """Save state to JSON file""" | |
| with open(file_path, "w") as f: | |
| json.dump(data, f, indent=2) | |
| async def download_hf_state() -> Dict[str, Any]: | |
| """Downloads the state file from Hugging Face or returns a default state.""" | |
| local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE | |
| default_state = {"next_download_index": 0, "file_states": {}} | |
| try: | |
| # Check if the file exists in the helium repo | |
| files = HfApi(token=HF_TOKEN).list_repo_files( | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| if HF_STATE_FILE not in files: | |
| print(f"[{FLOW_ID}] State file not found in {HF_OUTPUT_DATASET_ID}. Starting fresh.") | |
| return default_state | |
| # Download the file | |
| hf_hub_download( | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| filename=HF_STATE_FILE, | |
| repo_type="dataset", | |
| local_dir=LOCAL_STATE_FOLDER, | |
| local_dir_use_symlinks=False, | |
| token=HF_TOKEN | |
| ) | |
| print(f"[{FLOW_ID}] Successfully downloaded state file.") | |
| return load_json_state(str(local_path), default_state) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Starting fresh.") | |
| return default_state | |
| async def upload_hf_state(state: Dict[str, Any]) -> bool: | |
| """Uploads the state file to Hugging Face.""" | |
| local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE | |
| try: | |
| # Save state locally first | |
| save_json_state(str(local_path), state) | |
| # Upload to helium dataset | |
| HfApi(token=HF_TOKEN).upload_file( | |
| path_or_fileobj=str(local_path), | |
| path_in_repo=HF_STATE_FILE, | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| repo_type="dataset", | |
| commit_message=f"Update caption processing state: next_index={state['next_download_index']}" | |
| ) | |
| print(f"[{FLOW_ID}] Successfully uploaded state file.") | |
| return True | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}") | |
| return False | |
| async def lock_file_for_processing(zip_filename: str, state: Dict[str, Any]) -> bool: | |
| """Marks a file as 'processing' in the state file and uploads the lock.""" | |
| print(f"[{FLOW_ID}] 🔒 Attempting to lock file: {zip_filename}") | |
| # Update state locally | |
| state["file_states"][zip_filename] = "processing" | |
| # Upload the updated state file immediately to establish the lock | |
| if await upload_hf_state(state): | |
| print(f"[{FLOW_ID}] ✅ Successfully locked file: {zip_filename}") | |
| return True | |
| else: | |
| print(f"[{FLOW_ID}] ❌ Failed to lock file: {zip_filename}") | |
| # Revert local state | |
| if zip_filename in state["file_states"]: | |
| del state["file_states"][zip_filename] | |
| return False | |
| async def unlock_file_as_processed(zip_filename: str, state: Dict[str, Any], next_index: int) -> bool: | |
| """Marks a file as 'processed', updates the index, and uploads the state.""" | |
| print(f"[{FLOW_ID}] 🔓 Marking file as processed: {zip_filename}") | |
| # Update state locally | |
| state["file_states"][zip_filename] = "processed" | |
| state["next_download_index"] = next_index | |
| # Upload the updated state | |
| if await upload_hf_state(state): | |
| print(f"[{FLOW_ID}] ✅ Successfully marked as processed: {zip_filename}") | |
| return True | |
| else: | |
| print(f"[{FLOW_ID}] ❌ Failed to update state for: {zip_filename}") | |
| return False | |
| # --- Hugging Face Utility Functions --- | |
| async def get_audio_file_list(progress_data: Dict) -> List[str]: | |
| """ | |
| Fetches the list of all WAV files from the dataset, or uses the cached list. | |
| Updates the progress_data with the file list if a new list is fetched. | |
| """ | |
| if progress_data['file_list']: | |
| print(f"[{FLOW_ID}] Using cached file list with {len(progress_data['file_list'])} files.") | |
| return progress_data['file_list'] | |
| print(f"[{FLOW_ID}] Fetching full list of WAV files from {HF_AUDIO_DATASET_ID}...") | |
| try: | |
| api = HfApi(token=HF_TOKEN) | |
| repo_files = api.list_repo_files( | |
| repo_id=HF_AUDIO_DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| # Filter for WAV files and sort them alphabetically for consistent indexing | |
| wav_files = sorted([ | |
| f for f in repo_files | |
| if f.endswith('.wav') | |
| ]) | |
| if not wav_files: | |
| raise FileNotFoundError(f"No WAV files found in dataset '{HF_AUDIO_DATASET_ID}'.") | |
| print(f"[{FLOW_ID}] Found {len(wav_files)} WAV files.") | |
| # Update and save the progress data | |
| progress_data['file_list'] = wav_files | |
| save_progress(progress_data) | |
| return wav_files | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error fetching file list from Hugging Face: {e}") | |
| return [] | |
| async def download_wav_file_by_index(file_index: int, repo_file_full_path: str) -> Optional[Path]: | |
| """Downloads a WAV file from the repository.""" | |
| wav_filename = Path(repo_file_full_path).name | |
| print(f"[{FLOW_ID}] Downloading file #{file_index}: {repo_file_full_path}") | |
| try: | |
| # Download the file into our TEMP_DIR (so we can safely delete it later) | |
| wav_path = hf_hub_download( | |
| repo_id=HF_AUDIO_DATASET_ID, | |
| filename=repo_file_full_path, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| local_dir=str(TEMP_DIR), | |
| local_dir_use_symlinks=False, | |
| ) | |
| print(f"[{FLOW_ID}] Downloaded WAV file to {wav_path}") | |
| return Path(wav_path) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error downloading WAV file {repo_file_full_path}: {e}") | |
| return None | |
| async def upload_transcription_to_hf(wav_filename: str, transcription_data: Dict) -> bool: | |
| """Uploads the transcription JSON file to the output dataset.""" | |
| # Use the full WAV path, replacing slashes with underscores and extension with .json | |
| json_filename = wav_filename.replace('/', '_').replace('\\', '_').rsplit('.', 1)[0] + '.json' | |
| try: | |
| print(f"[{FLOW_ID}] Uploading transcription for {wav_filename} as {json_filename} to {HF_OUTPUT_DATASET_ID}...") | |
| # Create JSON content in memory | |
| json_content = json.dumps(transcription_data, indent=2, ensure_ascii=False).encode('utf-8') | |
| api = HfApi(token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=io.BytesIO(json_content), | |
| path_in_repo=json_filename, | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| repo_type="dataset", | |
| commit_message=f"[{FLOW_ID}] Transcription for {wav_filename}" | |
| ) | |
| print(f"[{FLOW_ID}] Successfully uploaded transcription for {wav_filename}.") | |
| return True | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error uploading transcription for {wav_filename}: {e}") | |
| return False | |
| # --- Core Processing Functions --- | |
| async def send_audio_to_whisper(wav_path: Path, server: WhisperServer) -> Optional[Dict]: | |
| """Sends a WAV file to a Whisper server for transcription.""" | |
| try: | |
| print(f"[{FLOW_ID}] Sending {wav_path.name} to {server.url}...") | |
| start_time = time.time() | |
| # Prepare multipart form data | |
| form_data = aiohttp.FormData() | |
| # Open the file in a context manager so the descriptor is closed after the request | |
| with wav_path.open('rb') as f: | |
| form_data.add_field('file', f, filename=wav_path.name, content_type='audio/wav') | |
| async with aiohttp.ClientSession() as session: | |
| # 10 minute timeout for transcription | |
| async with session.post(server.url, data=form_data, timeout=600) as resp: | |
| if resp.status == 200: | |
| result = await resp.json() | |
| end_time = time.time() | |
| # Update server stats | |
| server.total_processed += 1 | |
| server.total_time += (end_time - start_time) | |
| print(f"[{FLOW_ID}] ✓ {wav_path.name} transcribed successfully by {server.url}") | |
| return { | |
| "file": wav_path.name, | |
| "transcription": result, | |
| "timestamp": datetime.now().isoformat(), | |
| "processing_time_seconds": end_time - start_time | |
| } | |
| else: | |
| error_text = await resp.text() | |
| print(f"[{FLOW_ID}] ✗ Error from {server.url}: {resp.status} - {error_text}") | |
| return None | |
| except asyncio.TimeoutError: | |
| print(f"[{FLOW_ID}] ✗ Timeout from {server.url} for {wav_path.name}") | |
| return None | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] ✗ Exception on {server.url} for {wav_path.name}: {e}") | |
| return None | |
| async def get_available_servers() -> List[WhisperServer]: | |
| """ | |
| Returns a list of servers that are not currently processing. | |
| Dynamically assigns new files to available servers. | |
| """ | |
| async with server_lock: | |
| available = [s for s in servers if not s.is_processing] | |
| return available | |
| async def assign_file_to_server(file_index: int, server: WhisperServer): | |
| """Safely assign a file to a server""" | |
| async with server_lock: | |
| server.assign_file(file_index) | |
| async def release_server(server: WhisperServer): | |
| """Safely release a server for new work""" | |
| async with server_lock: | |
| server.release() | |
| async def process_batch_dynamic(wav_files: List[str], start_batch_index: int, batch_size: int, state: Dict[str, Any], progress: Dict) -> Tuple[int, int]: | |
| """ | |
| Processes a batch of WAV files in parallel using available servers. | |
| Batch size = number of servers. Each server gets one file, processes it, then gets the next. | |
| Includes retry mechanism for failed files. | |
| Returns (next_batch_index, uploaded_count) | |
| """ | |
| batch_end = min(start_batch_index + batch_size, len(wav_files)) | |
| uploaded_count = progress.get('uploaded_count', 0) | |
| max_retries = 3 | |
| failed_files = [] # Track files that failed for retry | |
| print(f"[{FLOW_ID}] Processing batch from index {start_batch_index} to {batch_end - 1} ({batch_end - start_batch_index} files)") | |
| # --- Batch-level locking: mark all files in this batch as 'processing' and upload state | |
| try: | |
| state.setdefault("file_states", {}) | |
| for idx in range(start_batch_index, batch_end): | |
| wav_file = wav_files[idx] | |
| state["file_states"][wav_file] = "processing" | |
| # Update next_download_index to the end of this batch (0-based) | |
| state["next_download_index"] = batch_end | |
| # Upload HF state to establish locks for this batch | |
| if await upload_hf_state(state): | |
| print(f"[{FLOW_ID}] ✅ Batch locked: files {start_batch_index}-{batch_end - 1} marked 'processing'") | |
| else: | |
| print(f"[{FLOW_ID}] ❌ Failed to upload batch lock") | |
| return start_batch_index, uploaded_count | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error while setting up batch locks: {e}") | |
| return start_batch_index, uploaded_count | |
| # Create a queue of files to process with retry support | |
| files_to_process = [(idx, wav_files[idx], 0) for idx in range(start_batch_index, batch_end)] # (idx, wav_file, retry_count) | |
| # --- Assign files to servers and create tasks | |
| pending_tasks: Dict[asyncio.Task, Tuple[int, Path, WhisperServer, str, int]] = {} | |
| try: | |
| while files_to_process or pending_tasks: | |
| # Assign new files to available servers | |
| while files_to_process: | |
| available = await get_available_servers() | |
| if not available: | |
| break | |
| file_idx, wav_file, retry_count = files_to_process.pop(0) | |
| wav_filename = Path(wav_file).name | |
| server = available[0] | |
| # Download the WAV file | |
| wav_path = await download_wav_file_by_index(file_idx + 1, wav_file) | |
| if not wav_path: | |
| if retry_count < max_retries: | |
| print(f"[{FLOW_ID}] ⚠️ Download failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...") | |
| files_to_process.append((file_idx, wav_file, retry_count + 1)) | |
| else: | |
| state["file_states"][wav_file] = "failed_download" | |
| print(f"[{FLOW_ID}] ❌ Download failed permanently for {wav_filename} after {max_retries} retries") | |
| continue | |
| # Assign to server and create task | |
| await assign_file_to_server(file_idx, server) | |
| task = asyncio.create_task(send_audio_to_whisper(wav_path, server)) | |
| pending_tasks[task] = (file_idx, wav_path, server, wav_file, retry_count) | |
| print(f"[{FLOW_ID}] Assigned {wav_filename} to server {servers.index(server) + 1}") | |
| # Wait for at least one task to complete if there are pending tasks | |
| if not pending_tasks: | |
| break | |
| done, pending = await asyncio.wait( | |
| pending_tasks.keys(), | |
| return_when=asyncio.FIRST_COMPLETED | |
| ) | |
| for task in done: | |
| file_idx, wav_path, server, wav_file, retry_count = pending_tasks.pop(task) | |
| wav_filename = Path(wav_file).name | |
| try: | |
| transcription_result = task.result() | |
| if transcription_result: | |
| # Upload transcription immediately with full path | |
| uploaded_ok = await upload_transcription_to_hf(wav_file, transcription_result) | |
| if uploaded_ok: | |
| # Update state locally but do NOT upload to HF yet | |
| state["file_states"][wav_file] = "processed" | |
| uploaded_count += 1 | |
| progress['uploaded_count'] = uploaded_count | |
| save_progress(progress) | |
| print(f"[{FLOW_ID}] ✅ {wav_filename} uploaded (#{uploaded_count})") | |
| else: | |
| # Retry failed upload | |
| if retry_count < max_retries: | |
| print(f"[{FLOW_ID}] ⚠️ Upload failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...") | |
| files_to_process.append((file_idx, wav_file, retry_count + 1)) | |
| else: | |
| state["file_states"][wav_file] = "failed_upload" | |
| print(f"[{FLOW_ID}] ❌ Upload failed permanently for {wav_filename} after {max_retries} retries") | |
| else: | |
| # Retry failed transcription | |
| if retry_count < max_retries: | |
| print(f"[{FLOW_ID}] ⚠️ Transcription failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...") | |
| files_to_process.append((file_idx, wav_file, retry_count + 1)) | |
| else: | |
| state["file_states"][wav_file] = "failed_transcription" | |
| print(f"[{FLOW_ID}] ❌ Transcription failed permanently for {wav_filename} after {max_retries} retries") | |
| except Exception as e: | |
| if retry_count < max_retries: | |
| print(f"[{FLOW_ID}] ⚠️ Error processing {wav_filename}: {e} (retry {retry_count + 1}/{max_retries}), re-queueing...") | |
| files_to_process.append((file_idx, wav_file, retry_count + 1)) | |
| else: | |
| print(f"[{FLOW_ID}] ❌ Error processing {wav_filename}: {e} (failed after {max_retries} retries)") | |
| state["file_states"][wav_file] = "failed_error" | |
| finally: | |
| # Release the server | |
| await release_server(server) | |
| # Clean up the WAV file | |
| if wav_path.exists(): | |
| wav_path.unlink() | |
| # --- After all files in this batch are uploaded, update HF state once | |
| if await upload_hf_state(state): | |
| print(f"[{FLOW_ID}] ✅ Batch state updated on HF: files {start_batch_index}-{batch_end - 1} marked processed") | |
| else: | |
| print(f"[{FLOW_ID}] ❌ Failed to update batch state on HF") | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error in process_batch_dynamic: {e}") | |
| return batch_end, uploaded_count | |
| async def process_dataset_task(start_index: int): | |
| """Main task to process the dataset using dynamic server assignment.""" | |
| # Load both local progress and HF state | |
| progress = load_progress() | |
| current_state = await download_hf_state() | |
| file_list = await get_audio_file_list(progress) | |
| if not file_list: | |
| print(f"[{FLOW_ID}] ERROR: Cannot proceed. File list is empty.") | |
| return False | |
| # Ensure start_index is within bounds | |
| if start_index > len(file_list): | |
| print(f"[{FLOW_ID}] WARNING: Start index {start_index} is greater than the total number of files ({len(file_list)}). Exiting.") | |
| return True | |
| # Determine the actual starting index in the 0-indexed list | |
| start_list_index = start_index - 1 | |
| print(f"[{FLOW_ID}] Starting audio transcription from file index: {start_index} out of {len(file_list)}.") | |
| print(f"[{FLOW_ID}] Using {len(servers)} Whisper servers for dynamic processing.") | |
| print(f"[{FLOW_ID}] Upload pause enabled: {UPLOAD_PAUSE_ENABLED}, Max uploads before pause: {MAX_UPLOADS_BEFORE_PAUSE}") | |
| # Initialize progress tracking | |
| if 'uploaded_count' not in progress: | |
| progress['uploaded_count'] = 0 | |
| # If there was no HF state in the repo, upload a fresh initial state file | |
| try: | |
| if not current_state.get("file_states") and current_state.get("next_download_index", 0) == 0: | |
| print(f"[{FLOW_ID}] No HF state detected; uploading initial state file to {HF_OUTPUT_DATASET_ID}...") | |
| # Ensure structure | |
| current_state.setdefault("file_states", {}) | |
| current_state.setdefault("next_download_index", 0) | |
| if await upload_hf_state(current_state): | |
| print(f"[{FLOW_ID}] ✅ Initial HF state uploaded.") | |
| else: | |
| print(f"[{FLOW_ID}] ❌ Failed to upload initial HF state.") | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error while uploading initial HF state: {e}") | |
| global_success = True | |
| current_batch_index = start_list_index | |
| batch_size = len(servers) # Batch size = number of servers (20 files per batch) | |
| batch_interval_seconds = 600 # 600 seconds = 10 minutes (enforces max 6 batches per hour) | |
| try: | |
| batch_count = 0 | |
| while current_batch_index < len(file_list): | |
| batch_start_time = time.time() | |
| # Process a batch dynamically | |
| next_index, uploaded_count = await process_batch_dynamic( | |
| file_list, | |
| current_batch_index, | |
| batch_size, | |
| current_state, | |
| progress | |
| ) | |
| batch_end_time = time.time() | |
| batch_elapsed = batch_end_time - batch_start_time | |
| # Update progress | |
| progress['last_processed_index'] = next_index | |
| progress['uploaded_count'] = uploaded_count | |
| save_progress(progress) | |
| # Update current batch index | |
| current_batch_index = next_index | |
| batch_count += 1 | |
| # Log statistics | |
| print(f"[{FLOW_ID}] Batch complete. Progress: {current_batch_index}/{len(file_list)}, Uploaded: {uploaded_count}") | |
| # Print server statistics | |
| print(f"[{FLOW_ID}] Server Statistics:") | |
| for i, server in enumerate(servers): | |
| print(f" Server {i+1}: {server.total_processed} files, {server.total_time:.2f}s total, {server.fps:.2f} files/sec") | |
| # Rate limiting: enforce minimum 10 minutes between batch starts (max 6 batches per hour) | |
| if current_batch_index < len(file_list): # Don't wait after the last batch | |
| wait_time = batch_interval_seconds - batch_elapsed | |
| if wait_time > 0: | |
| print(f"[{FLOW_ID}] Rate limit: batch took {batch_elapsed:.1f}s. Waiting {wait_time:.1f}s before next batch (min 10 min interval)...") | |
| await asyncio.sleep(wait_time) | |
| else: | |
| print(f"[{FLOW_ID}] Batch took {batch_elapsed:.1f}s (exceeded 10 min interval). Proceeding immediately to next batch.") | |
| print(f"[{FLOW_ID}] All files processed successfully! Total batches: {batch_count}") | |
| return True | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Critical error in process_dataset_task: {e}") | |
| global_success = False | |
| return global_success | |
| # --- FastAPI App and Endpoints --- | |
| app = FastAPI( | |
| title=f"Flow Server {FLOW_ID} API", | |
| description="Sequentially processes zip files from a dataset, captions images, and tracks progress.", | |
| version="1.0.0" | |
| ) | |
| async def startup_event(): | |
| print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}.") | |
| # Get both local progress and HF state | |
| progress = load_progress() | |
| current_state = await download_hf_state() | |
| # Get the next_download_index from HF state if available | |
| hf_next_index = current_state.get("next_download_index", 0) | |
| # If HF state has a higher index, use that instead of local progress | |
| if hf_next_index > 0: | |
| start_index = hf_next_index | |
| print(f"[{FLOW_ID}] Using next_download_index from HF state: {start_index}") | |
| else: | |
| # Fall back to local progress if HF state doesn't have a meaningful index | |
| start_index = progress.get('last_processed_index', 0) + 1 | |
| if start_index < AUTO_START_INDEX: | |
| start_index = AUTO_START_INDEX | |
| # Use a dummy BackgroundTasks object for the startup task | |
| # Note: FastAPI's startup events can't directly use BackgroundTasks, but we can use asyncio.create_task | |
| # to run the long-running process in the background without blocking the server startup. | |
| print(f"[{FLOW_ID}] Auto-starting processing from index: {start_index}...") | |
| asyncio.create_task(process_dataset_task(start_index)) | |
| async def root(): | |
| progress = load_progress() | |
| # Calculate server stats | |
| total_processed = sum(s.total_processed for s in servers) | |
| total_time = sum(s.total_time for s in servers) | |
| avg_fps = total_processed / total_time if total_time > 0 else 0 | |
| return { | |
| "flow_id": FLOW_ID, | |
| "status": "ready", | |
| "last_processed_index": progress.get('last_processed_index', 0), | |
| "total_files_in_list": len(progress['file_list']), | |
| "uploaded_count": progress.get('uploaded_count', 0), | |
| "total_servers": len(servers), | |
| "processing_servers": sum(1 for s in servers if s.is_processing), | |
| "total_files_processed_by_servers": total_processed, | |
| "avg_files_per_second": avg_fps, | |
| "upload_limit_paused": progress.get('uploaded_count', 0) >= MAX_UPLOADS_BEFORE_PAUSE | |
| } | |
| async def start_processing(request: ProcessStartRequest, background_tasks: BackgroundTasks): | |
| """ | |
| Starts the sequential processing of zip files from the given index in the background. | |
| """ | |
| start_index = request.start_index | |
| print(f"[{FLOW_ID}] Received request to start processing from index: {start_index}. Starting background task.") | |
| # Start the heavy processing in a background task so the API call returns immediately | |
| # Note: The server is already auto-starting, but this allows for manual restart/override. | |
| background_tasks.add_task(process_dataset_task, start_index) | |
| return {"status": "processing", "start_index": start_index, "message": "Dataset processing started in background."} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port. | |
| uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT) |