switch / app.py
factorstudios's picture
Update app.py
14ccf3f verified
raw
history blame
16.3 kB
import os
import json
import time
import asyncio
import aiohttp
import zipfile
import shutil
from typing import Dict, List, Set, Optional, Tuple, Any
from urllib.parse import quote
from datetime import datetime
from pathlib import Path
import io
from fastapi import FastAPI, BackgroundTasks, HTTPException, status
from pydantic import BaseModel, Field
from huggingface_hub import HfApi, hf_hub_download
# --- Configuration ---
AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found
FLOW_ID = os.getenv("FLOW_ID", "flow_default")
FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
HF_TOKEN = os.getenv("HF_TOKEN", "")
HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD")
HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG")
# Progress and State Tracking
PROGRESS_FILE = Path("processing_progress.json")
HF_STATE_FILE = "processing_state_transcriptions.json"
LOCAL_STATE_FOLDER = Path(".state")
LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
# Processing configuration
MAX_UPLOADS_BEFORE_PAUSE = 120 # Pause uploading after 120 files
UPLOAD_PAUSE_ENABLED = True
# Directory within the HF dataset where the audio files are located
AUDIO_FILE_PREFIX = "audio/"
WHISPER_SERVERS = [
f"https://eliasishere-makeitfr-mineo-{i}.hf.space/transcribe" for i in range(1, 21)
]
# Temporary storage for audio files
TEMP_DIR = Path(f"temp_audio_{FLOW_ID}")
TEMP_DIR.mkdir(exist_ok=True)
# --- Models ---
class ProcessStartRequest(BaseModel):
start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).")
class WhisperServer:
def __init__(self, url: str):
self.url = url
self.is_processing = False
self.current_file_index: Optional[int] = None
self.total_processed = 0
self.total_time = 0.0
@property
def fps(self):
"""Files per second"""
return self.total_processed / self.total_time if self.total_time > 0 else 0
def assign_file(self, file_index: int):
"""Assign a file index to this server"""
self.is_processing = True
self.current_file_index = file_index
def release(self):
"""Release the server for a new file"""
self.is_processing = False
self.current_file_index = None
# Global state for whisper servers
servers = [WhisperServer(url) for url in WHISPER_SERVERS]
server_lock = asyncio.Lock() # Lock for thread-safe server state access
# --- Progress and State Management Functions ---
def load_progress() -> Dict:
"""Loads the local processing progress from the JSON file."""
default_structure = {
"last_processed_index": 0,
"processed_files": {}, # {index: repo_path}
"file_list": [], # Full list of all zip files found in the dataset
"uploaded_count": 0
}
if PROGRESS_FILE.exists():
try:
with PROGRESS_FILE.open('r') as f:
data = json.load(f)
# Ensure all keys exist
for key, value in default_structure.items():
if key not in data:
data[key] = value
return data
except json.JSONDecodeError:
print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
return default_structure
def save_progress(progress_data: Dict):
"""Saves the local processing progress to the JSON file."""
try:
with PROGRESS_FILE.open('w') as f:
json.dump(progress_data, f, indent=4)
except Exception as e:
print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}")
def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]:
"""Load state from JSON file with migration logic for new structure."""
if os.path.exists(file_path):
try:
with open(file_path, "r") as f:
data = json.load(f)
if "file_states" not in data or not isinstance(data["file_states"], dict):
data["file_states"] = {}
if "next_download_index" not in data:
data["next_download_index"] = 0
return data
except json.JSONDecodeError:
print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}")
return default_value
def save_json_state(file_path: str, data: Dict[str, Any]):
"""Save state to JSON file"""
with open(file_path, "w") as f:
json.dump(data, f, indent=2)
async def download_hf_state() -> Dict[str, Any]:
"""Downloads the state file from Hugging Face or returns a default state."""
local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
default_state = {"next_download_index": 0, "file_states": {}}
try:
hf_hub_download(
repo_id=HF_OUTPUT_DATASET_ID,
filename=HF_STATE_FILE,
repo_type="dataset",
local_dir=LOCAL_STATE_FOLDER,
local_dir_use_symlinks=False,
token=HF_TOKEN
)
return load_json_state(str(local_path), default_state)
except Exception as e:
print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Using local/default.")
return load_json_state(str(local_path), default_state)
async def upload_hf_state(state: Dict[str, Any]) -> bool:
"""Uploads the state file to Hugging Face."""
local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
try:
save_json_state(str(local_path), state)
HfApi(token=HF_TOKEN).upload_file(
path_or_fileobj=str(local_path),
path_in_repo=HF_STATE_FILE,
repo_id=HF_OUTPUT_DATASET_ID,
repo_type="dataset",
commit_message=f"Update transcription state: next_index={state.get('next_download_index')}"
)
return True
except Exception as e:
print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
return False
# --- Hugging Face Utility Functions ---
async def get_audio_file_list(progress_data: Dict) -> List[str]:
if progress_data.get('file_list'):
return progress_data['file_list']
try:
api = HfApi(token=HF_TOKEN)
repo_files = api.list_repo_files(repo_id=HF_AUDIO_DATASET_ID, repo_type="dataset")
wav_files = sorted([f for f in repo_files if f.endswith('.wav')])
progress_data['file_list'] = wav_files
save_progress(progress_data)
return wav_files
except Exception as e:
print(f"[{FLOW_ID}] Error fetching file list: {e}")
return []
# --- Core Processing Logic ---
async def transcribe_with_server(server: WhisperServer, wav_path: Path) -> Optional[Dict]:
start_time = time.time()
try:
async with aiohttp.ClientSession() as session:
with open(wav_path, 'rb') as f:
data = aiohttp.FormData()
data.add_field('file', f, filename=wav_path.name)
async with session.post(server.url, data=data, timeout=600) as resp:
if resp.status == 200:
result = await resp.json()
elapsed = time.time() - start_time
server.total_processed += 1
server.total_time += elapsed
return result
else:
print(f"[{FLOW_ID}] Server {server.url} returned status {resp.status}")
except Exception as e:
print(f"[{FLOW_ID}] Error transcribing with {server.url}: {e}")
return None
async def process_file_task(wav_file: str, state: Dict, progress: Dict):
# Find an available server
server = None
while server is None:
async with server_lock:
for s in servers:
if not s.is_processing:
s.is_processing = True
server = s
break
if server is None:
await asyncio.sleep(1)
try:
wav_filename = Path(wav_file).name
wav_path = TEMP_DIR / wav_filename
# Download
hf_hub_download(
repo_id=HF_AUDIO_DATASET_ID,
filename=wav_file,
repo_type="dataset",
local_dir=TEMP_DIR,
local_dir_use_symlinks=False,
token=HF_TOKEN
)
# Transcribe
result = await transcribe_with_server(server, wav_path)
if result:
# Upload transcription result to HF
json_filename = Path(wav_file).with_suffix('.json').name
json_content = json.dumps(result, indent=2, ensure_ascii=False).encode('utf-8')
api = HfApi(token=HF_TOKEN)
api.upload_file(
path_or_fileobj=io.BytesIO(json_content),
path_in_repo=json_filename,
repo_id=HF_OUTPUT_DATASET_ID,
repo_type="dataset",
commit_message=f"[{FLOW_ID}] Transcription for {wav_file}"
)
state["file_states"][wav_file] = "processed"
progress["uploaded_count"] = progress.get("uploaded_count", 0) + 1
print(f"[{FLOW_ID}] ✅ Success: {wav_file}")
else:
state["file_states"][wav_file] = "failed_transcription"
print(f"[{FLOW_ID}] ❌ Failed: {wav_file}")
if wav_path.exists():
wav_path.unlink()
except Exception as e:
print(f"[{FLOW_ID}] Error processing {wav_file}: {e}")
state["file_states"][wav_file] = "failed_transcription"
finally:
server.release()
async def main_processing_loop():
print(f"[{FLOW_ID}] Starting main processing loop...")
while True:
try:
state = await download_hf_state()
progress = load_progress()
file_list = await get_audio_file_list(progress)
if not file_list:
print(f"[{FLOW_ID}] File list empty, retrying in 60s...")
await asyncio.sleep(60)
continue
# Check HF_OUTPUT_DATASET_ID for existing JSON outputs
print(f"[{FLOW_ID}] Checking {HF_OUTPUT_DATASET_ID} for existing JSON outputs...")
try:
api = HfApi(token=HF_TOKEN)
existing_files = api.list_repo_files(repo_id=HF_OUTPUT_DATASET_ID, repo_type="dataset")
existing_json_files = {f for f in existing_files if f.endswith('.json')}
print(f"[{FLOW_ID}] Found {len(existing_json_files)} existing JSON files.")
except Exception as e:
print(f"[{FLOW_ID}] Warning: Could not fetch existing files: {e}")
existing_json_files = set()
# 1. Handpick failed_transcription files
failed_files = [f for f, s in state.get("file_states", {}).items() if s == "failed_transcription"]
# 2. Also check for new files based on next_download_index
next_idx = state.get("next_download_index", 0)
# We take a larger chunk to allow for more skipping without re-fetching the list
new_files_chunk = file_list[next_idx:next_idx + 500]
# Combine: Prioritize failed files, then add new ones
files_to_check = failed_files + [f for f in new_files_chunk if f not in state["file_states"]]
if not files_to_check:
print(f"[{FLOW_ID}] No files to process. Sleeping...")
await asyncio.sleep(60)
continue
files_to_process = []
state_changed_locally = False
print(f"[{FLOW_ID}] Scanning {len(files_to_check)} files for existing results...")
for f in files_to_check:
expected_json_name = Path(f).with_suffix('.json').name
if expected_json_name in existing_json_files:
# Mark locally but DO NOT upload yet
if state["file_states"].get(f) != "processed":
state["file_states"][f] = "processed"
state_changed_locally = True
# Update next_download_index if it's a new file
if f in new_files_chunk:
current_idx = file_list.index(f)
if current_idx >= state.get("next_download_index", 0):
state["next_download_index"] = current_idx + 1
continue
# If we reach here, we found an UNPROCESSED file
print(f"[{FLOW_ID}] Found unprocessed file: {f}")
# Before processing, if we have local changes (skips), upload the state once
if state_changed_locally:
print(f"[{FLOW_ID}] Synchronizing skipped files to HF state before processing...")
await upload_hf_state(state)
state_changed_locally = False
files_to_process.append(f)
# Once we find an unprocessed file, we stop the skip-scan and start processing
# This ensures we process files as soon as we find them
break
# If we scanned everything and only found skips, upload the state once at the end
if state_changed_locally and not files_to_process:
print(f"[{FLOW_ID}] Uploading final batch of skips to HF state...")
await upload_hf_state(state)
if not files_to_process:
# If we only found skips, the loop will restart and check the next chunk
continue
print(f"[{FLOW_ID}] Processing batch of {len(files_to_process)} unprocessed files...")
# Process the found unprocessed file(s)
# (In this logic, it's usually just 1 file at a time to ensure frequent state updates)
batch_size = len(servers)
for i in range(0, len(files_to_process), batch_size):
batch = files_to_process[i:i + batch_size]
tasks = [process_file_task(f, state, progress) for f in batch]
await asyncio.gather(*tasks)
# Update next_download_index
for f in batch:
if f in file_list:
current_idx = file_list.index(f)
if current_idx >= state.get("next_download_index", 0):
state["next_download_index"] = current_idx + 1
# Save and upload state after processing the unprocessed file
await upload_hf_state(state)
save_progress(progress)
await asyncio.sleep(2) # Short sleep before looking for the next unprocessed file
except Exception as e:
print(f"[{FLOW_ID}] Error in main loop: {e}")
await asyncio.sleep(60)
# --- FastAPI App ---
app = FastAPI(title=f"Flow Server {FLOW_ID} API")
@app.on_event("startup")
async def startup_event():
asyncio.create_task(main_processing_loop())
@app.get("/")
async def root():
progress = load_progress()
state = await download_hf_state()
failed_count = sum(1 for s in state.get("file_states", {}).values() if s == "failed_transcription")
return {
"flow_id": FLOW_ID,
"status": "running",
"next_download_index": state.get("next_download_index", 0),
"failed_transcriptions": failed_count,
"uploaded_count": progress.get("uploaded_count", 0),
"total_files_in_list": len(progress.get('file_list', []))
}
@app.post("/start_processing")
async def start_processing(request: ProcessStartRequest):
state = await download_hf_state()
state["next_download_index"] = request.start_index - 1
await upload_hf_state(state)
return {"status": "index_reset", "new_index": request.start_index}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)