Update app.py
Browse files
app.py
CHANGED
|
@@ -4,66 +4,35 @@ import time
|
|
| 4 |
import asyncio
|
| 5 |
import aiohttp
|
| 6 |
import zipfile
|
| 7 |
-
import io
|
| 8 |
import shutil
|
| 9 |
-
from typing import Dict, List, Set, Optional,
|
| 10 |
from urllib.parse import quote
|
| 11 |
from datetime import datetime
|
| 12 |
from pathlib import Path
|
|
|
|
| 13 |
|
| 14 |
-
from fastapi import FastAPI, BackgroundTasks, HTTPException, status
|
| 15 |
-
from fastapi.responses import HTMLResponse
|
| 16 |
-
from fastapi.templating import Jinja2Templates
|
| 17 |
from pydantic import BaseModel, Field
|
| 18 |
-
from huggingface_hub import HfApi, hf_hub_download
|
| 19 |
-
import uvicorn
|
| 20 |
|
| 21 |
# --- Configuration ---
|
|
|
|
| 22 |
FLOW_ID = os.getenv("FLOW_ID", "flow_default")
|
| 23 |
FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
|
| 24 |
-
MANAGER_URL = os.getenv("MANAGER_URL", "https://fred808-fcord.hf.space")
|
| 25 |
-
MANAGER_COMPLETE_TASK_URL = f"{MANAGER_URL}/task/complete"
|
| 26 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 27 |
-
HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3")
|
| 28 |
-
HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/helium")
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Using the full list from the user's original code for actual deployment
|
| 32 |
CAPTION_SERVERS = [
|
| 33 |
"https://fred808-pil-4-1.hf.space/analyze",
|
| 34 |
"https://fred808-pil-4-2.hf.space/analyze",
|
| 35 |
-
|
| 36 |
-
"https://fred1012-fred1012-gw0j2h.hf.space/analyze",
|
| 37 |
-
"https://fred1012-fred1012-wqs6c2.hf.space/analyze",
|
| 38 |
-
"https://fred1012-fred1012-oncray.hf.space/analyze",
|
| 39 |
-
"https://fred1012-fred1012-4goge7.hf.space/analyze",
|
| 40 |
-
"https://fred1012-fred1012-z0eh7m.hf.space/analyze",
|
| 41 |
-
"https://fred1012-fred1012-u95rte.hf.space/analyze",
|
| 42 |
-
"https://fred1012-fred1012-igje22.hf.space/analyze",
|
| 43 |
-
"https://fred1012-fred1012-ibkuf8.hf.space/analyze",
|
| 44 |
-
"https://fred1012-fred1012-nwqthy.hf.space/analyze",
|
| 45 |
-
"https://fred1012-fred1012-4ldqj4.hf.space/analyze",
|
| 46 |
-
"https://fred1012-fred1012-pivlzg.hf.space/analyze",
|
| 47 |
-
"https://fred1012-fred1012-ptlc5u.hf.space/analyze",
|
| 48 |
-
"https://fred1012-fred1012-u7lh57.hf.space/analyze",
|
| 49 |
-
"https://fred1012-fred1012-q8djv1.hf.space/analyze",
|
| 50 |
-
"https://fredalone-fredalone-ozugrp.hf.space/analyze",
|
| 51 |
-
"https://fredalone-fredalone-9brxj2.hf.space/analyze",
|
| 52 |
-
"https://fredalone-fredalone-p8vq9a.hf.space/analyze",
|
| 53 |
-
"https://fredalone-fredalone-vbli2y.hf.space/analyze",
|
| 54 |
-
"https://fredalone-fredalone-uggger.hf.space/analyze",
|
| 55 |
-
"https://fredalone-fredalone-nmi7e8.hf.space/analyze",
|
| 56 |
-
"https://fredalone-fredalone-d1f26d.hf.space/analyze",
|
| 57 |
-
"https://fredalone-fredalone-461jp2.hf.space/analyze",
|
| 58 |
-
"https://fredalone-fredalone-3enfg4.hf.space/analyze",
|
| 59 |
-
"https://fredalone-fredalone-dqdbpv.hf.space/analyze",
|
| 60 |
-
"https://fredalone-fredalone-ivtjua.hf.space/analyze",
|
| 61 |
-
"https://fredalone-fredalone-6bezt2.hf.space/analyze",
|
| 62 |
-
"https://fredalone-fredalone-e0wfnk.hf.space/analyze",
|
| 63 |
-
"https://fredalone-fredalone-zu2t7j.hf.space/analyze",
|
| 64 |
-
"https://fredalone-fredalone-dqtv1o.hf.space/analyze",
|
| 65 |
-
"https://fredalone-fredalone-wclyog.hf.space/analyze",
|
| 66 |
-
"https://fredalone-fredalone-t27vig.hf.space/analyze",
|
| 67 |
"https://fredalone-fredalone-gahbxh.hf.space/analyze",
|
| 68 |
"https://fredalone-fredalone-kw2po4.hf.space/analyze",
|
| 69 |
"https://fredalone-fredalone-8h285h.hf.space/analyze"
|
|
@@ -75,8 +44,8 @@ TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
|
|
| 75 |
TEMP_DIR.mkdir(exist_ok=True)
|
| 76 |
|
| 77 |
# --- Models ---
|
| 78 |
-
class
|
| 79 |
-
|
| 80 |
|
| 81 |
class CaptionServer:
|
| 82 |
def __init__(self, url):
|
|
@@ -90,116 +59,144 @@ class CaptionServer:
|
|
| 90 |
def fps(self):
|
| 91 |
return self.total_processed / self.total_time if self.total_time > 0 else 0
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
all_zip_files: List[str] = Field(default_factory=list)
|
| 96 |
-
processed_files: Set[str] = Field(default_factory=set)
|
| 97 |
-
current_index: int = 0
|
| 98 |
-
total_files: int = 0
|
| 99 |
-
status: str = "Idle"
|
| 100 |
-
current_file: Optional[str] = None
|
| 101 |
-
current_file_progress: str = "0/0"
|
| 102 |
-
last_update: str = datetime.now().isoformat()
|
| 103 |
-
is_running: bool = False
|
| 104 |
-
|
| 105 |
-
# Global state for caption servers and the overall server state
|
| 106 |
servers = [CaptionServer(url) for url in CAPTION_SERVERS]
|
| 107 |
server_index = 0
|
| 108 |
-
state = ServerState()
|
| 109 |
-
# Lock for thread-safe access to the global state
|
| 110 |
-
state_lock = asyncio.Lock()
|
| 111 |
|
| 112 |
-
# ---
|
| 113 |
|
| 114 |
-
def
|
| 115 |
-
"""
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
"""Helper to get HfFileSystem instance."""
|
| 120 |
-
return HfFileSystem(token=HF_TOKEN)
|
| 121 |
|
| 122 |
-
async def
|
| 123 |
-
"""
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
if fs.exists(state_path):
|
| 131 |
-
print(f"[{FLOW_ID}] Loading state from {state_path}...")
|
| 132 |
-
with fs.open(state_path, 'rb') as f:
|
| 133 |
-
data = json.load(f)
|
| 134 |
-
# Convert list of processed files back to a set
|
| 135 |
-
if 'processed_files' in data and isinstance(data['processed_files'], list):
|
| 136 |
-
data['processed_files'] = set(data['processed_files'])
|
| 137 |
-
state = ServerState.parse_obj(data)
|
| 138 |
-
print(f"[{FLOW_ID}] State loaded successfully. Current index: {state.current_index}")
|
| 139 |
-
else:
|
| 140 |
-
print(f"[{FLOW_ID}] State file {state_path} not found. Starting with default state.")
|
| 141 |
-
except Exception as e:
|
| 142 |
-
print(f"[{FLOW_ID}] Error loading state from HF: {e}. Starting with default state.")
|
| 143 |
-
state = ServerState()
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
state.last_update = datetime.now().isoformat()
|
| 153 |
-
# Convert set of processed files to a list for JSON serialization
|
| 154 |
-
data_to_save = state.dict()
|
| 155 |
-
data_to_save['processed_files'] = list(state.processed_files)
|
| 156 |
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
|
| 174 |
-
async def
|
| 175 |
-
"""
|
| 176 |
-
global state
|
| 177 |
-
api = get_hf_api()
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
# --- Core Processing Functions (Modified) ---
|
| 205 |
|
|
@@ -208,30 +205,37 @@ async def get_available_server(timeout: float = 300.0) -> CaptionServer:
|
|
| 208 |
global server_index
|
| 209 |
start_time = time.time()
|
| 210 |
while True:
|
|
|
|
| 211 |
for _ in range(len(servers)):
|
| 212 |
server = servers[server_index]
|
| 213 |
server_index = (server_index + 1) % len(servers)
|
| 214 |
if not server.busy:
|
| 215 |
return server
|
| 216 |
|
|
|
|
| 217 |
await asyncio.sleep(0.5)
|
| 218 |
|
|
|
|
| 219 |
if time.time() - start_time > timeout:
|
| 220 |
raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
|
| 221 |
|
| 222 |
async def send_image_for_captioning(image_path: Path, course_name: str, progress_tracker: Dict) -> Optional[Dict]:
|
| 223 |
"""Sends a single image to a caption server for processing."""
|
|
|
|
| 224 |
MAX_RETRIES = 3
|
| 225 |
for attempt in range(MAX_RETRIES):
|
| 226 |
server = None
|
| 227 |
try:
|
|
|
|
| 228 |
server = await get_available_server()
|
| 229 |
server.busy = True
|
| 230 |
start_time = time.time()
|
| 231 |
|
|
|
|
| 232 |
if attempt == 0:
|
| 233 |
print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
|
| 234 |
|
|
|
|
| 235 |
form_data = aiohttp.FormData()
|
| 236 |
form_data.add_field('file',
|
| 237 |
image_path.open('rb'),
|
|
@@ -239,21 +243,24 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
|
|
| 239 |
content_type='image/jpeg')
|
| 240 |
form_data.add_field('model_choice', MODEL_TYPE)
|
| 241 |
|
|
|
|
| 242 |
async with aiohttp.ClientSession() as session:
|
|
|
|
| 243 |
async with session.post(server.url, data=form_data, timeout=600) as resp:
|
| 244 |
if resp.status == 200:
|
| 245 |
result = await resp.json()
|
| 246 |
caption = result.get("caption")
|
| 247 |
|
| 248 |
if caption:
|
| 249 |
-
# Update progress counter
|
| 250 |
progress_tracker['completed'] += 1
|
| 251 |
-
async with state_lock:
|
| 252 |
-
state.current_file_progress = f"{progress_tracker['completed']}/{progress_tracker['total']}"
|
| 253 |
-
|
| 254 |
if progress_tracker['completed'] % 50 == 0:
|
| 255 |
print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} captions completed.")
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
return {
|
| 258 |
"course": course_name,
|
| 259 |
"image_path": image_path.name,
|
|
@@ -262,18 +269,18 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
|
|
| 262 |
}
|
| 263 |
else:
|
| 264 |
print(f"[{FLOW_ID}] Server {server.url} returned success but no caption for {image_path.name}. Retrying...")
|
| 265 |
-
continue
|
| 266 |
else:
|
| 267 |
error_text = await resp.text()
|
| 268 |
print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
|
| 269 |
-
continue
|
| 270 |
|
| 271 |
except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
|
| 272 |
print(f"[{FLOW_ID}] Connection/Timeout error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
|
| 273 |
-
continue
|
| 274 |
except Exception as e:
|
| 275 |
print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
|
| 276 |
-
continue
|
| 277 |
finally:
|
| 278 |
if server:
|
| 279 |
end_time = time.time()
|
|
@@ -284,406 +291,185 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
|
|
| 284 |
print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
|
| 285 |
return None
|
| 286 |
|
| 287 |
-
async def
|
| 288 |
-
"""
|
| 289 |
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
try:
|
| 294 |
-
print(f"[{FLOW_ID}] Downloading file: {repo_file_full_path}. Full name: {zip_full_name}")
|
| 295 |
-
|
| 296 |
-
# Use hf_hub_download to get the file path
|
| 297 |
-
zip_path = hf_hub_download(
|
| 298 |
-
repo_id=HF_DATASET_ID,
|
| 299 |
-
filename=repo_file_full_path, # Use the full path in the repo
|
| 300 |
-
repo_type="dataset",
|
| 301 |
-
token=HF_TOKEN,
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
|
| 305 |
-
|
| 306 |
-
# Create a temporary directory for extraction
|
| 307 |
-
extract_dir = TEMP_DIR / course_name / zip_full_name.replace('.', '_')
|
| 308 |
-
extract_dir.mkdir(parents=True, exist_ok=True)
|
| 309 |
-
|
| 310 |
-
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 311 |
-
zip_ref.extractall(extract_dir)
|
| 312 |
-
|
| 313 |
-
print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
|
| 314 |
-
|
| 315 |
-
# Clean up the downloaded zip file
|
| 316 |
-
os.remove(zip_path)
|
| 317 |
-
|
| 318 |
-
# Return the extraction directory and the full zip file name
|
| 319 |
-
return extract_dir, zip_full_name
|
| 320 |
-
|
| 321 |
-
except Exception as e:
|
| 322 |
-
print(f"[{FLOW_ID}] Error downloading or extracting zip for {repo_file_full_path}: {e}")
|
| 323 |
-
return None
|
| 324 |
-
|
| 325 |
-
async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
|
| 326 |
-
"""Uploads the final captions JSON file to the output dataset."""
|
| 327 |
-
caption_filename = Path(zip_full_name).with_suffix('.json').name
|
| 328 |
|
| 329 |
-
|
| 330 |
-
print(f"[{FLOW_ID}]
|
| 331 |
-
|
| 332 |
-
json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
|
| 333 |
-
|
| 334 |
-
api = get_hf_api()
|
| 335 |
-
api.upload_file(
|
| 336 |
-
path_or_fileobj=io.BytesIO(json_content),
|
| 337 |
-
path_in_repo=caption_filename,
|
| 338 |
-
repo_id=HF_OUTPUT_DATASET_ID,
|
| 339 |
-
repo_type="dataset",
|
| 340 |
-
commit_message=f"[{FLOW_ID}] Captions for {zip_full_name}"
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
|
| 344 |
-
return True
|
| 345 |
-
|
| 346 |
-
except Exception as e:
|
| 347 |
-
print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
|
| 348 |
return False
|
| 349 |
-
|
| 350 |
-
async def process_next_file_task():
|
| 351 |
-
"""Continuous task to process files based on the current index."""
|
| 352 |
-
global state
|
| 353 |
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
-
if
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
| 365 |
continue
|
| 366 |
|
| 367 |
-
# Check if we have files to process
|
| 368 |
-
if current_index >= state.total_files:
|
| 369 |
-
async with state_lock:
|
| 370 |
-
state.status = "Finished processing all files."
|
| 371 |
-
state.is_running = False
|
| 372 |
-
state.current_file = None
|
| 373 |
-
state.current_file_progress = "0/0"
|
| 374 |
-
print(f"[{FLOW_ID}] Reached end of file list. Stopping processing.")
|
| 375 |
-
await save_state_to_hf()
|
| 376 |
-
await asyncio.sleep(2)
|
| 377 |
-
continue
|
| 378 |
-
|
| 379 |
-
# Process the current file
|
| 380 |
-
repo_file_full_path = None
|
| 381 |
-
async with state_lock:
|
| 382 |
-
repo_file_full_path = state.all_zip_files[current_index]
|
| 383 |
-
|
| 384 |
-
if repo_file_full_path in state.processed_files:
|
| 385 |
-
state.current_index += 1
|
| 386 |
-
state.status = f"Skipping processed file: {Path(repo_file_full_path).name}"
|
| 387 |
-
state.current_file = Path(repo_file_full_path).name
|
| 388 |
-
print(f"[{FLOW_ID}] Skipping already processed file: {repo_file_full_path}")
|
| 389 |
-
await save_state_to_hf()
|
| 390 |
-
continue
|
| 391 |
-
|
| 392 |
-
# Mark the file as in-progress
|
| 393 |
-
state.status = f"Processing file {current_index + 1}/{state.total_files}"
|
| 394 |
-
state.current_file = Path(repo_file_full_path).name
|
| 395 |
-
state.current_file_progress = "0/0"
|
| 396 |
-
await save_state_to_hf()
|
| 397 |
-
|
| 398 |
-
# --- Process the file ---
|
| 399 |
extract_dir = None
|
| 400 |
-
|
| 401 |
|
| 402 |
try:
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
download_result = await download_and_extract_zip(repo_file_full_path)
|
| 406 |
|
| 407 |
-
if
|
| 408 |
raise Exception("Failed to download or extract zip file.")
|
| 409 |
-
|
| 410 |
-
extract_dir, zip_full_name = download_result
|
| 411 |
-
course_name = zip_full_name.split('_')[0]
|
| 412 |
|
| 413 |
-
# Find
|
|
|
|
| 414 |
image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
|
| 415 |
print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
|
| 416 |
|
| 417 |
if not image_paths:
|
| 418 |
print(f"[{FLOW_ID}] No images found in {zip_full_name}. Marking as complete.")
|
| 419 |
-
|
| 420 |
-
async with state_lock:
|
| 421 |
-
state.processed_files.add(repo_file_full_path)
|
| 422 |
-
state.current_index += 1
|
| 423 |
-
state.current_file = None
|
| 424 |
-
state.current_file_progress = "0/0"
|
| 425 |
-
state.status = "Idle"
|
| 426 |
-
await save_state_to_hf()
|
| 427 |
else:
|
| 428 |
-
#
|
| 429 |
progress_tracker = {
|
| 430 |
'total': len(image_paths),
|
| 431 |
'completed': 0
|
| 432 |
}
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
# Process images
|
| 438 |
semaphore = asyncio.Semaphore(len(servers))
|
|
|
|
| 439 |
async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
|
| 440 |
async with semaphore:
|
| 441 |
return await send_image_for_captioning(image_path, course_name, progress_tracker)
|
| 442 |
|
|
|
|
| 443 |
caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
|
|
|
|
|
|
|
| 444 |
results = await asyncio.gather(*caption_tasks)
|
|
|
|
|
|
|
| 445 |
all_captions = [r for r in results if r is not None]
|
| 446 |
|
| 447 |
-
# Final progress report
|
| 448 |
-
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
-
# Upload
|
| 452 |
-
if all_captions
|
|
|
|
| 453 |
if await upload_captions_to_hf(zip_full_name, all_captions):
|
| 454 |
-
print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}")
|
| 455 |
-
#
|
| 456 |
-
|
| 457 |
-
state.processed_files.add(repo_file_full_path)
|
| 458 |
-
state.current_index += 1
|
| 459 |
-
state.current_file = None
|
| 460 |
-
state.current_file_progress = "0/0"
|
| 461 |
-
state.status = "Idle"
|
| 462 |
-
await save_state_to_hf()
|
| 463 |
else:
|
| 464 |
-
print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}.
|
| 465 |
-
|
| 466 |
-
async with state_lock:
|
| 467 |
-
state.status = f"Upload failed for {zip_full_name}. Retrying later."
|
| 468 |
-
await save_state_to_hf()
|
| 469 |
-
await asyncio.sleep(30) # Wait before retry
|
| 470 |
else:
|
| 471 |
-
print(f"[{FLOW_ID}] No captions generated for {zip_full_name}.
|
| 472 |
-
|
| 473 |
-
async with state_lock:
|
| 474 |
-
state.status = f"No captions for {zip_full_name}. Retrying later."
|
| 475 |
-
await save_state_to_hf()
|
| 476 |
-
await asyncio.sleep(30) # Wait before retry
|
| 477 |
|
| 478 |
except Exception as e:
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
state.status = f"Error processing {Path(repo_file_full_path).name}: {error_message[:100]}..."
|
| 483 |
-
await save_state_to_hf()
|
| 484 |
-
await asyncio.sleep(30) # Wait before retry
|
| 485 |
|
| 486 |
finally:
|
| 487 |
-
# Cleanup
|
| 488 |
if extract_dir and extract_dir.exists():
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
# --- FastAPI App and Endpoints ---
|
| 496 |
|
| 497 |
app = FastAPI(
|
| 498 |
title=f"Flow Server {FLOW_ID} API",
|
| 499 |
-
description="
|
| 500 |
-
version="
|
| 501 |
)
|
| 502 |
|
| 503 |
-
# Setup Jinja2 templates for the UI
|
| 504 |
-
templates = Jinja2Templates(directory="templates")
|
| 505 |
-
|
| 506 |
@app.on_event("startup")
|
| 507 |
async def startup_event():
|
| 508 |
-
print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}.
|
| 509 |
|
| 510 |
-
#
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
# Calculate server stats
|
| 543 |
-
server_stats = [
|
| 544 |
-
{
|
| 545 |
-
"url": s.url,
|
| 546 |
-
"busy": s.busy,
|
| 547 |
-
"processed": s.total_processed,
|
| 548 |
-
"fps": f"{s.fps:.2f}"
|
| 549 |
-
} for s in servers
|
| 550 |
-
]
|
| 551 |
-
|
| 552 |
-
# Calculate overall FPS
|
| 553 |
-
total_processed = sum(s.total_processed for s in servers)
|
| 554 |
-
total_time = sum(s.total_time for s in servers)
|
| 555 |
-
overall_fps = total_processed / total_time if total_time > 0 else 0
|
| 556 |
-
|
| 557 |
-
context = {
|
| 558 |
-
"request": request,
|
| 559 |
-
"flow_id": FLOW_ID,
|
| 560 |
-
"status": state.status,
|
| 561 |
-
"is_running": state.is_running,
|
| 562 |
-
"total_files": state.total_files,
|
| 563 |
-
"processed_count": processed_count,
|
| 564 |
-
"remaining_count": remaining_count,
|
| 565 |
-
"current_index": state.current_index,
|
| 566 |
-
"current_file": state.current_file if state.current_file else "N/A",
|
| 567 |
-
"current_file_progress": state.current_file_progress,
|
| 568 |
-
"last_update": state.last_update,
|
| 569 |
-
"overall_fps": f"{overall_fps:.2f}",
|
| 570 |
-
"server_stats": server_stats
|
| 571 |
-
}
|
| 572 |
-
return templates.TemplateResponse("index.html", context)
|
| 573 |
-
|
| 574 |
-
@app.post("/set_index")
|
| 575 |
-
async def set_index(request: Request, background_tasks: BackgroundTasks):
|
| 576 |
-
"""Endpoint to manually set the start index."""
|
| 577 |
-
global state
|
| 578 |
|
| 579 |
-
|
| 580 |
-
try:
|
| 581 |
-
new_index = int(form.get("start_index"))
|
| 582 |
-
except (TypeError, ValueError):
|
| 583 |
-
raise HTTPException(status_code=400, detail="Invalid index value.")
|
| 584 |
-
|
| 585 |
-
async with state_lock:
|
| 586 |
-
if 0 <= new_index < state.total_files:
|
| 587 |
-
state.current_index = new_index
|
| 588 |
-
state.status = f"Index set to {new_index}. Restarting processing."
|
| 589 |
-
|
| 590 |
-
# If the loop is not running, start it
|
| 591 |
-
if not state.is_running:
|
| 592 |
-
state.is_running = True
|
| 593 |
-
background_tasks.add_task(process_next_file_task)
|
| 594 |
-
|
| 595 |
-
await save_state_to_hf()
|
| 596 |
-
print(f"[{FLOW_ID}] Index manually set to {new_index}.")
|
| 597 |
-
return {"status": "success", "message": f"Start index set to {new_index}. Processing will resume from this point."}
|
| 598 |
-
elif new_index == state.total_files:
|
| 599 |
-
state.current_index = new_index
|
| 600 |
-
state.is_running = False
|
| 601 |
-
state.status = "Finished processing all files."
|
| 602 |
-
await save_state_to_hf()
|
| 603 |
-
return {"status": "success", "message": "Index set to end of list. Processing stopped."}
|
| 604 |
-
else:
|
| 605 |
-
raise HTTPException(status_code=400, detail=f"Index {new_index} is out of bounds (0 to {state.total_files}).")
|
| 606 |
-
|
| 607 |
-
@app.post("/control_processing")
|
| 608 |
-
async def control_processing(request: Request, background_tasks: BackgroundTasks):
|
| 609 |
-
"""Endpoint to start/stop the processing loop."""
|
| 610 |
-
global state
|
| 611 |
|
| 612 |
-
|
| 613 |
-
|
|
|
|
| 614 |
|
| 615 |
-
|
| 616 |
-
if action == "start":
|
| 617 |
-
if not state.is_running:
|
| 618 |
-
# Reset state if we're at the end
|
| 619 |
-
if state.current_index >= state.total_files:
|
| 620 |
-
state.current_index = 0
|
| 621 |
-
state.status = "Reset to start and processing..."
|
| 622 |
-
|
| 623 |
-
state.is_running = True
|
| 624 |
-
state.status = "Processing started."
|
| 625 |
-
|
| 626 |
-
# Start the processing task
|
| 627 |
-
background_tasks.add_task(process_next_file_task)
|
| 628 |
-
await save_state_to_hf()
|
| 629 |
-
|
| 630 |
-
print(f"[{FLOW_ID}] Processing manually started from index {state.current_index}")
|
| 631 |
-
return {"status": "success", "message": "Processing loop started."}
|
| 632 |
-
else:
|
| 633 |
-
return {"status": "info", "message": "Processing is already running."}
|
| 634 |
-
|
| 635 |
-
elif action == "stop":
|
| 636 |
-
if state.is_running:
|
| 637 |
-
state.is_running = False
|
| 638 |
-
state.status = "Processing stopped by user."
|
| 639 |
-
await save_state_to_hf()
|
| 640 |
-
|
| 641 |
-
print(f"[{FLOW_ID}] Processing manually stopped")
|
| 642 |
-
return {"status": "success", "message": "Processing loop stopped."}
|
| 643 |
-
else:
|
| 644 |
-
return {"status": "info", "message": "Processing is already stopped."}
|
| 645 |
-
else:
|
| 646 |
-
raise HTTPException(status_code=400, detail="Invalid action.")
|
| 647 |
-
|
| 648 |
-
@app.get("/status")
|
| 649 |
-
async def get_status():
|
| 650 |
-
"""API endpoint to get the current server status as JSON."""
|
| 651 |
-
async with state_lock:
|
| 652 |
-
processed_count = len(state.processed_files)
|
| 653 |
-
|
| 654 |
-
server_stats = [
|
| 655 |
-
{
|
| 656 |
-
"url": s.url,
|
| 657 |
-
"busy": s.busy,
|
| 658 |
-
"processed": s.total_processed,
|
| 659 |
-
"fps": f"{s.fps:.2f}"
|
| 660 |
-
} for s in servers
|
| 661 |
-
]
|
| 662 |
-
|
| 663 |
-
total_processed = sum(s.total_processed for s in servers)
|
| 664 |
-
total_time = sum(s.total_time for s in servers)
|
| 665 |
-
overall_fps = total_processed / total_time if total_time > 0 else 0
|
| 666 |
-
|
| 667 |
-
return {
|
| 668 |
-
"flow_id": FLOW_ID,
|
| 669 |
-
"status": state.status,
|
| 670 |
-
"is_running": state.is_running,
|
| 671 |
-
"total_files": state.total_files,
|
| 672 |
-
"processed_count": processed_count,
|
| 673 |
-
"remaining_count": state.total_files - processed_count,
|
| 674 |
-
"current_index": state.current_index,
|
| 675 |
-
"current_file": state.current_file,
|
| 676 |
-
"current_file_progress": state.current_file_progress,
|
| 677 |
-
"last_update": state.last_update,
|
| 678 |
-
"overall_fps": f"{overall_fps:.2f}",
|
| 679 |
-
"server_stats": server_stats
|
| 680 |
-
}
|
| 681 |
-
|
| 682 |
-
# The original /process_course endpoint is now obsolete as the server manages its own queue
|
| 683 |
-
# @app.post("/process_course")
|
| 684 |
-
# async def process_course(request: ProcessCourseRequest, background_tasks: BackgroundTasks):
|
| 685 |
-
# return {"status": "obsolete", "message": "The server now manages its own processing queue based on the index."}
|
| 686 |
-
|
| 687 |
|
| 688 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 689 |
uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
|
|
|
|
| 4 |
import asyncio
|
| 5 |
import aiohttp
|
| 6 |
import zipfile
|
|
|
|
| 7 |
import shutil
|
| 8 |
+
from typing import Dict, List, Set, Optional, Tuple
|
| 9 |
from urllib.parse import quote
|
| 10 |
from datetime import datetime
|
| 11 |
from pathlib import Path
|
| 12 |
+
import io
|
| 13 |
|
| 14 |
+
from fastapi import FastAPI, BackgroundTasks, HTTPException, status
|
|
|
|
|
|
|
| 15 |
from pydantic import BaseModel, Field
|
| 16 |
+
from huggingface_hub import HfApi, hf_hub_download
|
|
|
|
| 17 |
|
| 18 |
# --- Configuration ---
|
| 19 |
+
AUTO_START_INDEX = 20 # Hardcoded default start index if no progress is found
|
| 20 |
FLOW_ID = os.getenv("FLOW_ID", "flow_default")
|
| 21 |
FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
|
|
|
|
|
|
|
| 22 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 23 |
+
HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3") # Source dataset for zip files
|
| 24 |
+
HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/helium") # Target dataset for captions
|
| 25 |
+
|
| 26 |
+
# Progress Tracking File
|
| 27 |
+
PROGRESS_FILE = Path("processing_progress.json")
|
| 28 |
+
# Directory within the HF dataset where the zip files are located
|
| 29 |
+
ZIP_FILE_PREFIX = "frames/"
|
| 30 |
|
| 31 |
# Using the full list from the user's original code for actual deployment
|
| 32 |
CAPTION_SERVERS = [
|
| 33 |
"https://fred808-pil-4-1.hf.space/analyze",
|
| 34 |
"https://fred808-pil-4-2.hf.space/analyze",
|
| 35 |
+
# ... (rest of the servers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
"https://fredalone-fredalone-gahbxh.hf.space/analyze",
|
| 37 |
"https://fredalone-fredalone-kw2po4.hf.space/analyze",
|
| 38 |
"https://fredalone-fredalone-8h285h.hf.space/analyze"
|
|
|
|
| 44 |
TEMP_DIR.mkdir(exist_ok=True)
|
| 45 |
|
| 46 |
# --- Models ---
|
| 47 |
+
class ProcessStartRequest(BaseModel):
|
| 48 |
+
start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the zip file to start processing from (1-indexed).")
|
| 49 |
|
| 50 |
class CaptionServer:
|
| 51 |
def __init__(self, url):
|
|
|
|
| 59 |
def fps(self):
|
| 60 |
return self.total_processed / self.total_time if self.total_time > 0 else 0
|
| 61 |
|
| 62 |
+
# Global state for caption servers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
servers = [CaptionServer(url) for url in CAPTION_SERVERS]
|
| 64 |
server_index = 0
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
# --- Progress Tracking Functions ---
|
| 67 |
|
| 68 |
+
def load_progress() -> Dict:
|
| 69 |
+
"""Loads the processing progress from the JSON file."""
|
| 70 |
+
if PROGRESS_FILE.exists():
|
| 71 |
+
try:
|
| 72 |
+
with PROGRESS_FILE.open('r') as f:
|
| 73 |
+
return json.load(f)
|
| 74 |
+
except json.JSONDecodeError:
|
| 75 |
+
print(f"[{FLOW_ID}] WARNING: Progress file is corrupted. Starting fresh.")
|
| 76 |
+
# Fall through to return default structure
|
| 77 |
+
|
| 78 |
+
# Default structure
|
| 79 |
+
return {
|
| 80 |
+
"last_processed_index": 0,
|
| 81 |
+
"processed_files": {}, # {index: repo_path}
|
| 82 |
+
"file_list": [] # Full list of all zip files found in the dataset
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def save_progress(progress_data: Dict):
|
| 86 |
+
"""Saves the processing progress to the JSON file."""
|
| 87 |
+
try:
|
| 88 |
+
with PROGRESS_FILE.open('w') as f:
|
| 89 |
+
json.dump(progress_data, f, indent=4)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress to {PROGRESS_FILE}: {e}")
|
| 92 |
|
| 93 |
+
# --- Hugging Face Utility Functions ---
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
async def get_zip_file_list(progress_data: Dict) -> List[str]:
|
| 96 |
+
"""
|
| 97 |
+
Fetches the list of all zip files from the dataset, or uses the cached list.
|
| 98 |
+
Updates the progress_data with the file list if a new list is fetched.
|
| 99 |
+
"""
|
| 100 |
+
if progress_data['file_list']:
|
| 101 |
+
print(f"[{FLOW_ID}] Using cached file list with {len(progress_data['file_list'])} files.")
|
| 102 |
+
return progress_data['file_list']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
print(f"[{FLOW_ID}] Fetching full list of zip files from {HF_DATASET_ID}...")
|
| 105 |
+
try:
|
| 106 |
+
api = HfApi(token=HF_TOKEN)
|
| 107 |
+
repo_files = api.list_repo_files(
|
| 108 |
+
repo_id=HF_DATASET_ID,
|
| 109 |
+
repo_type="dataset"
|
| 110 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
# Filter for zip files in the specified directory and sort them alphabetically for consistent indexing
|
| 113 |
+
zip_files = sorted([
|
| 114 |
+
f for f in repo_files
|
| 115 |
+
if f.startswith(ZIP_FILE_PREFIX) and f.endswith('.zip')
|
| 116 |
+
])
|
| 117 |
|
| 118 |
+
if not zip_files:
|
| 119 |
+
raise FileNotFoundError(f"No zip files found in '{ZIP_FILE_PREFIX}' directory of dataset '{HF_DATASET_ID}'.")
|
| 120 |
+
|
| 121 |
+
print(f"[{FLOW_ID}] Found {len(zip_files)} zip files.")
|
| 122 |
+
|
| 123 |
+
# Update and save the progress data
|
| 124 |
+
progress_data['file_list'] = zip_files
|
| 125 |
+
save_progress(progress_data)
|
| 126 |
+
|
| 127 |
+
return zip_files
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"[{FLOW_ID}] Error fetching file list from Hugging Face: {e}")
|
| 131 |
+
return []
|
| 132 |
|
| 133 |
+
async def download_and_extract_zip_by_index(file_index: int, repo_file_full_path: str) -> Optional[Path]:
|
| 134 |
+
"""Downloads the zip file for the given index and extracts its contents."""
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
# Extract the base name for the extraction directory
|
| 137 |
+
zip_full_name = Path(repo_file_full_path).name
|
| 138 |
+
course_name = zip_full_name.replace('.zip', '') # Use the file name as the course/job name
|
| 139 |
+
|
| 140 |
+
print(f"[{FLOW_ID}] Processing file #{file_index}: {repo_file_full_path}. Full name: {zip_full_name}")
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
# Use hf_hub_download to get the file path
|
| 144 |
+
zip_path = hf_hub_download(
|
| 145 |
+
repo_id=HF_DATASET_ID,
|
| 146 |
+
filename=repo_file_full_path, # Use the full path in the repo
|
| 147 |
+
repo_type="dataset",
|
| 148 |
+
token=HF_TOKEN,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
|
| 152 |
+
|
| 153 |
+
# Create a temporary directory for extraction
|
| 154 |
+
extract_dir = TEMP_DIR / course_name
|
| 155 |
+
# Ensure a clean directory for extraction
|
| 156 |
+
if extract_dir.exists():
|
| 157 |
+
shutil.rmtree(extract_dir)
|
| 158 |
+
extract_dir.mkdir(exist_ok=True)
|
| 159 |
+
|
| 160 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 161 |
+
zip_ref.extractall(extract_dir)
|
| 162 |
|
| 163 |
+
print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
|
| 164 |
+
|
| 165 |
+
# Clean up the downloaded zip file to save space
|
| 166 |
+
os.remove(zip_path)
|
| 167 |
+
|
| 168 |
+
return extract_dir
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"[{FLOW_ID}] Error downloading or extracting zip for {repo_file_full_path}: {e}")
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
|
| 175 |
+
"""Uploads the final captions JSON file to the output dataset."""
|
| 176 |
+
# Use the full zip name, replacing the extension with .json
|
| 177 |
+
caption_filename = Path(zip_full_name).with_suffix('.json').name
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
|
| 181 |
+
|
| 182 |
+
# Create JSON content in memory
|
| 183 |
+
json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
|
| 184 |
+
|
| 185 |
+
api = HfApi(token=HF_TOKEN)
|
| 186 |
+
api.upload_file(
|
| 187 |
+
path_or_fileobj=io.BytesIO(json_content),
|
| 188 |
+
path_in_repo=caption_filename,
|
| 189 |
+
repo_id=HF_OUTPUT_DATASET_ID,
|
| 190 |
+
repo_type="dataset",
|
| 191 |
+
commit_message=f"[{FLOW_ID}] Captions for {zip_full_name}"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
|
| 195 |
+
return True
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
|
| 199 |
+
return False
|
| 200 |
|
| 201 |
# --- Core Processing Functions (Modified) ---
|
| 202 |
|
|
|
|
| 205 |
global server_index
|
| 206 |
start_time = time.time()
|
| 207 |
while True:
|
| 208 |
+
# Round-robin check for an available server
|
| 209 |
for _ in range(len(servers)):
|
| 210 |
server = servers[server_index]
|
| 211 |
server_index = (server_index + 1) % len(servers)
|
| 212 |
if not server.busy:
|
| 213 |
return server
|
| 214 |
|
| 215 |
+
# If all servers are busy, wait for a short period and check again
|
| 216 |
await asyncio.sleep(0.5)
|
| 217 |
|
| 218 |
+
# Check if timeout has been reached
|
| 219 |
if time.time() - start_time > timeout:
|
| 220 |
raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
|
| 221 |
|
| 222 |
async def send_image_for_captioning(image_path: Path, course_name: str, progress_tracker: Dict) -> Optional[Dict]:
|
| 223 |
"""Sends a single image to a caption server for processing."""
|
| 224 |
+
# This function now handles server selection and retries internally
|
| 225 |
MAX_RETRIES = 3
|
| 226 |
for attempt in range(MAX_RETRIES):
|
| 227 |
server = None
|
| 228 |
try:
|
| 229 |
+
# 1. Get an available server (will wait if all are busy, with a timeout)
|
| 230 |
server = await get_available_server()
|
| 231 |
server.busy = True
|
| 232 |
start_time = time.time()
|
| 233 |
|
| 234 |
+
# Print a less verbose message only on the first attempt
|
| 235 |
if attempt == 0:
|
| 236 |
print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
|
| 237 |
|
| 238 |
+
# 2. Prepare request data
|
| 239 |
form_data = aiohttp.FormData()
|
| 240 |
form_data.add_field('file',
|
| 241 |
image_path.open('rb'),
|
|
|
|
| 243 |
content_type='image/jpeg')
|
| 244 |
form_data.add_field('model_choice', MODEL_TYPE)
|
| 245 |
|
| 246 |
+
# 3. Send request
|
| 247 |
async with aiohttp.ClientSession() as session:
|
| 248 |
+
# Increased timeout to 10 minutes (600s) as requested by user's problem description
|
| 249 |
async with session.post(server.url, data=form_data, timeout=600) as resp:
|
| 250 |
if resp.status == 200:
|
| 251 |
result = await resp.json()
|
| 252 |
caption = result.get("caption")
|
| 253 |
|
| 254 |
if caption:
|
| 255 |
+
# Update progress counter
|
| 256 |
progress_tracker['completed'] += 1
|
|
|
|
|
|
|
|
|
|
| 257 |
if progress_tracker['completed'] % 50 == 0:
|
| 258 |
print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} captions completed.")
|
| 259 |
|
| 260 |
+
# Log success only if it's not a progress report interval
|
| 261 |
+
if progress_tracker['completed'] % 50 != 0:
|
| 262 |
+
print(f"[{FLOW_ID}] Success: {image_path.name} captioned by {server.url}")
|
| 263 |
+
|
| 264 |
return {
|
| 265 |
"course": course_name,
|
| 266 |
"image_path": image_path.name,
|
|
|
|
| 269 |
}
|
| 270 |
else:
|
| 271 |
print(f"[{FLOW_ID}] Server {server.url} returned success but no caption for {image_path.name}. Retrying...")
|
| 272 |
+
continue # Retry with a different server
|
| 273 |
else:
|
| 274 |
error_text = await resp.text()
|
| 275 |
print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
|
| 276 |
+
continue # Retry with a different server
|
| 277 |
|
| 278 |
except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
|
| 279 |
print(f"[{FLOW_ID}] Connection/Timeout error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
|
| 280 |
+
continue # Retry with a different server
|
| 281 |
except Exception as e:
|
| 282 |
print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
|
| 283 |
+
continue # Retry with a different server
|
| 284 |
finally:
|
| 285 |
if server:
|
| 286 |
end_time = time.time()
|
|
|
|
| 291 |
print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
|
| 292 |
return None
|
| 293 |
|
| 294 |
+
async def process_dataset_task(start_index: int):
|
| 295 |
+
"""Main task to process the dataset sequentially starting from a given index."""
|
| 296 |
|
| 297 |
+
progress = load_progress()
|
| 298 |
+
file_list = await get_zip_file_list(progress)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
+
if not file_list:
|
| 301 |
+
print(f"[{FLOW_ID}] ERROR: Cannot proceed. File list is empty.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
+
# Ensure start_index is within bounds
|
| 305 |
+
if start_index > len(file_list):
|
| 306 |
+
print(f"[{FLOW_ID}] WARNING: Start index {start_index} is greater than the total number of files ({len(file_list)}). Exiting.")
|
| 307 |
+
return True
|
| 308 |
|
| 309 |
+
# Determine the actual starting index in the 0-indexed list
|
| 310 |
+
start_list_index = start_index - 1
|
| 311 |
+
|
| 312 |
+
print(f"[{FLOW_ID}] Starting dataset processing from file index: {start_index} out of {len(file_list)}.")
|
| 313 |
+
|
| 314 |
+
global_success = True
|
| 315 |
+
|
| 316 |
+
for i in range(start_list_index, len(file_list)):
|
| 317 |
+
file_index = i + 1 # 1-indexed for user display and progress tracking
|
| 318 |
+
repo_file_full_path = file_list[i]
|
| 319 |
+
zip_full_name = Path(repo_file_full_path).name
|
| 320 |
+
course_name = zip_full_name.replace('.zip', '') # Use the file name as the course/job name
|
| 321 |
|
| 322 |
+
# Check if the file has already been successfully processed
|
| 323 |
+
if str(file_index) in progress['processed_files']:
|
| 324 |
+
print(f"[{FLOW_ID}] Skipping file #{file_index} ({zip_full_name}): Already processed according to progress file.")
|
| 325 |
+
progress['last_processed_index'] = file_index
|
| 326 |
+
save_progress(progress)
|
| 327 |
continue
|
| 328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
extract_dir = None
|
| 330 |
+
current_file_success = False
|
| 331 |
|
| 332 |
try:
|
| 333 |
+
# 1. Download and Extract
|
| 334 |
+
extract_dir = await download_and_extract_zip_by_index(file_index, repo_file_full_path)
|
|
|
|
| 335 |
|
| 336 |
+
if not extract_dir:
|
| 337 |
raise Exception("Failed to download or extract zip file.")
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
+
# 2. Find Images
|
| 340 |
+
# Use recursive glob to find images in subdirectories
|
| 341 |
image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
|
| 342 |
print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
|
| 343 |
|
| 344 |
if not image_paths:
|
| 345 |
print(f"[{FLOW_ID}] No images found in {zip_full_name}. Marking as complete.")
|
| 346 |
+
current_file_success = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
else:
|
| 348 |
+
# 3. Process Images (Captioning)
|
| 349 |
progress_tracker = {
|
| 350 |
'total': len(image_paths),
|
| 351 |
'completed': 0
|
| 352 |
}
|
| 353 |
+
print(f"[{FLOW_ID}] Starting captioning for {progress_tracker['total']} images in {zip_full_name}...")
|
| 354 |
+
|
| 355 |
+
# Create a semaphore to limit concurrent tasks to the number of available servers
|
|
|
|
|
|
|
| 356 |
semaphore = asyncio.Semaphore(len(servers))
|
| 357 |
+
|
| 358 |
async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
|
| 359 |
async with semaphore:
|
| 360 |
return await send_image_for_captioning(image_path, course_name, progress_tracker)
|
| 361 |
|
| 362 |
+
# Create a list of tasks for parallel captioning
|
| 363 |
caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
|
| 364 |
+
|
| 365 |
+
# Run all captioning tasks concurrently
|
| 366 |
results = await asyncio.gather(*caption_tasks)
|
| 367 |
+
|
| 368 |
+
# Filter out failed results
|
| 369 |
all_captions = [r for r in results if r is not None]
|
| 370 |
|
| 371 |
+
# Final progress report for the current file
|
| 372 |
+
if len(all_captions) == len(image_paths):
|
| 373 |
+
print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully completed all {len(all_captions)} captions.")
|
| 374 |
+
current_file_success = True
|
| 375 |
+
else:
|
| 376 |
+
print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} captions. Marking as partial failure.")
|
| 377 |
+
current_file_success = False
|
| 378 |
|
| 379 |
+
# 4. Upload Results
|
| 380 |
+
if all_captions:
|
| 381 |
+
print(f"[{FLOW_ID}] Uploading {len(all_captions)} captions for {zip_full_name}...")
|
| 382 |
if await upload_captions_to_hf(zip_full_name, all_captions):
|
| 383 |
+
print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
|
| 384 |
+
# Partial success in captioning is still a success for the upload step
|
| 385 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
else:
|
| 387 |
+
print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}.")
|
| 388 |
+
current_file_success = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
else:
|
| 390 |
+
print(f"[{FLOW_ID}] No captions generated. Skipping upload for {zip_full_name}.")
|
| 391 |
+
current_file_success = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
except Exception as e:
|
| 394 |
+
print(f"[{FLOW_ID}] Critical error in process_dataset_task for file #{file_index} ({zip_full_name}): {e}")
|
| 395 |
+
current_file_success = False
|
| 396 |
+
global_success = False # Mark overall task as failed if any file fails critically
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
finally:
|
| 399 |
+
# 5. Cleanup and Update Progress
|
| 400 |
if extract_dir and extract_dir.exists():
|
| 401 |
+
print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
|
| 402 |
+
shutil.rmtree(extract_dir, ignore_errors=True)
|
| 403 |
+
|
| 404 |
+
if current_file_success:
|
| 405 |
+
# Update progress only on successful completion of the file
|
| 406 |
+
progress['last_processed_index'] = file_index
|
| 407 |
+
progress['processed_files'][str(file_index)] = repo_file_full_path
|
| 408 |
+
save_progress(progress)
|
| 409 |
+
print(f"[{FLOW_ID}] Progress saved: File #{file_index} marked as processed.")
|
| 410 |
+
else:
|
| 411 |
+
# If a file fails, we stop the continuous loop to allow for manual intervention or a fresh start
|
| 412 |
+
print(f"[{FLOW_ID}] File #{file_index} failed. Stopping continuous processing.")
|
| 413 |
+
global_success = False
|
| 414 |
+
break
|
| 415 |
+
|
| 416 |
+
print(f"[{FLOW_ID}] All processing loops complete. Overall success: {global_success}")
|
| 417 |
+
return global_success
|
| 418 |
|
| 419 |
# --- FastAPI App and Endpoints ---
|
| 420 |
|
| 421 |
app = FastAPI(
|
| 422 |
title=f"Flow Server {FLOW_ID} API",
|
| 423 |
+
description="Sequentially processes zip files from a dataset, captions images, and tracks progress.",
|
| 424 |
+
version="1.0.0"
|
| 425 |
)
|
| 426 |
|
|
|
|
|
|
|
|
|
|
| 427 |
@app.on_event("startup")
|
| 428 |
async def startup_event():
|
| 429 |
+
print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}.")
|
| 430 |
|
| 431 |
+
# Automatically start the processing task
|
| 432 |
+
progress = load_progress()
|
| 433 |
+
# Start from the last processed index + 1, or the hardcoded AUTO_START_INDEX if the progress file is new/empty
|
| 434 |
+
start_index = progress.get('last_processed_index', 0) + 1
|
| 435 |
+
if start_index < AUTO_START_INDEX:
|
| 436 |
+
start_index = AUTO_START_INDEX
|
| 437 |
+
|
| 438 |
+
# Use a dummy BackgroundTasks object for the startup task
|
| 439 |
+
# Note: FastAPI's startup events can't directly use BackgroundTasks, but we can use asyncio.create_task
|
| 440 |
+
# to run the long-running process in the background without blocking the server startup.
|
| 441 |
+
print(f"[{FLOW_ID}] Auto-starting processing from index: {start_index}...")
|
| 442 |
+
asyncio.create_task(process_dataset_task(start_index))
|
| 443 |
+
|
| 444 |
+
@app.get("/")
|
| 445 |
+
async def root():
|
| 446 |
+
progress = load_progress()
|
| 447 |
+
return {
|
| 448 |
+
"flow_id": FLOW_ID,
|
| 449 |
+
"status": "ready",
|
| 450 |
+
"last_processed_index": progress['last_processed_index'],
|
| 451 |
+
"total_files_in_list": len(progress['file_list']),
|
| 452 |
+
"processed_files_count": len(progress['processed_files']),
|
| 453 |
+
"total_servers": len(servers),
|
| 454 |
+
"busy_servers": sum(1 for s in servers if s.busy),
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
@app.post("/start_processing")
|
| 458 |
+
async def start_processing(request: ProcessStartRequest, background_tasks: BackgroundTasks):
|
| 459 |
+
"""
|
| 460 |
+
Starts the sequential processing of zip files from the given index in the background.
|
| 461 |
+
"""
|
| 462 |
+
start_index = request.start_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
|
| 464 |
+
print(f"[{FLOW_ID}] Received request to start processing from index: {start_index}. Starting background task.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
|
| 466 |
+
# Start the heavy processing in a background task so the API call returns immediately
|
| 467 |
+
# Note: The server is already auto-starting, but this allows for manual restart/override.
|
| 468 |
+
background_tasks.add_task(process_dataset_task, start_index)
|
| 469 |
|
| 470 |
+
return {"status": "processing", "start_index": start_index, "message": "Dataset processing started in background."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
if __name__ == "__main__":
|
| 473 |
+
import uvicorn
|
| 474 |
+
# Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
|
| 475 |
uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
|