factorstudios commited on
Commit
a1d4a74
·
verified ·
1 Parent(s): 019259b

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +45 -0
  2. app.py +704 -0
  3. requirements.txt +18 -0
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim-bullseye
2
+
3
+ # Install system dependencies
4
+ RUN sed -i 's/main/main contrib non-free/' /etc/apt/sources.list && \
5
+ apt-get update && \
6
+ apt-get install -y --no-install-recommends \
7
+ unrar \
8
+ libgl1 \
9
+ libglib2.0-0 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /app
13
+
14
+ # Upgrade pip and install core dependencies first
15
+ RUN pip install --no-cache-dir --upgrade pip setuptools wheel packaging
16
+
17
+ # Install CPU-only PyTorch first
18
+
19
+ # Copy requirements and install with special handling for flash_attn
20
+ COPY requirements.txt .
21
+ RUN pip install --no-cache-dir \
22
+ -r requirements.txt \
23
+ --find-links https://download.pytorch.org/whl/cpu \
24
+ --extra-index-url https://pypi.org/simple && \
25
+ # Install remaining packages that might have been skipped
26
+ pip install --no-cache-dir \
27
+ accelerate \
28
+ transformers==4.36.2 \
29
+ timm==0.9.12 \
30
+ einops==0.7.0
31
+
32
+ # Copy application code
33
+ COPY . .
34
+
35
+ # Create non-root user
36
+ RUN useradd -m -u 1000 user && \
37
+ chown -R user:user /app
38
+
39
+ USER user
40
+
41
+ # Environment variables to suppress warnings
42
+ ENV HF_HUB_DISABLE_PROGRESS=1
43
+ ENV TF_CPP_MIN_LOG_LEVEL=3
44
+
45
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import asyncio
5
+ import aiohttp
6
+ import zipfile
7
+ import shutil
8
+ from typing import Dict, List, Set, Optional, Tuple, Any
9
+ from urllib.parse import quote
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ import io
13
+
14
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, status
15
+ from pydantic import BaseModel, Field
16
+ from huggingface_hub import HfApi, hf_hub_download
17
+
18
+ # --- Configuration ---
19
+ AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found
20
+ FLOW_ID = os.getenv("FLOW_ID", "flow_default")
21
+ FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
22
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
23
+ HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD")
24
+ HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG")
25
+
26
+ # Progress and State Tracking
27
+ PROGRESS_FILE = Path("processing_progress.json")
28
+ HF_STATE_FILE = "processing_state_transcriptions.json"
29
+ LOCAL_STATE_FOLDER = Path(".state")
30
+ LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
31
+
32
+ # Processing configuration
33
+ MAX_UPLOADS_BEFORE_PAUSE = 120 # Pause uploading after 120 files
34
+ UPLOAD_PAUSE_ENABLED = True
35
+
36
+ # Directory within the HF dataset where the audio files are located
37
+ AUDIO_FILE_PREFIX = "audio/"
38
+
39
+ WHISPER_SERVERS = [
40
+ "https://makeitfr-mineo-1.hf.space/transcribe",
41
+ "https://makeitfr-mineo-2.hf.space/transcribe",
42
+ "https://makeitfr-mineo-3.hf.space/transcribe",
43
+ "https://makeitfr-mineo-4.hf.space/transcribe",
44
+ "https://makeitfr-mineo-5.hf.space/transcribe",
45
+ "https://makeitfr-mineo-6.hf.space/transcribe",
46
+ "https://makeitfr-mineo-7.hf.space/transcribe",
47
+ "https://makeitfr-mineo-8.hf.space/transcribe",
48
+ "https://makeitfr-mineo-9.hf.space/transcribe",
49
+ "https://makeitfr-mineo-10.hf.space/transcribe",
50
+ "https://makeitfr-mineo-11.hf.space/transcribe",
51
+ "https://makeitfr-mineo-12.hf.space/transcribe",
52
+ "https://makeitfr-mineo-13.hf.space/transcribe",
53
+ "https://makeitfr-mineo-14.hf.space/transcribe",
54
+ "https://makeitfr-mineo-15.hf.space/transcribe",
55
+ "https://makeitfr-mineo-16.hf.space/transcribe",
56
+ "https://makeitfr-mineo-17.hf.space/transcribe",
57
+ "https://makeitfr-mineo-18.hf.space/transcribe",
58
+ "https://makeitfr-mineo-19.hf.space/transcribe",
59
+ "https://makeitfr-mineo-20.hf.space/transcribe"
60
+ ]
61
+
62
+ # Temporary storage for audio files
63
+ TEMP_DIR = Path(f"temp_audio_{FLOW_ID}")
64
+ TEMP_DIR.mkdir(exist_ok=True)
65
+
66
+ # --- Models ---
67
+ class ProcessStartRequest(BaseModel):
68
+ start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).")
69
+
70
+ class WhisperServer:
71
+ def __init__(self, url: str):
72
+ self.url = url
73
+ self.is_processing = False
74
+ self.current_file_index: Optional[int] = None
75
+ self.total_processed = 0
76
+ self.total_time = 0.0
77
+
78
+ @property
79
+ def fps(self):
80
+ """Files per second"""
81
+ return self.total_processed / self.total_time if self.total_time > 0 else 0
82
+
83
+ def assign_file(self, file_index: int):
84
+ """Assign a file index to this server"""
85
+ self.is_processing = True
86
+ self.current_file_index = file_index
87
+
88
+ def release(self):
89
+ """Release the server for a new file"""
90
+ self.is_processing = False
91
+ self.current_file_index = None
92
+
93
+ # Global state for whisper servers
94
+ servers = [WhisperServer(url) for url in WHISPER_SERVERS]
95
+ server_lock = asyncio.Lock() # Lock for thread-safe server state access
96
+
97
+ # --- Progress and State Management Functions ---
98
+
99
+ def load_progress() -> Dict:
100
+ """Loads the local processing progress from the JSON file."""
101
+ if PROGRESS_FILE.exists():
102
+ try:
103
+ with PROGRESS_FILE.open('r') as f:
104
+ return json.load(f)
105
+ except json.JSONDecodeError:
106
+ print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
107
+ # Fall through to return default structure
108
+
109
+ # Default structure
110
+ return {
111
+ "last_processed_index": 0,
112
+ "processed_files": {}, # {index: repo_path}
113
+ "file_list": [] # Full list of all zip files found in the dataset
114
+ }
115
+
116
+ def save_progress(progress_data: Dict):
117
+ """Saves the local processing progress to the JSON file."""
118
+ try:
119
+ with PROGRESS_FILE.open('w') as f:
120
+ json.dump(progress_data, f, indent=4)
121
+ except Exception as e:
122
+ print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}")
123
+
124
+ def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]:
125
+ """Load state from JSON file with migration logic for new structure."""
126
+ if os.path.exists(file_path):
127
+ try:
128
+ with open(file_path, "r") as f:
129
+ data = json.load(f)
130
+
131
+ # Migration Logic
132
+ if "file_states" not in data or not isinstance(data["file_states"], dict):
133
+ print(f"[{FLOW_ID}] Initializing 'file_states' dictionary.")
134
+ data["file_states"] = {}
135
+
136
+ if "next_download_index" not in data:
137
+ data["next_download_index"] = 0
138
+
139
+ return data
140
+ except json.JSONDecodeError:
141
+ print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}")
142
+ return default_value
143
+
144
+ def save_json_state(file_path: str, data: Dict[str, Any]):
145
+ """Save state to JSON file"""
146
+ with open(file_path, "w") as f:
147
+ json.dump(data, f, indent=2)
148
+
149
+ async def download_hf_state() -> Dict[str, Any]:
150
+ """Downloads the state file from Hugging Face or returns a default state."""
151
+ local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
152
+ default_state = {"next_download_index": 0, "file_states": {}}
153
+
154
+ try:
155
+ # Check if the file exists in the helium repo
156
+ files = HfApi(token=HF_TOKEN).list_repo_files(
157
+ repo_id=HF_OUTPUT_DATASET_ID,
158
+ repo_type="dataset"
159
+ )
160
+
161
+ if HF_STATE_FILE not in files:
162
+ print(f"[{FLOW_ID}] State file not found in {HF_OUTPUT_DATASET_ID}. Starting fresh.")
163
+ return default_state
164
+
165
+ # Download the file
166
+ hf_hub_download(
167
+ repo_id=HF_OUTPUT_DATASET_ID,
168
+ filename=HF_STATE_FILE,
169
+ repo_type="dataset",
170
+ local_dir=LOCAL_STATE_FOLDER,
171
+ local_dir_use_symlinks=False,
172
+ token=HF_TOKEN
173
+ )
174
+
175
+ print(f"[{FLOW_ID}] Successfully downloaded state file.")
176
+ return load_json_state(str(local_path), default_state)
177
+
178
+ except Exception as e:
179
+ print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Starting fresh.")
180
+ return default_state
181
+
182
+ async def upload_hf_state(state: Dict[str, Any]) -> bool:
183
+ """Uploads the state file to Hugging Face."""
184
+ local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
185
+
186
+ try:
187
+ # Save state locally first
188
+ save_json_state(str(local_path), state)
189
+
190
+ # Upload to helium dataset
191
+ HfApi(token=HF_TOKEN).upload_file(
192
+ path_or_fileobj=str(local_path),
193
+ path_in_repo=HF_STATE_FILE,
194
+ repo_id=HF_OUTPUT_DATASET_ID,
195
+ repo_type="dataset",
196
+ commit_message=f"Update caption processing state: next_index={state['next_download_index']}"
197
+ )
198
+ print(f"[{FLOW_ID}] Successfully uploaded state file.")
199
+ return True
200
+ except Exception as e:
201
+ print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
202
+ return False
203
+
204
+ async def lock_file_for_processing(zip_filename: str, state: Dict[str, Any]) -> bool:
205
+ """Marks a file as 'processing' in the state file and uploads the lock."""
206
+ print(f"[{FLOW_ID}] 🔒 Attempting to lock file: {zip_filename}")
207
+
208
+ # Update state locally
209
+ state["file_states"][zip_filename] = "processing"
210
+
211
+ # Upload the updated state file immediately to establish the lock
212
+ if await upload_hf_state(state):
213
+ print(f"[{FLOW_ID}] ✅ Successfully locked file: {zip_filename}")
214
+ return True
215
+ else:
216
+ print(f"[{FLOW_ID}] ❌ Failed to lock file: {zip_filename}")
217
+ # Revert local state
218
+ if zip_filename in state["file_states"]:
219
+ del state["file_states"][zip_filename]
220
+ return False
221
+
222
+ async def unlock_file_as_processed(zip_filename: str, state: Dict[str, Any], next_index: int) -> bool:
223
+ """Marks a file as 'processed', updates the index, and uploads the state."""
224
+ print(f"[{FLOW_ID}] 🔓 Marking file as processed: {zip_filename}")
225
+
226
+ # Update state locally
227
+ state["file_states"][zip_filename] = "processed"
228
+ state["next_download_index"] = next_index
229
+
230
+ # Upload the updated state
231
+ if await upload_hf_state(state):
232
+ print(f"[{FLOW_ID}] ✅ Successfully marked as processed: {zip_filename}")
233
+ return True
234
+ else:
235
+ print(f"[{FLOW_ID}] ❌ Failed to update state for: {zip_filename}")
236
+ return False
237
+
238
+ # --- Hugging Face Utility Functions ---
239
+
240
+ async def get_audio_file_list(progress_data: Dict) -> List[str]:
241
+ """
242
+ Fetches the list of all WAV files from the dataset, or uses the cached list.
243
+ Updates the progress_data with the file list if a new list is fetched.
244
+ """
245
+ if progress_data['file_list']:
246
+ print(f"[{FLOW_ID}] Using cached file list with {len(progress_data['file_list'])} files.")
247
+ return progress_data['file_list']
248
+
249
+ print(f"[{FLOW_ID}] Fetching full list of WAV files from {HF_AUDIO_DATASET_ID}...")
250
+ try:
251
+ api = HfApi(token=HF_TOKEN)
252
+ repo_files = api.list_repo_files(
253
+ repo_id=HF_AUDIO_DATASET_ID,
254
+ repo_type="dataset"
255
+ )
256
+
257
+ # Filter for WAV files and sort them alphabetically for consistent indexing
258
+ wav_files = sorted([
259
+ f for f in repo_files
260
+ if f.endswith('.wav')
261
+ ])
262
+
263
+ if not wav_files:
264
+ raise FileNotFoundError(f"No WAV files found in dataset '{HF_AUDIO_DATASET_ID}'.")
265
+
266
+ print(f"[{FLOW_ID}] Found {len(wav_files)} WAV files.")
267
+
268
+ # Update and save the progress data
269
+ progress_data['file_list'] = wav_files
270
+ save_progress(progress_data)
271
+
272
+ return wav_files
273
+
274
+ except Exception as e:
275
+ print(f"[{FLOW_ID}] Error fetching file list from Hugging Face: {e}")
276
+ return []
277
+
278
+ async def download_wav_file_by_index(file_index: int, repo_file_full_path: str) -> Optional[Path]:
279
+ """Downloads a WAV file from the repository."""
280
+
281
+ wav_filename = Path(repo_file_full_path).name
282
+
283
+ print(f"[{FLOW_ID}] Downloading file #{file_index}: {repo_file_full_path}")
284
+
285
+ try:
286
+ # Download the file into our TEMP_DIR (so we can safely delete it later)
287
+ wav_path = hf_hub_download(
288
+ repo_id=HF_AUDIO_DATASET_ID,
289
+ filename=repo_file_full_path,
290
+ repo_type="dataset",
291
+ token=HF_TOKEN,
292
+ local_dir=str(TEMP_DIR),
293
+ local_dir_use_symlinks=False,
294
+ )
295
+
296
+ print(f"[{FLOW_ID}] Downloaded WAV file to {wav_path}")
297
+ return Path(wav_path)
298
+
299
+ except Exception as e:
300
+ print(f"[{FLOW_ID}] Error downloading WAV file {repo_file_full_path}: {e}")
301
+ return None
302
+
303
+ async def upload_transcription_to_hf(wav_filename: str, transcription_data: Dict) -> bool:
304
+ """Uploads the transcription JSON file to the output dataset."""
305
+ # Use the full WAV path, replacing slashes with underscores and extension with .json
306
+ json_filename = wav_filename.replace('/', '_').replace('\\', '_').rsplit('.', 1)[0] + '.json'
307
+
308
+ try:
309
+ print(f"[{FLOW_ID}] Uploading transcription for {wav_filename} as {json_filename} to {HF_OUTPUT_DATASET_ID}...")
310
+
311
+ # Create JSON content in memory
312
+ json_content = json.dumps(transcription_data, indent=2, ensure_ascii=False).encode('utf-8')
313
+
314
+ api = HfApi(token=HF_TOKEN)
315
+ api.upload_file(
316
+ path_or_fileobj=io.BytesIO(json_content),
317
+ path_in_repo=json_filename,
318
+ repo_id=HF_OUTPUT_DATASET_ID,
319
+ repo_type="dataset",
320
+ commit_message=f"[{FLOW_ID}] Transcription for {wav_filename}"
321
+ )
322
+
323
+ print(f"[{FLOW_ID}] Successfully uploaded transcription for {wav_filename}.")
324
+ return True
325
+
326
+ except Exception as e:
327
+ print(f"[{FLOW_ID}] Error uploading transcription for {wav_filename}: {e}")
328
+ return False
329
+
330
+ # --- Core Processing Functions ---
331
+
332
+ async def send_audio_to_whisper(wav_path: Path, server: WhisperServer) -> Optional[Dict]:
333
+ """Sends a WAV file to a Whisper server for transcription."""
334
+ try:
335
+ print(f"[{FLOW_ID}] Sending {wav_path.name} to {server.url}...")
336
+
337
+ start_time = time.time()
338
+
339
+ # Prepare multipart form data
340
+ form_data = aiohttp.FormData()
341
+ # Open the file in a context manager so the descriptor is closed after the request
342
+ with wav_path.open('rb') as f:
343
+ form_data.add_field('file', f, filename=wav_path.name, content_type='audio/wav')
344
+
345
+ async with aiohttp.ClientSession() as session:
346
+ # 10 minute timeout for transcription
347
+ async with session.post(server.url, data=form_data, timeout=600) as resp:
348
+ if resp.status == 200:
349
+ result = await resp.json()
350
+ end_time = time.time()
351
+
352
+ # Update server stats
353
+ server.total_processed += 1
354
+ server.total_time += (end_time - start_time)
355
+
356
+ print(f"[{FLOW_ID}] ✓ {wav_path.name} transcribed successfully by {server.url}")
357
+
358
+ return {
359
+ "file": wav_path.name,
360
+ "transcription": result,
361
+ "timestamp": datetime.now().isoformat(),
362
+ "processing_time_seconds": end_time - start_time
363
+ }
364
+ else:
365
+ error_text = await resp.text()
366
+ print(f"[{FLOW_ID}] ✗ Error from {server.url}: {resp.status} - {error_text}")
367
+ return None
368
+
369
+ except asyncio.TimeoutError:
370
+ print(f"[{FLOW_ID}] ✗ Timeout from {server.url} for {wav_path.name}")
371
+ return None
372
+ except Exception as e:
373
+ print(f"[{FLOW_ID}] ✗ Exception on {server.url} for {wav_path.name}: {e}")
374
+ return None
375
+
376
+ async def get_available_servers() -> List[WhisperServer]:
377
+ """
378
+ Returns a list of servers that are not currently processing.
379
+ Dynamically assigns new files to available servers.
380
+ """
381
+ async with server_lock:
382
+ available = [s for s in servers if not s.is_processing]
383
+ return available
384
+
385
+ async def assign_file_to_server(file_index: int, server: WhisperServer):
386
+ """Safely assign a file to a server"""
387
+ async with server_lock:
388
+ server.assign_file(file_index)
389
+
390
+ async def release_server(server: WhisperServer):
391
+ """Safely release a server for new work"""
392
+ async with server_lock:
393
+ server.release()
394
+
395
+ async def process_batch_dynamic(wav_files: List[str], start_batch_index: int, batch_size: int, state: Dict[str, Any], progress: Dict) -> Tuple[int, int]:
396
+ """
397
+ Processes a batch of WAV files in parallel using available servers.
398
+ Batch size = number of servers. Each server gets one file, processes it, then gets the next.
399
+ Includes retry mechanism for failed files.
400
+ Returns (next_batch_index, uploaded_count)
401
+ """
402
+ batch_end = min(start_batch_index + batch_size, len(wav_files))
403
+ uploaded_count = progress.get('uploaded_count', 0)
404
+ max_retries = 3
405
+ failed_files = [] # Track files that failed for retry
406
+
407
+ print(f"[{FLOW_ID}] Processing batch from index {start_batch_index} to {batch_end - 1} ({batch_end - start_batch_index} files)")
408
+
409
+ # --- Batch-level locking: mark all files in this batch as 'processing' and upload state
410
+ try:
411
+ state.setdefault("file_states", {})
412
+ for idx in range(start_batch_index, batch_end):
413
+ wav_file = wav_files[idx]
414
+ state["file_states"][wav_file] = "processing"
415
+
416
+ # Update next_download_index to the end of this batch (0-based)
417
+ state["next_download_index"] = batch_end
418
+
419
+ # Upload HF state to establish locks for this batch
420
+ if await upload_hf_state(state):
421
+ print(f"[{FLOW_ID}] ✅ Batch locked: files {start_batch_index}-{batch_end - 1} marked 'processing'")
422
+ else:
423
+ print(f"[{FLOW_ID}] ❌ Failed to upload batch lock")
424
+ return start_batch_index, uploaded_count
425
+ except Exception as e:
426
+ print(f"[{FLOW_ID}] Error while setting up batch locks: {e}")
427
+ return start_batch_index, uploaded_count
428
+
429
+ # Create a queue of files to process with retry support
430
+ files_to_process = [(idx, wav_files[idx], 0) for idx in range(start_batch_index, batch_end)] # (idx, wav_file, retry_count)
431
+
432
+ # --- Assign files to servers and create tasks
433
+ pending_tasks: Dict[asyncio.Task, Tuple[int, Path, WhisperServer, str, int]] = {}
434
+
435
+ try:
436
+ while files_to_process or pending_tasks:
437
+ # Assign new files to available servers
438
+ while files_to_process:
439
+ available = await get_available_servers()
440
+ if not available:
441
+ break
442
+
443
+ file_idx, wav_file, retry_count = files_to_process.pop(0)
444
+ wav_filename = Path(wav_file).name
445
+ server = available[0]
446
+
447
+ # Download the WAV file
448
+ wav_path = await download_wav_file_by_index(file_idx + 1, wav_file)
449
+ if not wav_path:
450
+ if retry_count < max_retries:
451
+ print(f"[{FLOW_ID}] ⚠️ Download failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...")
452
+ files_to_process.append((file_idx, wav_file, retry_count + 1))
453
+ else:
454
+ state["file_states"][wav_file] = "failed_download"
455
+ print(f"[{FLOW_ID}] ❌ Download failed permanently for {wav_filename} after {max_retries} retries")
456
+ continue
457
+
458
+ # Assign to server and create task
459
+ await assign_file_to_server(file_idx, server)
460
+ task = asyncio.create_task(send_audio_to_whisper(wav_path, server))
461
+ pending_tasks[task] = (file_idx, wav_path, server, wav_file, retry_count)
462
+ print(f"[{FLOW_ID}] Assigned {wav_filename} to server {servers.index(server) + 1}")
463
+
464
+ # Wait for at least one task to complete if there are pending tasks
465
+ if not pending_tasks:
466
+ break
467
+
468
+ done, pending = await asyncio.wait(
469
+ pending_tasks.keys(),
470
+ return_when=asyncio.FIRST_COMPLETED
471
+ )
472
+
473
+ for task in done:
474
+ file_idx, wav_path, server, wav_file, retry_count = pending_tasks.pop(task)
475
+ wav_filename = Path(wav_file).name
476
+
477
+ try:
478
+ transcription_result = task.result()
479
+
480
+ if transcription_result:
481
+ # Upload transcription immediately with full path
482
+ uploaded_ok = await upload_transcription_to_hf(wav_file, transcription_result)
483
+ if uploaded_ok:
484
+ # Update state locally but do NOT upload to HF yet
485
+ state["file_states"][wav_file] = "processed"
486
+ uploaded_count += 1
487
+ progress['uploaded_count'] = uploaded_count
488
+ save_progress(progress)
489
+ print(f"[{FLOW_ID}] ✅ {wav_filename} uploaded (#{uploaded_count})")
490
+ else:
491
+ # Retry failed upload
492
+ if retry_count < max_retries:
493
+ print(f"[{FLOW_ID}] ⚠️ Upload failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...")
494
+ files_to_process.append((file_idx, wav_file, retry_count + 1))
495
+ else:
496
+ state["file_states"][wav_file] = "failed_upload"
497
+ print(f"[{FLOW_ID}] ❌ Upload failed permanently for {wav_filename} after {max_retries} retries")
498
+ else:
499
+ # Retry failed transcription
500
+ if retry_count < max_retries:
501
+ print(f"[{FLOW_ID}] ⚠️ Transcription failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...")
502
+ files_to_process.append((file_idx, wav_file, retry_count + 1))
503
+ else:
504
+ state["file_states"][wav_file] = "failed_transcription"
505
+ print(f"[{FLOW_ID}] ❌ Transcription failed permanently for {wav_filename} after {max_retries} retries")
506
+
507
+ except Exception as e:
508
+ if retry_count < max_retries:
509
+ print(f"[{FLOW_ID}] ⚠️ Error processing {wav_filename}: {e} (retry {retry_count + 1}/{max_retries}), re-queueing...")
510
+ files_to_process.append((file_idx, wav_file, retry_count + 1))
511
+ else:
512
+ print(f"[{FLOW_ID}] ❌ Error processing {wav_filename}: {e} (failed after {max_retries} retries)")
513
+ state["file_states"][wav_file] = "failed_error"
514
+ finally:
515
+ # Release the server
516
+ await release_server(server)
517
+ # Clean up the WAV file
518
+ if wav_path.exists():
519
+ wav_path.unlink()
520
+
521
+ # --- After all files in this batch are uploaded, update HF state once
522
+ if await upload_hf_state(state):
523
+ print(f"[{FLOW_ID}] ✅ Batch state updated on HF: files {start_batch_index}-{batch_end - 1} marked processed")
524
+ else:
525
+ print(f"[{FLOW_ID}] ❌ Failed to update batch state on HF")
526
+ except Exception as e:
527
+ print(f"[{FLOW_ID}] Error in process_batch_dynamic: {e}")
528
+
529
+ return batch_end, uploaded_count
530
+
531
+ async def process_dataset_task(start_index: int):
532
+ """Main task to process the dataset using dynamic server assignment."""
533
+
534
+ # Load both local progress and HF state
535
+ progress = load_progress()
536
+ current_state = await download_hf_state()
537
+ file_list = await get_audio_file_list(progress)
538
+
539
+ if not file_list:
540
+ print(f"[{FLOW_ID}] ERROR: Cannot proceed. File list is empty.")
541
+ return False
542
+
543
+ # Ensure start_index is within bounds
544
+ if start_index > len(file_list):
545
+ print(f"[{FLOW_ID}] WARNING: Start index {start_index} is greater than the total number of files ({len(file_list)}). Exiting.")
546
+ return True
547
+
548
+ # Determine the actual starting index in the 0-indexed list
549
+ start_list_index = start_index - 1
550
+
551
+ print(f"[{FLOW_ID}] Starting audio transcription from file index: {start_index} out of {len(file_list)}.")
552
+ print(f"[{FLOW_ID}] Using {len(servers)} Whisper servers for dynamic processing.")
553
+ print(f"[{FLOW_ID}] Upload pause enabled: {UPLOAD_PAUSE_ENABLED}, Max uploads before pause: {MAX_UPLOADS_BEFORE_PAUSE}")
554
+
555
+ # Initialize progress tracking
556
+ if 'uploaded_count' not in progress:
557
+ progress['uploaded_count'] = 0
558
+
559
+ # If there was no HF state in the repo, upload a fresh initial state file
560
+ try:
561
+ if not current_state.get("file_states") and current_state.get("next_download_index", 0) == 0:
562
+ print(f"[{FLOW_ID}] No HF state detected; uploading initial state file to {HF_OUTPUT_DATASET_ID}...")
563
+ # Ensure structure
564
+ current_state.setdefault("file_states", {})
565
+ current_state.setdefault("next_download_index", 0)
566
+ if await upload_hf_state(current_state):
567
+ print(f"[{FLOW_ID}] ✅ Initial HF state uploaded.")
568
+ else:
569
+ print(f"[{FLOW_ID}] ❌ Failed to upload initial HF state.")
570
+ except Exception as e:
571
+ print(f"[{FLOW_ID}] Error while uploading initial HF state: {e}")
572
+
573
+ global_success = True
574
+ current_batch_index = start_list_index
575
+ batch_size = len(servers) # Batch size = number of servers (20 files per batch)
576
+ batch_interval_seconds = 600 # 600 seconds = 10 minutes (enforces max 6 batches per hour)
577
+
578
+ try:
579
+ batch_count = 0
580
+ while current_batch_index < len(file_list):
581
+ batch_start_time = time.time()
582
+
583
+ # Process a batch dynamically
584
+ next_index, uploaded_count = await process_batch_dynamic(
585
+ file_list,
586
+ current_batch_index,
587
+ batch_size,
588
+ current_state,
589
+ progress
590
+ )
591
+
592
+ batch_end_time = time.time()
593
+ batch_elapsed = batch_end_time - batch_start_time
594
+
595
+ # Update progress
596
+ progress['last_processed_index'] = next_index
597
+ progress['uploaded_count'] = uploaded_count
598
+ save_progress(progress)
599
+
600
+ # Update current batch index
601
+ current_batch_index = next_index
602
+ batch_count += 1
603
+
604
+ # Log statistics
605
+ print(f"[{FLOW_ID}] Batch complete. Progress: {current_batch_index}/{len(file_list)}, Uploaded: {uploaded_count}")
606
+
607
+ # Print server statistics
608
+ print(f"[{FLOW_ID}] Server Statistics:")
609
+ for i, server in enumerate(servers):
610
+ print(f" Server {i+1}: {server.total_processed} files, {server.total_time:.2f}s total, {server.fps:.2f} files/sec")
611
+
612
+ # Rate limiting: enforce minimum 10 minutes between batch starts (max 6 batches per hour)
613
+ if current_batch_index < len(file_list): # Don't wait after the last batch
614
+ wait_time = batch_interval_seconds - batch_elapsed
615
+ if wait_time > 0:
616
+ print(f"[{FLOW_ID}] Rate limit: batch took {batch_elapsed:.1f}s. Waiting {wait_time:.1f}s before next batch (min 10 min interval)...")
617
+ await asyncio.sleep(wait_time)
618
+ else:
619
+ print(f"[{FLOW_ID}] Batch took {batch_elapsed:.1f}s (exceeded 10 min interval). Proceeding immediately to next batch.")
620
+
621
+ print(f"[{FLOW_ID}] All files processed successfully! Total batches: {batch_count}")
622
+ return True
623
+
624
+ except Exception as e:
625
+ print(f"[{FLOW_ID}] Critical error in process_dataset_task: {e}")
626
+ global_success = False
627
+ return global_success
628
+
629
+ # --- FastAPI App and Endpoints ---
630
+
631
+ app = FastAPI(
632
+ title=f"Flow Server {FLOW_ID} API",
633
+ description="Sequentially processes zip files from a dataset, captions images, and tracks progress.",
634
+ version="1.0.0"
635
+ )
636
+
637
+ @app.on_event("startup")
638
+ async def startup_event():
639
+ print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}.")
640
+
641
+ # Get both local progress and HF state
642
+ progress = load_progress()
643
+ current_state = await download_hf_state()
644
+
645
+ # Get the next_download_index from HF state if available
646
+ hf_next_index = current_state.get("next_download_index", 0)
647
+
648
+ # If HF state has a higher index, use that instead of local progress
649
+ if hf_next_index > 0:
650
+ start_index = hf_next_index
651
+ print(f"[{FLOW_ID}] Using next_download_index from HF state: {start_index}")
652
+ else:
653
+ # Fall back to local progress if HF state doesn't have a meaningful index
654
+ start_index = progress.get('last_processed_index', 0) + 1
655
+ if start_index < AUTO_START_INDEX:
656
+ start_index = AUTO_START_INDEX
657
+
658
+ # Use a dummy BackgroundTasks object for the startup task
659
+ # Note: FastAPI's startup events can't directly use BackgroundTasks, but we can use asyncio.create_task
660
+ # to run the long-running process in the background without blocking the server startup.
661
+ print(f"[{FLOW_ID}] Auto-starting processing from index: {start_index}...")
662
+ asyncio.create_task(process_dataset_task(start_index))
663
+
664
+ @app.get("/")
665
+ async def root():
666
+ progress = load_progress()
667
+
668
+ # Calculate server stats
669
+ total_processed = sum(s.total_processed for s in servers)
670
+ total_time = sum(s.total_time for s in servers)
671
+ avg_fps = total_processed / total_time if total_time > 0 else 0
672
+
673
+ return {
674
+ "flow_id": FLOW_ID,
675
+ "status": "ready",
676
+ "last_processed_index": progress.get('last_processed_index', 0),
677
+ "total_files_in_list": len(progress['file_list']),
678
+ "uploaded_count": progress.get('uploaded_count', 0),
679
+ "total_servers": len(servers),
680
+ "processing_servers": sum(1 for s in servers if s.is_processing),
681
+ "total_files_processed_by_servers": total_processed,
682
+ "avg_files_per_second": avg_fps,
683
+ "upload_limit_paused": progress.get('uploaded_count', 0) >= MAX_UPLOADS_BEFORE_PAUSE
684
+ }
685
+
686
+ @app.post("/start_processing")
687
+ async def start_processing(request: ProcessStartRequest, background_tasks: BackgroundTasks):
688
+ """
689
+ Starts the sequential processing of zip files from the given index in the background.
690
+ """
691
+ start_index = request.start_index
692
+
693
+ print(f"[{FLOW_ID}] Received request to start processing from index: {start_index}. Starting background task.")
694
+
695
+ # Start the heavy processing in a background task so the API call returns immediately
696
+ # Note: The server is already auto-starting, but this allows for manual restart/override.
697
+ background_tasks.add_task(process_dataset_task, start_index)
698
+
699
+ return {"status": "processing", "start_index": start_index, "message": "Dataset processing started in background."}
700
+
701
+ if __name__ == "__main__":
702
+ import uvicorn
703
+ # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
704
+ uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ accelerate
3
+ fastapi
4
+ uvicorn
5
+ opencv-python-headless
6
+ numpy
7
+ pathlib
8
+ huggingface_hub
9
+ pillow
10
+ rarfile
11
+ python-multipart
12
+ openai-whisper
13
+ ffmpeg-python
14
+ transformers
15
+ librosa
16
+ torch
17
+ torchaudio
18
+ aiohttp