factorstudios commited on
Commit
0a4682f
·
verified ·
1 Parent(s): a1d4a74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -704
app.py CHANGED
@@ -1,704 +1,326 @@
1
- import os
2
- import json
3
- import time
4
- import asyncio
5
- import aiohttp
6
- import zipfile
7
- import shutil
8
- from typing import Dict, List, Set, Optional, Tuple, Any
9
- from urllib.parse import quote
10
- from datetime import datetime
11
- from pathlib import Path
12
- import io
13
-
14
- from fastapi import FastAPI, BackgroundTasks, HTTPException, status
15
- from pydantic import BaseModel, Field
16
- from huggingface_hub import HfApi, hf_hub_download
17
-
18
- # --- Configuration ---
19
- AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found
20
- FLOW_ID = os.getenv("FLOW_ID", "flow_default")
21
- FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
22
- HF_TOKEN = os.getenv("HF_TOKEN", "")
23
- HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD")
24
- HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG")
25
-
26
- # Progress and State Tracking
27
- PROGRESS_FILE = Path("processing_progress.json")
28
- HF_STATE_FILE = "processing_state_transcriptions.json"
29
- LOCAL_STATE_FOLDER = Path(".state")
30
- LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
31
-
32
- # Processing configuration
33
- MAX_UPLOADS_BEFORE_PAUSE = 120 # Pause uploading after 120 files
34
- UPLOAD_PAUSE_ENABLED = True
35
-
36
- # Directory within the HF dataset where the audio files are located
37
- AUDIO_FILE_PREFIX = "audio/"
38
-
39
- WHISPER_SERVERS = [
40
- "https://makeitfr-mineo-1.hf.space/transcribe",
41
- "https://makeitfr-mineo-2.hf.space/transcribe",
42
- "https://makeitfr-mineo-3.hf.space/transcribe",
43
- "https://makeitfr-mineo-4.hf.space/transcribe",
44
- "https://makeitfr-mineo-5.hf.space/transcribe",
45
- "https://makeitfr-mineo-6.hf.space/transcribe",
46
- "https://makeitfr-mineo-7.hf.space/transcribe",
47
- "https://makeitfr-mineo-8.hf.space/transcribe",
48
- "https://makeitfr-mineo-9.hf.space/transcribe",
49
- "https://makeitfr-mineo-10.hf.space/transcribe",
50
- "https://makeitfr-mineo-11.hf.space/transcribe",
51
- "https://makeitfr-mineo-12.hf.space/transcribe",
52
- "https://makeitfr-mineo-13.hf.space/transcribe",
53
- "https://makeitfr-mineo-14.hf.space/transcribe",
54
- "https://makeitfr-mineo-15.hf.space/transcribe",
55
- "https://makeitfr-mineo-16.hf.space/transcribe",
56
- "https://makeitfr-mineo-17.hf.space/transcribe",
57
- "https://makeitfr-mineo-18.hf.space/transcribe",
58
- "https://makeitfr-mineo-19.hf.space/transcribe",
59
- "https://makeitfr-mineo-20.hf.space/transcribe"
60
- ]
61
-
62
- # Temporary storage for audio files
63
- TEMP_DIR = Path(f"temp_audio_{FLOW_ID}")
64
- TEMP_DIR.mkdir(exist_ok=True)
65
-
66
- # --- Models ---
67
- class ProcessStartRequest(BaseModel):
68
- start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).")
69
-
70
- class WhisperServer:
71
- def __init__(self, url: str):
72
- self.url = url
73
- self.is_processing = False
74
- self.current_file_index: Optional[int] = None
75
- self.total_processed = 0
76
- self.total_time = 0.0
77
-
78
- @property
79
- def fps(self):
80
- """Files per second"""
81
- return self.total_processed / self.total_time if self.total_time > 0 else 0
82
-
83
- def assign_file(self, file_index: int):
84
- """Assign a file index to this server"""
85
- self.is_processing = True
86
- self.current_file_index = file_index
87
-
88
- def release(self):
89
- """Release the server for a new file"""
90
- self.is_processing = False
91
- self.current_file_index = None
92
-
93
- # Global state for whisper servers
94
- servers = [WhisperServer(url) for url in WHISPER_SERVERS]
95
- server_lock = asyncio.Lock() # Lock for thread-safe server state access
96
-
97
- # --- Progress and State Management Functions ---
98
-
99
- def load_progress() -> Dict:
100
- """Loads the local processing progress from the JSON file."""
101
- if PROGRESS_FILE.exists():
102
- try:
103
- with PROGRESS_FILE.open('r') as f:
104
- return json.load(f)
105
- except json.JSONDecodeError:
106
- print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
107
- # Fall through to return default structure
108
-
109
- # Default structure
110
- return {
111
- "last_processed_index": 0,
112
- "processed_files": {}, # {index: repo_path}
113
- "file_list": [] # Full list of all zip files found in the dataset
114
- }
115
-
116
- def save_progress(progress_data: Dict):
117
- """Saves the local processing progress to the JSON file."""
118
- try:
119
- with PROGRESS_FILE.open('w') as f:
120
- json.dump(progress_data, f, indent=4)
121
- except Exception as e:
122
- print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}")
123
-
124
- def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]:
125
- """Load state from JSON file with migration logic for new structure."""
126
- if os.path.exists(file_path):
127
- try:
128
- with open(file_path, "r") as f:
129
- data = json.load(f)
130
-
131
- # Migration Logic
132
- if "file_states" not in data or not isinstance(data["file_states"], dict):
133
- print(f"[{FLOW_ID}] Initializing 'file_states' dictionary.")
134
- data["file_states"] = {}
135
-
136
- if "next_download_index" not in data:
137
- data["next_download_index"] = 0
138
-
139
- return data
140
- except json.JSONDecodeError:
141
- print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}")
142
- return default_value
143
-
144
- def save_json_state(file_path: str, data: Dict[str, Any]):
145
- """Save state to JSON file"""
146
- with open(file_path, "w") as f:
147
- json.dump(data, f, indent=2)
148
-
149
- async def download_hf_state() -> Dict[str, Any]:
150
- """Downloads the state file from Hugging Face or returns a default state."""
151
- local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
152
- default_state = {"next_download_index": 0, "file_states": {}}
153
-
154
- try:
155
- # Check if the file exists in the helium repo
156
- files = HfApi(token=HF_TOKEN).list_repo_files(
157
- repo_id=HF_OUTPUT_DATASET_ID,
158
- repo_type="dataset"
159
- )
160
-
161
- if HF_STATE_FILE not in files:
162
- print(f"[{FLOW_ID}] State file not found in {HF_OUTPUT_DATASET_ID}. Starting fresh.")
163
- return default_state
164
-
165
- # Download the file
166
- hf_hub_download(
167
- repo_id=HF_OUTPUT_DATASET_ID,
168
- filename=HF_STATE_FILE,
169
- repo_type="dataset",
170
- local_dir=LOCAL_STATE_FOLDER,
171
- local_dir_use_symlinks=False,
172
- token=HF_TOKEN
173
- )
174
-
175
- print(f"[{FLOW_ID}] Successfully downloaded state file.")
176
- return load_json_state(str(local_path), default_state)
177
-
178
- except Exception as e:
179
- print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Starting fresh.")
180
- return default_state
181
-
182
- async def upload_hf_state(state: Dict[str, Any]) -> bool:
183
- """Uploads the state file to Hugging Face."""
184
- local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
185
-
186
- try:
187
- # Save state locally first
188
- save_json_state(str(local_path), state)
189
-
190
- # Upload to helium dataset
191
- HfApi(token=HF_TOKEN).upload_file(
192
- path_or_fileobj=str(local_path),
193
- path_in_repo=HF_STATE_FILE,
194
- repo_id=HF_OUTPUT_DATASET_ID,
195
- repo_type="dataset",
196
- commit_message=f"Update caption processing state: next_index={state['next_download_index']}"
197
- )
198
- print(f"[{FLOW_ID}] Successfully uploaded state file.")
199
- return True
200
- except Exception as e:
201
- print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
202
- return False
203
-
204
- async def lock_file_for_processing(zip_filename: str, state: Dict[str, Any]) -> bool:
205
- """Marks a file as 'processing' in the state file and uploads the lock."""
206
- print(f"[{FLOW_ID}] 🔒 Attempting to lock file: {zip_filename}")
207
-
208
- # Update state locally
209
- state["file_states"][zip_filename] = "processing"
210
-
211
- # Upload the updated state file immediately to establish the lock
212
- if await upload_hf_state(state):
213
- print(f"[{FLOW_ID}] ✅ Successfully locked file: {zip_filename}")
214
- return True
215
- else:
216
- print(f"[{FLOW_ID}] ❌ Failed to lock file: {zip_filename}")
217
- # Revert local state
218
- if zip_filename in state["file_states"]:
219
- del state["file_states"][zip_filename]
220
- return False
221
-
222
- async def unlock_file_as_processed(zip_filename: str, state: Dict[str, Any], next_index: int) -> bool:
223
- """Marks a file as 'processed', updates the index, and uploads the state."""
224
- print(f"[{FLOW_ID}] 🔓 Marking file as processed: {zip_filename}")
225
-
226
- # Update state locally
227
- state["file_states"][zip_filename] = "processed"
228
- state["next_download_index"] = next_index
229
-
230
- # Upload the updated state
231
- if await upload_hf_state(state):
232
- print(f"[{FLOW_ID}] ✅ Successfully marked as processed: {zip_filename}")
233
- return True
234
- else:
235
- print(f"[{FLOW_ID}] ❌ Failed to update state for: {zip_filename}")
236
- return False
237
-
238
- # --- Hugging Face Utility Functions ---
239
-
240
- async def get_audio_file_list(progress_data: Dict) -> List[str]:
241
- """
242
- Fetches the list of all WAV files from the dataset, or uses the cached list.
243
- Updates the progress_data with the file list if a new list is fetched.
244
- """
245
- if progress_data['file_list']:
246
- print(f"[{FLOW_ID}] Using cached file list with {len(progress_data['file_list'])} files.")
247
- return progress_data['file_list']
248
-
249
- print(f"[{FLOW_ID}] Fetching full list of WAV files from {HF_AUDIO_DATASET_ID}...")
250
- try:
251
- api = HfApi(token=HF_TOKEN)
252
- repo_files = api.list_repo_files(
253
- repo_id=HF_AUDIO_DATASET_ID,
254
- repo_type="dataset"
255
- )
256
-
257
- # Filter for WAV files and sort them alphabetically for consistent indexing
258
- wav_files = sorted([
259
- f for f in repo_files
260
- if f.endswith('.wav')
261
- ])
262
-
263
- if not wav_files:
264
- raise FileNotFoundError(f"No WAV files found in dataset '{HF_AUDIO_DATASET_ID}'.")
265
-
266
- print(f"[{FLOW_ID}] Found {len(wav_files)} WAV files.")
267
-
268
- # Update and save the progress data
269
- progress_data['file_list'] = wav_files
270
- save_progress(progress_data)
271
-
272
- return wav_files
273
-
274
- except Exception as e:
275
- print(f"[{FLOW_ID}] Error fetching file list from Hugging Face: {e}")
276
- return []
277
-
278
- async def download_wav_file_by_index(file_index: int, repo_file_full_path: str) -> Optional[Path]:
279
- """Downloads a WAV file from the repository."""
280
-
281
- wav_filename = Path(repo_file_full_path).name
282
-
283
- print(f"[{FLOW_ID}] Downloading file #{file_index}: {repo_file_full_path}")
284
-
285
- try:
286
- # Download the file into our TEMP_DIR (so we can safely delete it later)
287
- wav_path = hf_hub_download(
288
- repo_id=HF_AUDIO_DATASET_ID,
289
- filename=repo_file_full_path,
290
- repo_type="dataset",
291
- token=HF_TOKEN,
292
- local_dir=str(TEMP_DIR),
293
- local_dir_use_symlinks=False,
294
- )
295
-
296
- print(f"[{FLOW_ID}] Downloaded WAV file to {wav_path}")
297
- return Path(wav_path)
298
-
299
- except Exception as e:
300
- print(f"[{FLOW_ID}] Error downloading WAV file {repo_file_full_path}: {e}")
301
- return None
302
-
303
- async def upload_transcription_to_hf(wav_filename: str, transcription_data: Dict) -> bool:
304
- """Uploads the transcription JSON file to the output dataset."""
305
- # Use the full WAV path, replacing slashes with underscores and extension with .json
306
- json_filename = wav_filename.replace('/', '_').replace('\\', '_').rsplit('.', 1)[0] + '.json'
307
-
308
- try:
309
- print(f"[{FLOW_ID}] Uploading transcription for {wav_filename} as {json_filename} to {HF_OUTPUT_DATASET_ID}...")
310
-
311
- # Create JSON content in memory
312
- json_content = json.dumps(transcription_data, indent=2, ensure_ascii=False).encode('utf-8')
313
-
314
- api = HfApi(token=HF_TOKEN)
315
- api.upload_file(
316
- path_or_fileobj=io.BytesIO(json_content),
317
- path_in_repo=json_filename,
318
- repo_id=HF_OUTPUT_DATASET_ID,
319
- repo_type="dataset",
320
- commit_message=f"[{FLOW_ID}] Transcription for {wav_filename}"
321
- )
322
-
323
- print(f"[{FLOW_ID}] Successfully uploaded transcription for {wav_filename}.")
324
- return True
325
-
326
- except Exception as e:
327
- print(f"[{FLOW_ID}] Error uploading transcription for {wav_filename}: {e}")
328
- return False
329
-
330
- # --- Core Processing Functions ---
331
-
332
- async def send_audio_to_whisper(wav_path: Path, server: WhisperServer) -> Optional[Dict]:
333
- """Sends a WAV file to a Whisper server for transcription."""
334
- try:
335
- print(f"[{FLOW_ID}] Sending {wav_path.name} to {server.url}...")
336
-
337
- start_time = time.time()
338
-
339
- # Prepare multipart form data
340
- form_data = aiohttp.FormData()
341
- # Open the file in a context manager so the descriptor is closed after the request
342
- with wav_path.open('rb') as f:
343
- form_data.add_field('file', f, filename=wav_path.name, content_type='audio/wav')
344
-
345
- async with aiohttp.ClientSession() as session:
346
- # 10 minute timeout for transcription
347
- async with session.post(server.url, data=form_data, timeout=600) as resp:
348
- if resp.status == 200:
349
- result = await resp.json()
350
- end_time = time.time()
351
-
352
- # Update server stats
353
- server.total_processed += 1
354
- server.total_time += (end_time - start_time)
355
-
356
- print(f"[{FLOW_ID}] ✓ {wav_path.name} transcribed successfully by {server.url}")
357
-
358
- return {
359
- "file": wav_path.name,
360
- "transcription": result,
361
- "timestamp": datetime.now().isoformat(),
362
- "processing_time_seconds": end_time - start_time
363
- }
364
- else:
365
- error_text = await resp.text()
366
- print(f"[{FLOW_ID}] ✗ Error from {server.url}: {resp.status} - {error_text}")
367
- return None
368
-
369
- except asyncio.TimeoutError:
370
- print(f"[{FLOW_ID}] ✗ Timeout from {server.url} for {wav_path.name}")
371
- return None
372
- except Exception as e:
373
- print(f"[{FLOW_ID}] ✗ Exception on {server.url} for {wav_path.name}: {e}")
374
- return None
375
-
376
- async def get_available_servers() -> List[WhisperServer]:
377
- """
378
- Returns a list of servers that are not currently processing.
379
- Dynamically assigns new files to available servers.
380
- """
381
- async with server_lock:
382
- available = [s for s in servers if not s.is_processing]
383
- return available
384
-
385
- async def assign_file_to_server(file_index: int, server: WhisperServer):
386
- """Safely assign a file to a server"""
387
- async with server_lock:
388
- server.assign_file(file_index)
389
-
390
- async def release_server(server: WhisperServer):
391
- """Safely release a server for new work"""
392
- async with server_lock:
393
- server.release()
394
-
395
- async def process_batch_dynamic(wav_files: List[str], start_batch_index: int, batch_size: int, state: Dict[str, Any], progress: Dict) -> Tuple[int, int]:
396
- """
397
- Processes a batch of WAV files in parallel using available servers.
398
- Batch size = number of servers. Each server gets one file, processes it, then gets the next.
399
- Includes retry mechanism for failed files.
400
- Returns (next_batch_index, uploaded_count)
401
- """
402
- batch_end = min(start_batch_index + batch_size, len(wav_files))
403
- uploaded_count = progress.get('uploaded_count', 0)
404
- max_retries = 3
405
- failed_files = [] # Track files that failed for retry
406
-
407
- print(f"[{FLOW_ID}] Processing batch from index {start_batch_index} to {batch_end - 1} ({batch_end - start_batch_index} files)")
408
-
409
- # --- Batch-level locking: mark all files in this batch as 'processing' and upload state
410
- try:
411
- state.setdefault("file_states", {})
412
- for idx in range(start_batch_index, batch_end):
413
- wav_file = wav_files[idx]
414
- state["file_states"][wav_file] = "processing"
415
-
416
- # Update next_download_index to the end of this batch (0-based)
417
- state["next_download_index"] = batch_end
418
-
419
- # Upload HF state to establish locks for this batch
420
- if await upload_hf_state(state):
421
- print(f"[{FLOW_ID}] ✅ Batch locked: files {start_batch_index}-{batch_end - 1} marked 'processing'")
422
- else:
423
- print(f"[{FLOW_ID}] ❌ Failed to upload batch lock")
424
- return start_batch_index, uploaded_count
425
- except Exception as e:
426
- print(f"[{FLOW_ID}] Error while setting up batch locks: {e}")
427
- return start_batch_index, uploaded_count
428
-
429
- # Create a queue of files to process with retry support
430
- files_to_process = [(idx, wav_files[idx], 0) for idx in range(start_batch_index, batch_end)] # (idx, wav_file, retry_count)
431
-
432
- # --- Assign files to servers and create tasks
433
- pending_tasks: Dict[asyncio.Task, Tuple[int, Path, WhisperServer, str, int]] = {}
434
-
435
- try:
436
- while files_to_process or pending_tasks:
437
- # Assign new files to available servers
438
- while files_to_process:
439
- available = await get_available_servers()
440
- if not available:
441
- break
442
-
443
- file_idx, wav_file, retry_count = files_to_process.pop(0)
444
- wav_filename = Path(wav_file).name
445
- server = available[0]
446
-
447
- # Download the WAV file
448
- wav_path = await download_wav_file_by_index(file_idx + 1, wav_file)
449
- if not wav_path:
450
- if retry_count < max_retries:
451
- print(f"[{FLOW_ID}] ⚠️ Download failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...")
452
- files_to_process.append((file_idx, wav_file, retry_count + 1))
453
- else:
454
- state["file_states"][wav_file] = "failed_download"
455
- print(f"[{FLOW_ID}] ❌ Download failed permanently for {wav_filename} after {max_retries} retries")
456
- continue
457
-
458
- # Assign to server and create task
459
- await assign_file_to_server(file_idx, server)
460
- task = asyncio.create_task(send_audio_to_whisper(wav_path, server))
461
- pending_tasks[task] = (file_idx, wav_path, server, wav_file, retry_count)
462
- print(f"[{FLOW_ID}] Assigned {wav_filename} to server {servers.index(server) + 1}")
463
-
464
- # Wait for at least one task to complete if there are pending tasks
465
- if not pending_tasks:
466
- break
467
-
468
- done, pending = await asyncio.wait(
469
- pending_tasks.keys(),
470
- return_when=asyncio.FIRST_COMPLETED
471
- )
472
-
473
- for task in done:
474
- file_idx, wav_path, server, wav_file, retry_count = pending_tasks.pop(task)
475
- wav_filename = Path(wav_file).name
476
-
477
- try:
478
- transcription_result = task.result()
479
-
480
- if transcription_result:
481
- # Upload transcription immediately with full path
482
- uploaded_ok = await upload_transcription_to_hf(wav_file, transcription_result)
483
- if uploaded_ok:
484
- # Update state locally but do NOT upload to HF yet
485
- state["file_states"][wav_file] = "processed"
486
- uploaded_count += 1
487
- progress['uploaded_count'] = uploaded_count
488
- save_progress(progress)
489
- print(f"[{FLOW_ID}] ✅ {wav_filename} uploaded (#{uploaded_count})")
490
- else:
491
- # Retry failed upload
492
- if retry_count < max_retries:
493
- print(f"[{FLOW_ID}] ⚠️ Upload failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...")
494
- files_to_process.append((file_idx, wav_file, retry_count + 1))
495
- else:
496
- state["file_states"][wav_file] = "failed_upload"
497
- print(f"[{FLOW_ID}] ❌ Upload failed permanently for {wav_filename} after {max_retries} retries")
498
- else:
499
- # Retry failed transcription
500
- if retry_count < max_retries:
501
- print(f"[{FLOW_ID}] ⚠️ Transcription failed for {wav_filename} (retry {retry_count + 1}/{max_retries}), re-queueing...")
502
- files_to_process.append((file_idx, wav_file, retry_count + 1))
503
- else:
504
- state["file_states"][wav_file] = "failed_transcription"
505
- print(f"[{FLOW_ID}] ❌ Transcription failed permanently for {wav_filename} after {max_retries} retries")
506
-
507
- except Exception as e:
508
- if retry_count < max_retries:
509
- print(f"[{FLOW_ID}] ⚠️ Error processing {wav_filename}: {e} (retry {retry_count + 1}/{max_retries}), re-queueing...")
510
- files_to_process.append((file_idx, wav_file, retry_count + 1))
511
- else:
512
- print(f"[{FLOW_ID}] ❌ Error processing {wav_filename}: {e} (failed after {max_retries} retries)")
513
- state["file_states"][wav_file] = "failed_error"
514
- finally:
515
- # Release the server
516
- await release_server(server)
517
- # Clean up the WAV file
518
- if wav_path.exists():
519
- wav_path.unlink()
520
-
521
- # --- After all files in this batch are uploaded, update HF state once
522
- if await upload_hf_state(state):
523
- print(f"[{FLOW_ID}] ✅ Batch state updated on HF: files {start_batch_index}-{batch_end - 1} marked processed")
524
- else:
525
- print(f"[{FLOW_ID}] ❌ Failed to update batch state on HF")
526
- except Exception as e:
527
- print(f"[{FLOW_ID}] Error in process_batch_dynamic: {e}")
528
-
529
- return batch_end, uploaded_count
530
-
531
- async def process_dataset_task(start_index: int):
532
- """Main task to process the dataset using dynamic server assignment."""
533
-
534
- # Load both local progress and HF state
535
- progress = load_progress()
536
- current_state = await download_hf_state()
537
- file_list = await get_audio_file_list(progress)
538
-
539
- if not file_list:
540
- print(f"[{FLOW_ID}] ERROR: Cannot proceed. File list is empty.")
541
- return False
542
-
543
- # Ensure start_index is within bounds
544
- if start_index > len(file_list):
545
- print(f"[{FLOW_ID}] WARNING: Start index {start_index} is greater than the total number of files ({len(file_list)}). Exiting.")
546
- return True
547
-
548
- # Determine the actual starting index in the 0-indexed list
549
- start_list_index = start_index - 1
550
-
551
- print(f"[{FLOW_ID}] Starting audio transcription from file index: {start_index} out of {len(file_list)}.")
552
- print(f"[{FLOW_ID}] Using {len(servers)} Whisper servers for dynamic processing.")
553
- print(f"[{FLOW_ID}] Upload pause enabled: {UPLOAD_PAUSE_ENABLED}, Max uploads before pause: {MAX_UPLOADS_BEFORE_PAUSE}")
554
-
555
- # Initialize progress tracking
556
- if 'uploaded_count' not in progress:
557
- progress['uploaded_count'] = 0
558
-
559
- # If there was no HF state in the repo, upload a fresh initial state file
560
- try:
561
- if not current_state.get("file_states") and current_state.get("next_download_index", 0) == 0:
562
- print(f"[{FLOW_ID}] No HF state detected; uploading initial state file to {HF_OUTPUT_DATASET_ID}...")
563
- # Ensure structure
564
- current_state.setdefault("file_states", {})
565
- current_state.setdefault("next_download_index", 0)
566
- if await upload_hf_state(current_state):
567
- print(f"[{FLOW_ID}] ✅ Initial HF state uploaded.")
568
- else:
569
- print(f"[{FLOW_ID}] ❌ Failed to upload initial HF state.")
570
- except Exception as e:
571
- print(f"[{FLOW_ID}] Error while uploading initial HF state: {e}")
572
-
573
- global_success = True
574
- current_batch_index = start_list_index
575
- batch_size = len(servers) # Batch size = number of servers (20 files per batch)
576
- batch_interval_seconds = 600 # 600 seconds = 10 minutes (enforces max 6 batches per hour)
577
-
578
- try:
579
- batch_count = 0
580
- while current_batch_index < len(file_list):
581
- batch_start_time = time.time()
582
-
583
- # Process a batch dynamically
584
- next_index, uploaded_count = await process_batch_dynamic(
585
- file_list,
586
- current_batch_index,
587
- batch_size,
588
- current_state,
589
- progress
590
- )
591
-
592
- batch_end_time = time.time()
593
- batch_elapsed = batch_end_time - batch_start_time
594
-
595
- # Update progress
596
- progress['last_processed_index'] = next_index
597
- progress['uploaded_count'] = uploaded_count
598
- save_progress(progress)
599
-
600
- # Update current batch index
601
- current_batch_index = next_index
602
- batch_count += 1
603
-
604
- # Log statistics
605
- print(f"[{FLOW_ID}] Batch complete. Progress: {current_batch_index}/{len(file_list)}, Uploaded: {uploaded_count}")
606
-
607
- # Print server statistics
608
- print(f"[{FLOW_ID}] Server Statistics:")
609
- for i, server in enumerate(servers):
610
- print(f" Server {i+1}: {server.total_processed} files, {server.total_time:.2f}s total, {server.fps:.2f} files/sec")
611
-
612
- # Rate limiting: enforce minimum 10 minutes between batch starts (max 6 batches per hour)
613
- if current_batch_index < len(file_list): # Don't wait after the last batch
614
- wait_time = batch_interval_seconds - batch_elapsed
615
- if wait_time > 0:
616
- print(f"[{FLOW_ID}] Rate limit: batch took {batch_elapsed:.1f}s. Waiting {wait_time:.1f}s before next batch (min 10 min interval)...")
617
- await asyncio.sleep(wait_time)
618
- else:
619
- print(f"[{FLOW_ID}] Batch took {batch_elapsed:.1f}s (exceeded 10 min interval). Proceeding immediately to next batch.")
620
-
621
- print(f"[{FLOW_ID}] All files processed successfully! Total batches: {batch_count}")
622
- return True
623
-
624
- except Exception as e:
625
- print(f"[{FLOW_ID}] Critical error in process_dataset_task: {e}")
626
- global_success = False
627
- return global_success
628
-
629
- # --- FastAPI App and Endpoints ---
630
-
631
- app = FastAPI(
632
- title=f"Flow Server {FLOW_ID} API",
633
- description="Sequentially processes zip files from a dataset, captions images, and tracks progress.",
634
- version="1.0.0"
635
- )
636
-
637
- @app.on_event("startup")
638
- async def startup_event():
639
- print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}.")
640
-
641
- # Get both local progress and HF state
642
- progress = load_progress()
643
- current_state = await download_hf_state()
644
-
645
- # Get the next_download_index from HF state if available
646
- hf_next_index = current_state.get("next_download_index", 0)
647
-
648
- # If HF state has a higher index, use that instead of local progress
649
- if hf_next_index > 0:
650
- start_index = hf_next_index
651
- print(f"[{FLOW_ID}] Using next_download_index from HF state: {start_index}")
652
- else:
653
- # Fall back to local progress if HF state doesn't have a meaningful index
654
- start_index = progress.get('last_processed_index', 0) + 1
655
- if start_index < AUTO_START_INDEX:
656
- start_index = AUTO_START_INDEX
657
-
658
- # Use a dummy BackgroundTasks object for the startup task
659
- # Note: FastAPI's startup events can't directly use BackgroundTasks, but we can use asyncio.create_task
660
- # to run the long-running process in the background without blocking the server startup.
661
- print(f"[{FLOW_ID}] Auto-starting processing from index: {start_index}...")
662
- asyncio.create_task(process_dataset_task(start_index))
663
-
664
- @app.get("/")
665
- async def root():
666
- progress = load_progress()
667
-
668
- # Calculate server stats
669
- total_processed = sum(s.total_processed for s in servers)
670
- total_time = sum(s.total_time for s in servers)
671
- avg_fps = total_processed / total_time if total_time > 0 else 0
672
-
673
- return {
674
- "flow_id": FLOW_ID,
675
- "status": "ready",
676
- "last_processed_index": progress.get('last_processed_index', 0),
677
- "total_files_in_list": len(progress['file_list']),
678
- "uploaded_count": progress.get('uploaded_count', 0),
679
- "total_servers": len(servers),
680
- "processing_servers": sum(1 for s in servers if s.is_processing),
681
- "total_files_processed_by_servers": total_processed,
682
- "avg_files_per_second": avg_fps,
683
- "upload_limit_paused": progress.get('uploaded_count', 0) >= MAX_UPLOADS_BEFORE_PAUSE
684
- }
685
-
686
- @app.post("/start_processing")
687
- async def start_processing(request: ProcessStartRequest, background_tasks: BackgroundTasks):
688
- """
689
- Starts the sequential processing of zip files from the given index in the background.
690
- """
691
- start_index = request.start_index
692
-
693
- print(f"[{FLOW_ID}] Received request to start processing from index: {start_index}. Starting background task.")
694
-
695
- # Start the heavy processing in a background task so the API call returns immediately
696
- # Note: The server is already auto-starting, but this allows for manual restart/override.
697
- background_tasks.add_task(process_dataset_task, start_index)
698
-
699
- return {"status": "processing", "start_index": start_index, "message": "Dataset processing started in background."}
700
-
701
- if __name__ == "__main__":
702
- import uvicorn
703
- # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
704
- uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
 
1
+ import os
2
+ import json
3
+ import time
4
+ import asyncio
5
+ import aiohttp
6
+ import zipfile
7
+ import shutil
8
+ from typing import Dict, List, Set, Optional, Tuple, Any
9
+ from urllib.parse import quote
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ import io
13
+
14
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, status
15
+ from pydantic import BaseModel, Field
16
+ from huggingface_hub import HfApi, hf_hub_download
17
+
18
+ # --- Configuration ---
19
+ AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found
20
+ FLOW_ID = os.getenv("FLOW_ID", "flow_default")
21
+ FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
22
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
23
+ HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD")
24
+ HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG")
25
+
26
+ # Progress and State Tracking
27
+ PROGRESS_FILE = Path("processing_progress.json")
28
+ HF_STATE_FILE = "processing_state_transcriptions.json"
29
+ LOCAL_STATE_FOLDER = Path(".state")
30
+ LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
31
+
32
+ # Processing configuration
33
+ MAX_UPLOADS_BEFORE_PAUSE = 120 # Pause uploading after 120 files
34
+ UPLOAD_PAUSE_ENABLED = True
35
+
36
+ # Directory within the HF dataset where the audio files are located
37
+ AUDIO_FILE_PREFIX = "audio/"
38
+
39
+ WHISPER_SERVERS = [
40
+ f"https://makeitfr-mineo-{i}.hf.space/transcribe" for i in range(1, 21)
41
+ ]
42
+
43
+ # Temporary storage for audio files
44
+ TEMP_DIR = Path(f"temp_audio_{FLOW_ID}")
45
+ TEMP_DIR.mkdir(exist_ok=True)
46
+
47
+ # --- Models ---
48
+ class ProcessStartRequest(BaseModel):
49
+ start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).")
50
+
51
+ class WhisperServer:
52
+ def __init__(self, url: str):
53
+ self.url = url
54
+ self.is_processing = False
55
+ self.current_file_index: Optional[int] = None
56
+ self.total_processed = 0
57
+ self.total_time = 0.0
58
+
59
+ @property
60
+ def fps(self):
61
+ """Files per second"""
62
+ return self.total_processed / self.total_time if self.total_time > 0 else 0
63
+
64
+ def assign_file(self, file_index: int):
65
+ """Assign a file index to this server"""
66
+ self.is_processing = True
67
+ self.current_file_index = file_index
68
+
69
+ def release(self):
70
+ """Release the server for a new file"""
71
+ self.is_processing = False
72
+ self.current_file_index = None
73
+
74
+ # Global state for whisper servers
75
+ servers = [WhisperServer(url) for url in WHISPER_SERVERS]
76
+ server_lock = asyncio.Lock() # Lock for thread-safe server state access
77
+
78
+ # --- Progress and State Management Functions ---
79
+
80
+ def load_progress() -> Dict:
81
+ """Loads the local processing progress from the JSON file."""
82
+ if PROGRESS_FILE.exists():
83
+ try:
84
+ with PROGRESS_FILE.open('r') as f:
85
+ return json.load(f)
86
+ except json.JSONDecodeError:
87
+ print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
88
+
89
+ return {
90
+ "last_processed_index": 0,
91
+ "processed_files": {}, # {index: repo_path}
92
+ "file_list": [], # Full list of all zip files found in the dataset
93
+ "uploaded_count": 0
94
+ }
95
+
96
+ def save_progress(progress_data: Dict):
97
+ """Saves the local processing progress to the JSON file."""
98
+ try:
99
+ with PROGRESS_FILE.open('w') as f:
100
+ json.dump(progress_data, f, indent=4)
101
+ except Exception as e:
102
+ print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}")
103
+
104
+ def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]:
105
+ """Load state from JSON file with migration logic for new structure."""
106
+ if os.path.exists(file_path):
107
+ try:
108
+ with open(file_path, "r") as f:
109
+ data = json.load(f)
110
+ if "file_states" not in data or not isinstance(data["file_states"], dict):
111
+ data["file_states"] = {}
112
+ if "next_download_index" not in data:
113
+ data["next_download_index"] = 0
114
+ return data
115
+ except json.JSONDecodeError:
116
+ print(f"[{FLOW_ID}] WARNING: Corrupted state file: {file_path}")
117
+ return default_value
118
+
119
+ def save_json_state(file_path: str, data: Dict[str, Any]):
120
+ """Save state to JSON file"""
121
+ with open(file_path, "w") as f:
122
+ json.dump(data, f, indent=2)
123
+
124
+ async def download_hf_state() -> Dict[str, Any]:
125
+ """Downloads the state file from Hugging Face or returns a default state."""
126
+ local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
127
+ default_state = {"next_download_index": 0, "file_states": {}}
128
+ try:
129
+ hf_hub_download(
130
+ repo_id=HF_OUTPUT_DATASET_ID,
131
+ filename=HF_STATE_FILE,
132
+ repo_type="dataset",
133
+ local_dir=LOCAL_STATE_FOLDER,
134
+ local_dir_use_symlinks=False,
135
+ token=HF_TOKEN
136
+ )
137
+ return load_json_state(str(local_path), default_state)
138
+ except Exception as e:
139
+ print(f"[{FLOW_ID}] Failed to download state file: {str(e)}. Using local/default.")
140
+ return load_json_state(str(local_path), default_state)
141
+
142
+ async def upload_hf_state(state: Dict[str, Any]) -> bool:
143
+ """Uploads the state file to Hugging Face."""
144
+ local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
145
+ try:
146
+ save_json_state(str(local_path), state)
147
+ HfApi(token=HF_TOKEN).upload_file(
148
+ path_or_fileobj=str(local_path),
149
+ path_in_repo=HF_STATE_FILE,
150
+ repo_id=HF_OUTPUT_DATASET_ID,
151
+ repo_type="dataset",
152
+ commit_message=f"Update transcription state: next_index={state.get('next_download_index')}"
153
+ )
154
+ return True
155
+ except Exception as e:
156
+ print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
157
+ return False
158
+
159
+ # --- Hugging Face Utility Functions ---
160
+
161
+ async def get_audio_file_list(progress_data: Dict) -> List[str]:
162
+ if progress_data['file_list']:
163
+ return progress_data['file_list']
164
+ try:
165
+ api = HfApi(token=HF_TOKEN)
166
+ repo_files = api.list_repo_files(repo_id=HF_AUDIO_DATASET_ID, repo_type="dataset")
167
+ wav_files = sorted([f for f in repo_files if f.endswith('.wav')])
168
+ progress_data['file_list'] = wav_files
169
+ save_progress(progress_data)
170
+ return wav_files
171
+ except Exception as e:
172
+ print(f"[{FLOW_ID}] Error fetching file list: {e}")
173
+ return []
174
+
175
+ # --- Core Processing Logic ---
176
+
177
+ async def transcribe_with_server(server: WhisperServer, wav_path: Path) -> Optional[Dict]:
178
+ start_time = time.time()
179
+ try:
180
+ async with aiohttp.ClientSession() as session:
181
+ with open(wav_path, 'rb') as f:
182
+ data = aiohttp.FormData()
183
+ data.add_field('file', f, filename=wav_path.name)
184
+ async with session.post(server.url, data=data, timeout=600) as resp:
185
+ if resp.status == 200:
186
+ result = await resp.json()
187
+ elapsed = time.time() - start_time
188
+ server.total_processed += 1
189
+ server.total_time += elapsed
190
+ return result
191
+ else:
192
+ print(f"[{FLOW_ID}] Server {server.url} returned status {resp.status}")
193
+ except Exception as e:
194
+ print(f"[{FLOW_ID}] Error transcribing with {server.url}: {e}")
195
+ return None
196
+
197
+ async def process_file_task(wav_file: str, state: Dict, progress: Dict):
198
+ # Find an available server
199
+ server = None
200
+ while server is None:
201
+ async with server_lock:
202
+ for s in servers:
203
+ if not s.is_processing:
204
+ s.is_processing = True
205
+ server = s
206
+ break
207
+ if server is None:
208
+ await asyncio.sleep(1)
209
+
210
+ try:
211
+ wav_filename = Path(wav_file).name
212
+ wav_path = TEMP_DIR / wav_filename
213
+
214
+ # Download
215
+ hf_hub_download(
216
+ repo_id=HF_AUDIO_DATASET_ID,
217
+ filename=wav_file,
218
+ repo_type="dataset",
219
+ local_dir=TEMP_DIR,
220
+ local_dir_use_symlinks=False,
221
+ token=HF_TOKEN
222
+ )
223
+
224
+ # Transcribe
225
+ result = await transcribe_with_server(server, wav_path)
226
+
227
+ if result:
228
+ state["file_states"][wav_file] = "processed"
229
+ progress["uploaded_count"] = progress.get("uploaded_count", 0) + 1
230
+ print(f"[{FLOW_ID}] Success: {wav_file}")
231
+ # Note: In a real scenario, you'd save the 'result' (transcription) somewhere
232
+ else:
233
+ state["file_states"][wav_file] = "failed_transcription"
234
+ print(f"[{FLOW_ID}] ❌ Failed: {wav_file}")
235
+
236
+ if wav_path.exists():
237
+ wav_path.unlink()
238
+
239
+ except Exception as e:
240
+ print(f"[{FLOW_ID}] Error processing {wav_file}: {e}")
241
+ state["file_states"][wav_file] = "failed_transcription"
242
+ finally:
243
+ server.release()
244
+
245
+ async def main_processing_loop():
246
+ print(f"[{FLOW_ID}] Starting main processing loop...")
247
+
248
+ while True:
249
+ state = await download_hf_state()
250
+ progress = load_progress()
251
+ file_list = await get_audio_file_list(progress)
252
+
253
+ if not file_list:
254
+ await asyncio.sleep(60)
255
+ continue
256
+
257
+ # 1. Handpick failed_transcription files
258
+ failed_files = [f for f, s in state.get("file_states", {}).items() if s == "failed_transcription"]
259
+
260
+ # 2. Also check for new files based on next_download_index
261
+ next_idx = state.get("next_download_index", 0)
262
+ new_files = file_list[next_idx:next_idx + 100] # Take a chunk of new files
263
+
264
+ # Combine: Prioritize failed files, then add new ones
265
+ files_to_process = failed_files + [f for f in new_files if f not in state["file_states"]]
266
+
267
+ if not files_to_process:
268
+ print(f"[{FLOW_ID}] No files to process. Sleeping...")
269
+ await asyncio.sleep(60)
270
+ continue
271
+
272
+ print(f"[{FLOW_ID}] Processing {len(files_to_process)} files ({len(failed_files)} failed, {len(files_to_process)-len(failed_files)} new)...")
273
+
274
+ # Process in batches of server count
275
+ batch_size = len(servers)
276
+ for i in range(0, len(files_to_process), batch_size):
277
+ batch = files_to_process[i:i + batch_size]
278
+ tasks = [process_file_task(f, state, progress) for f in batch]
279
+ await asyncio.gather(*tasks)
280
+
281
+ # Update next_download_index if we processed new files
282
+ processed_new = [f for f in batch if f in new_files]
283
+ if processed_new:
284
+ # This is a simple way to update index; in reality, you'd want to be more precise
285
+ # but for this fix, we'll just increment based on what we found
286
+ last_new_file = processed_new[-1]
287
+ state["next_download_index"] = file_list.index(last_new_file) + 1
288
+
289
+ # Save and upload state after each batch
290
+ await upload_hf_state(state)
291
+ save_progress(progress)
292
+
293
+ await asyncio.sleep(10)
294
+
295
+ # --- FastAPI App ---
296
+
297
+ app = FastAPI(title=f"Flow Server {FLOW_ID} API")
298
+
299
+ @app.on_event("startup")
300
+ async def startup_event():
301
+ asyncio.create_task(main_processing_loop())
302
+
303
+ @app.get("/")
304
+ async def root():
305
+ progress = load_progress()
306
+ state = await download_hf_state()
307
+ failed_count = sum(1 for s in state.get("file_states", {}).values() if s == "failed_transcription")
308
+ return {
309
+ "flow_id": FLOW_ID,
310
+ "status": "running",
311
+ "next_download_index": state.get("next_download_index", 0),
312
+ "failed_transcriptions": failed_count,
313
+ "uploaded_count": progress.get("uploaded_count", 0)
314
+ }
315
+
316
+ @app.post("/start_processing")
317
+ async def start_processing(request: ProcessStartRequest):
318
+ # This endpoint can be used to manually reset the index if needed
319
+ state = await download_hf_state()
320
+ state["next_download_index"] = request.start_index - 1
321
+ await upload_hf_state(state)
322
+ return {"status": "index_reset", "new_index": request.start_index}
323
+
324
+ if __name__ == "__main__":
325
+ import uvicorn
326
+ uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)