import os import json import time import asyncio import aiohttp import zipfile import shutil from typing import Dict, List, Set, Optional, Tuple, Any from urllib.parse import quote from datetime import datetime from pathlib import Path import io from fastapi import FastAPI, BackgroundTasks, HTTPException, status from pydantic import BaseModel, Field from huggingface_hub import HfApi, hf_hub_download # --- Configuration --- AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found FLOW_ID = os.getenv("FLOW_ID", "flow_default") FLOW_PORT = int(os.getenv("FLOW_PORT", 8001)) HF_TOKEN = os.getenv("HF_TOKEN", "") HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD") HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG") # Progress and State Tracking PROGRESS_FILE = Path("processing_progress.json") HF_STATE_FILE = "processing_state_transcriptions.json" LOCAL_STATE_FOLDER = Path(".state") LOCAL_STATE_FOLDER.mkdir(exist_ok=True) # Processing configuration MAX_UPLOADS_BEFORE_PAUSE = 120 # Pause uploading after 120 files UPLOAD_PAUSE_ENABLED = True # Directory within the HF dataset where the audio files are located AUDIO_FILE_PREFIX = "audio/" WHISPER_SERVERS = [ f"https://makeitfr-mineo-{i}.hf.space/transcribe" for i in range(1, 21) ] # Temporary storage for audio files TEMP_DIR = Path(f"temp_audio_{FLOW_ID}") TEMP_DIR.mkdir(exist_ok=True) # --- Models --- class ProcessStartRequest(BaseModel): start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).") class WhisperServer: def __init__(self, url: str): self.url = url self.is_processing = False self.current_file_index: Optional[int] = None self.total_processed = 0 self.total_time = 0.0 @property def fps(self): """Files per second""" return self.total_processed / self.total_time if self.total_time > 0 else 0 def assign_file(self, file_index: int): """Assign a file index to this server""" self.is_processing = True self.current_file_index = file_index def release(self): """Release the server for a new file""" self.is_processing = False self.current_file_index = None # Global state for whisper servers servers = [WhisperServer(url) for url in WHISPER_SERVERS] server_lock = asyncio.Lock() # Lock for thread-safe server state access # --- Progress and State Management Functions --- def load_progress() -> Dict: """Loads the local processing progress from the JSON file.""" default_structure = { "last_processed_index": 0, "processed_files": {}, # {index: repo_path} "file_list": [], # Full list of all zip files found in the dataset "uploaded_count": 0 } if PROGRESS_FILE.exists(): try: with PROGRESS_FILE.open('r') as f: data = json.load(f) # Ensure all keys exist for key, value in default_structure.items(): if key not in data: data[key] = value return data except json.JSONDecodeError: print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.") return default_structure def save_progress(progress_data: Dict): """Saves the local processing progress to the JSON file.""" try: with PROGRESS_FILE.open('w') as f: json.dump(progress_data, f, indent=4) except Exception as e: print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}") def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]: """Load state from JSON file with migration logic for new structure.""" if os.path.exists(file_path): try: with open(file_path, "r") as f: data = json.load(f) if "file_states" not in data or not isinstance(data["file_states"], dict): data["file_states"] = {} if "next_download_index" not in data: data["next_download_index"] = 0 return data except json.JSONDecodeError: print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}") return default_value def save_json_state(file_path: str, data: Dict[str, Any]): """Save state to JSON file""" with open(file_path, "w") as f: json.dump(data, f, indent=2) async def download_hf_state() -> Dict[str, Any]: """Downloads the state file from Hugging Face or returns a default state.""" local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE default_state = {"next_download_index": 0, "file_states": {}} try: hf_hub_download( repo_id=HF_OUTPUT_DATASET_ID, filename=HF_STATE_FILE, repo_type="dataset", local_dir=LOCAL_STATE_FOLDER, local_dir_use_symlinks=False, token=HF_TOKEN ) return load_json_state(str(local_path), default_state) except Exception as e: print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Using local/default.") return load_json_state(str(local_path), default_state) async def upload_hf_state(state: Dict[str, Any]) -> bool: """Uploads the state file to Hugging Face.""" local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE try: save_json_state(str(local_path), state) HfApi(token=HF_TOKEN).upload_file( path_or_fileobj=str(local_path), path_in_repo=HF_STATE_FILE, repo_id=HF_OUTPUT_DATASET_ID, repo_type="dataset", commit_message=f"Update transcription state: next_index={state.get('next_download_index')}" ) return True except Exception as e: print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}") return False # --- Hugging Face Utility Functions --- async def get_audio_file_list(progress_data: Dict) -> List[str]: if progress_data.get('file_list'): return progress_data['file_list'] try: api = HfApi(token=HF_TOKEN) repo_files = api.list_repo_files(repo_id=HF_AUDIO_DATASET_ID, repo_type="dataset") wav_files = sorted([f for f in repo_files if f.endswith('.wav')]) progress_data['file_list'] = wav_files save_progress(progress_data) return wav_files except Exception as e: print(f"[{FLOW_ID}] Error fetching file list: {e}") return [] # --- Core Processing Logic --- async def transcribe_with_server(server: WhisperServer, wav_path: Path) -> Optional[Dict]: start_time = time.time() try: async with aiohttp.ClientSession() as session: with open(wav_path, 'rb') as f: data = aiohttp.FormData() data.add_field('file', f, filename=wav_path.name) async with session.post(server.url, data=data, timeout=600) as resp: if resp.status == 200: result = await resp.json() elapsed = time.time() - start_time server.total_processed += 1 server.total_time += elapsed return result else: print(f"[{FLOW_ID}] Server {server.url} returned status {resp.status}") except Exception as e: print(f"[{FLOW_ID}] Error transcribing with {server.url}: {e}") return None async def process_file_task(wav_file: str, state: Dict, progress: Dict): # Find an available server server = None while server is None: async with server_lock: for s in servers: if not s.is_processing: s.is_processing = True server = s break if server is None: await asyncio.sleep(1) try: wav_filename = Path(wav_file).name wav_path = TEMP_DIR / wav_filename # Download hf_hub_download( repo_id=HF_AUDIO_DATASET_ID, filename=wav_file, repo_type="dataset", local_dir=TEMP_DIR, local_dir_use_symlinks=False, token=HF_TOKEN ) # Transcribe result = await transcribe_with_server(server, wav_path) if result: state["file_states"][wav_file] = "processed" progress["uploaded_count"] = progress.get("uploaded_count", 0) + 1 print(f"[{FLOW_ID}] ✅ Success: {wav_file}") else: state["file_states"][wav_file] = "failed_transcription" print(f"[{FLOW_ID}] ❌ Failed: {wav_file}") if wav_path.exists(): wav_path.unlink() except Exception as e: print(f"[{FLOW_ID}] Error processing {wav_file}: {e}") state["file_states"][wav_file] = "failed_transcription" finally: server.release() async def main_processing_loop(): print(f"[{FLOW_ID}] Starting main processing loop...") while True: try: state = await download_hf_state() progress = load_progress() file_list = await get_audio_file_list(progress) if not file_list: print(f"[{FLOW_ID}] File list empty, retrying in 60s...") await asyncio.sleep(60) continue # 1. Handpick failed_transcription files failed_files = [f for f, s in state.get("file_states", {}).items() if s == "failed_transcription"] # 2. Also check for new files based on next_download_index next_idx = state.get("next_download_index", 0) new_files = file_list[next_idx:next_idx + 100] # Take a chunk of new files # Combine: Prioritize failed files, then add new ones files_to_process = failed_files + [f for f in new_files if f not in state["file_states"]] if not files_to_process: print(f"[{FLOW_ID}] No files to process. Sleeping...") await asyncio.sleep(60) continue print(f"[{FLOW_ID}] Processing {len(files_to_process)} files ({len(failed_files)} failed, {len(files_to_process)-len(failed_files)} new)...") # Process in batches of server count batch_size = len(servers) for i in range(0, len(files_to_process), batch_size): batch = files_to_process[i:i + batch_size] tasks = [process_file_task(f, state, progress) for f in batch] await asyncio.gather(*tasks) # Update next_download_index if we processed new files processed_new = [f for f in batch if f in new_files] if processed_new: last_new_file = processed_new[-1] state["next_download_index"] = file_list.index(last_new_file) + 1 # Save and upload state after each batch await upload_hf_state(state) save_progress(progress) await asyncio.sleep(10) except Exception as e: print(f"[{FLOW_ID}] Error in main loop: {e}") await asyncio.sleep(60) # --- FastAPI App --- app = FastAPI(title=f"Flow Server {FLOW_ID} API") @app.on_event("startup") async def startup_event(): asyncio.create_task(main_processing_loop()) @app.get("/") async def root(): progress = load_progress() state = await download_hf_state() failed_count = sum(1 for s in state.get("file_states", {}).values() if s == "failed_transcription") return { "flow_id": FLOW_ID, "status": "running", "next_download_index": state.get("next_download_index", 0), "failed_transcriptions": failed_count, "uploaded_count": progress.get("uploaded_count", 0), "total_files_in_list": len(progress.get('file_list', [])) } @app.post("/start_processing") async def start_processing(request: ProcessStartRequest): state = await download_hf_state() state["next_download_index"] = request.start_index - 1 await upload_hf_state(state) return {"status": "index_reset", "new_index": request.start_index} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)