Update app.py
Browse files
app.py
CHANGED
|
@@ -4,14 +4,14 @@ import time
|
|
| 4 |
import asyncio
|
| 5 |
import aiohttp
|
| 6 |
import zipfile
|
|
|
|
| 7 |
import shutil
|
| 8 |
from typing import Dict, List, Set, Optional, Any
|
| 9 |
from urllib.parse import quote
|
| 10 |
from datetime import datetime
|
| 11 |
from pathlib import Path
|
| 12 |
-
import io
|
| 13 |
|
| 14 |
-
from fastapi import FastAPI, BackgroundTasks, HTTPException, status, Request
|
| 15 |
from fastapi.responses import HTMLResponse
|
| 16 |
from fastapi.templating import Jinja2Templates
|
| 17 |
from pydantic import BaseModel, Field
|
|
@@ -19,21 +19,17 @@ from huggingface_hub import HfApi, hf_hub_download, HfFileSystem
|
|
| 19 |
import uvicorn
|
| 20 |
|
| 21 |
# --- Configuration ---
|
| 22 |
-
# Flow Server ID and Port will be set via environment variables for easy deployment
|
| 23 |
FLOW_ID = os.getenv("FLOW_ID", "flow_default")
|
| 24 |
-
FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
|
| 25 |
-
|
| 26 |
-
# Manager Server Configuration
|
| 27 |
MANAGER_URL = os.getenv("MANAGER_URL", "https://fred808-fcord.hf.space")
|
| 28 |
MANAGER_COMPLETE_TASK_URL = f"{MANAGER_URL}/task/complete"
|
| 29 |
-
|
| 30 |
-
# Hugging Face Configuration
|
| 31 |
-
HF_TOKEN = os.getenv("HF_TOKEN", "") # User provided token
|
| 32 |
HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3")
|
| 33 |
-
HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/helium")
|
|
|
|
| 34 |
|
| 35 |
-
# Using the full list from the user's original code
|
| 36 |
-
|
| 37 |
"https://fred808-pil-4-1.hf.space/analyze",
|
| 38 |
"https://fred808-pil-4-2.hf.space/analyze",
|
| 39 |
"https://fred808-pil-4-3.hf.space/analyze",
|
|
@@ -78,126 +74,156 @@ MODEL_TYPE = "Florence-2-large"
|
|
| 78 |
TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
|
| 79 |
TEMP_DIR.mkdir(exist_ok=True)
|
| 80 |
|
| 81 |
-
# State persistence file name in the output dataset
|
| 82 |
-
STATE_FILENAME = f"processing_state_{FLOW_ID}.json"
|
| 83 |
-
|
| 84 |
# --- Models ---
|
| 85 |
class ProcessCourseRequest(BaseModel):
|
| 86 |
course_name: Optional[str] = None
|
| 87 |
-
start_index: int = 0 # New field for configurable start index
|
| 88 |
|
| 89 |
-
class CaptionServer
|
| 90 |
-
url:
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
@property
|
| 97 |
def fps(self):
|
| 98 |
return self.total_processed / self.total_time if self.total_time > 0 else 0
|
| 99 |
|
| 100 |
-
class
|
| 101 |
-
#
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
server_index = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
async def load_state_from_hf():
|
| 120 |
-
"""
|
| 121 |
global state
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
print(f"[{FLOW_ID}] State loaded successfully. Processed files: {len(state.processed_files)}")
|
| 142 |
-
return True
|
| 143 |
-
else:
|
| 144 |
-
print(f"[{FLOW_ID}] State file not found. Initializing with default servers.")
|
| 145 |
-
state.servers = [CaptionServer(url=url) for url in INITIAL_CAPTION_SERVERS]
|
| 146 |
-
return False
|
| 147 |
-
except Exception as e:
|
| 148 |
-
print(f"[{FLOW_ID}] Error loading state: {e}. Initializing with default servers.")
|
| 149 |
-
state.servers = [CaptionServer(url=url) for url in INITIAL_CAPTION_SERVERS]
|
| 150 |
-
return False
|
| 151 |
|
| 152 |
async def save_state_to_hf():
|
| 153 |
-
"""Saves the current
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
async with state_lock:
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
try:
|
| 157 |
-
|
| 158 |
-
data_to_save = state.dict()
|
| 159 |
-
data_to_save['processed_files'] = list(state.processed_files) # Convert set to list for JSON
|
| 160 |
-
|
| 161 |
-
json_content = json.dumps(data_to_save, indent=2, ensure_ascii=False).encode('utf-8')
|
| 162 |
-
|
| 163 |
-
api = HfApi(token=HF_TOKEN)
|
| 164 |
api.upload_file(
|
| 165 |
path_or_fileobj=io.BytesIO(json_content),
|
| 166 |
-
path_in_repo=
|
| 167 |
repo_id=HF_OUTPUT_DATASET_ID,
|
| 168 |
repo_type="dataset",
|
| 169 |
-
commit_message=f"[{FLOW_ID}] Update
|
| 170 |
)
|
| 171 |
print(f"[{FLOW_ID}] State saved successfully.")
|
| 172 |
return True
|
| 173 |
except Exception as e:
|
| 174 |
-
print(f"[{FLOW_ID}] Error saving state: {e}")
|
| 175 |
return False
|
| 176 |
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
async def get_available_server(timeout: float = 300.0) -> CaptionServer:
|
| 180 |
-
"""Round-robin selection of an available caption server
|
| 181 |
global server_index
|
| 182 |
start_time = time.time()
|
| 183 |
-
|
| 184 |
while True:
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
# Round-robin check for an available server
|
| 191 |
-
for _ in range(len(state.servers)):
|
| 192 |
-
server = state.servers[server_index % len(state.servers)]
|
| 193 |
-
server_index = (server_index + 1) % len(state.servers)
|
| 194 |
-
if not server.busy:
|
| 195 |
-
return server
|
| 196 |
|
| 197 |
-
# If all servers are busy, wait for a short period and check again
|
| 198 |
await asyncio.sleep(0.5)
|
| 199 |
|
| 200 |
-
# Check if timeout has been reached
|
| 201 |
if time.time() - start_time > timeout:
|
| 202 |
raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
|
| 203 |
|
|
@@ -207,21 +233,13 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
|
|
| 207 |
for attempt in range(MAX_RETRIES):
|
| 208 |
server = None
|
| 209 |
try:
|
| 210 |
-
# 1. Get an available server (will wait if all are busy, with a timeout)
|
| 211 |
server = await get_available_server()
|
| 212 |
-
|
| 213 |
-
async with state_lock:
|
| 214 |
-
# Find the server in the global list and mark it busy
|
| 215 |
-
server_in_state = next(s for s in state.servers if s.url == server.url)
|
| 216 |
-
server_in_state.busy = True
|
| 217 |
-
|
| 218 |
start_time = time.time()
|
| 219 |
|
| 220 |
-
# Print a less verbose message only on the first attempt
|
| 221 |
if attempt == 0:
|
| 222 |
print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
|
| 223 |
|
| 224 |
-
# 2. Prepare request data
|
| 225 |
form_data = aiohttp.FormData()
|
| 226 |
form_data.add_field('file',
|
| 227 |
image_path.open('rb'),
|
|
@@ -229,24 +247,21 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
|
|
| 229 |
content_type='image/jpeg')
|
| 230 |
form_data.add_field('model_choice', MODEL_TYPE)
|
| 231 |
|
| 232 |
-
# 3. Send request
|
| 233 |
async with aiohttp.ClientSession() as session:
|
| 234 |
-
# Increased timeout to 10 minutes (600s)
|
| 235 |
async with session.post(server.url, data=form_data, timeout=600) as resp:
|
| 236 |
if resp.status == 200:
|
| 237 |
result = await resp.json()
|
| 238 |
caption = result.get("caption")
|
| 239 |
|
| 240 |
if caption:
|
| 241 |
-
# Update progress counter
|
| 242 |
progress_tracker['completed'] += 1
|
|
|
|
|
|
|
|
|
|
| 243 |
if progress_tracker['completed'] % 50 == 0:
|
| 244 |
print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} captions completed.")
|
| 245 |
|
| 246 |
-
# Log success only if it's not a progress report interval
|
| 247 |
-
if progress_tracker['completed'] % 50 != 0:
|
| 248 |
-
print(f"[{FLOW_ID}] Success: {image_path.name} captioned by {server.url}")
|
| 249 |
-
|
| 250 |
return {
|
| 251 |
"course": course_name,
|
| 252 |
"image_path": image_path.name,
|
|
@@ -255,51 +270,76 @@ async def send_image_for_captioning(image_path: Path, course_name: str, progress
|
|
| 255 |
}
|
| 256 |
else:
|
| 257 |
print(f"[{FLOW_ID}] Server {server.url} returned success but no caption for {image_path.name}. Retrying...")
|
| 258 |
-
continue
|
| 259 |
else:
|
| 260 |
error_text = await resp.text()
|
| 261 |
print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
|
| 262 |
-
continue
|
| 263 |
|
| 264 |
-
except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
continue # Retry with a different server
|
| 268 |
except Exception as e:
|
| 269 |
print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
|
| 270 |
-
continue
|
| 271 |
finally:
|
| 272 |
if server:
|
| 273 |
end_time = time.time()
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
server_in_state = next(s for s in state.servers if s.url == server.url)
|
| 278 |
-
server_in_state.busy = False
|
| 279 |
-
server_in_state.total_processed += 1
|
| 280 |
-
server_in_state.total_time += (end_time - start_time)
|
| 281 |
-
except StopIteration:
|
| 282 |
-
# Server might have been removed while processing
|
| 283 |
-
print(f"[{FLOW_ID}] Warning: Completed task on a server that was likely removed: {server.url}")
|
| 284 |
|
| 285 |
print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
|
| 286 |
return None
|
| 287 |
|
| 288 |
-
async def
|
| 289 |
-
"""
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
caption_filename = Path(zip_full_name).with_suffix('.json').name
|
| 295 |
|
| 296 |
try:
|
| 297 |
print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
|
| 298 |
|
| 299 |
-
# Create JSON content in memory
|
| 300 |
json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
|
| 301 |
|
| 302 |
-
api =
|
| 303 |
api.upload_file(
|
| 304 |
path_or_fileobj=io.BytesIO(json_content),
|
| 305 |
path_in_repo=caption_filename,
|
|
@@ -315,225 +355,147 @@ async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> boo
|
|
| 315 |
print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
|
| 316 |
return False
|
| 317 |
|
| 318 |
-
async def
|
| 319 |
-
"""
|
| 320 |
-
|
| 321 |
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
# List all files in the frames directory
|
| 326 |
-
repo_files = api.list_repo_files(
|
| 327 |
-
repo_id=HF_DATASET_ID,
|
| 328 |
-
repo_type="dataset"
|
| 329 |
-
)
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
if f.startswith(f"frames/{course_name}") and f.endswith('.zip')
|
| 335 |
-
]
|
| 336 |
|
| 337 |
-
if not matching_files:
|
| 338 |
-
print(f"[{FLOW_ID}] No zip files found starting with '{course_name}' in frames/ directory.")
|
| 339 |
-
return None, None, None
|
| 340 |
-
|
| 341 |
async with state_lock:
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
|
| 364 |
-
|
| 365 |
-
# Create a temporary directory for extraction
|
| 366 |
-
extract_dir = TEMP_DIR / course_name
|
| 367 |
-
extract_dir.mkdir(exist_ok=True)
|
| 368 |
-
|
| 369 |
-
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 370 |
-
zip_ref.extractall(extract_dir)
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
# Return the extraction directory, the full zip file name, and the repo path
|
| 375 |
-
return extract_dir, zip_full_name, repo_file_full_path
|
| 376 |
-
|
| 377 |
-
except Exception as e:
|
| 378 |
-
print(f"[{FLOW_ID}] Error downloading or extracting zip for {course_name}: {e}")
|
| 379 |
-
return None, None, None
|
| 380 |
-
|
| 381 |
-
async def process_course_task(course_name: str, start_index: int = 0):
|
| 382 |
-
"""Main task to process a single course, looping until all files are processed."""
|
| 383 |
-
print(f"[{FLOW_ID}] Starting continuous processing for course: {course_name} with start index {start_index}")
|
| 384 |
-
|
| 385 |
-
global_success = True
|
| 386 |
-
|
| 387 |
-
# Update state before starting the loop
|
| 388 |
-
async with state_lock:
|
| 389 |
-
state.last_processed_course = course_name
|
| 390 |
-
state.last_processed_index = start_index
|
| 391 |
-
await save_state_to_hf()
|
| 392 |
-
|
| 393 |
-
# Loop to continuously check for new files matching the course_name prefix
|
| 394 |
-
while True:
|
| 395 |
extract_dir = None
|
| 396 |
zip_full_name = None
|
| 397 |
-
|
| 398 |
|
| 399 |
try:
|
| 400 |
-
|
| 401 |
-
download_result = await download_and_extract_zip(course_name)
|
| 402 |
|
| 403 |
-
if download_result is None
|
| 404 |
-
|
| 405 |
-
print(f"[{FLOW_ID}] No new files found for {course_name}. Exiting loop.")
|
| 406 |
-
break
|
| 407 |
|
| 408 |
-
extract_dir, zip_full_name
|
| 409 |
-
|
| 410 |
-
# --- Start Processing the single file ---
|
| 411 |
|
| 412 |
-
#
|
| 413 |
image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
|
| 414 |
-
|
| 415 |
-
# Apply start_index logic
|
| 416 |
-
if start_index > 0:
|
| 417 |
-
original_count = len(image_paths)
|
| 418 |
-
image_paths = image_paths[start_index:]
|
| 419 |
-
print(f"[{FLOW_ID}] Applying start index {start_index}. Processing {len(image_paths)} images from {original_count} in {zip_full_name}.")
|
| 420 |
-
# Reset start_index for subsequent zip files
|
| 421 |
-
start_index = 0
|
| 422 |
-
else:
|
| 423 |
-
print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
|
| 424 |
-
|
| 425 |
-
current_file_success = False
|
| 426 |
|
| 427 |
if not image_paths:
|
| 428 |
-
print(f"[{FLOW_ID}] No images
|
| 429 |
-
|
| 430 |
else:
|
| 431 |
# Initialize progress tracker
|
| 432 |
progress_tracker = {
|
| 433 |
'total': len(image_paths),
|
| 434 |
'completed': 0
|
| 435 |
}
|
| 436 |
-
print(f"[{FLOW_ID}] Starting captioning for {progress_tracker['total']} images in {zip_full_name}...")
|
| 437 |
-
|
| 438 |
-
# Create a semaphore to limit concurrent tasks to the number of available servers
|
| 439 |
async with state_lock:
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
|
|
|
| 443 |
async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
|
| 444 |
async with semaphore:
|
| 445 |
return await send_image_for_captioning(image_path, course_name, progress_tracker)
|
| 446 |
|
| 447 |
-
# Create a list of tasks for parallel captioning
|
| 448 |
caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
|
| 449 |
-
|
| 450 |
-
# Run all captioning tasks concurrently
|
| 451 |
results = await asyncio.gather(*caption_tasks)
|
| 452 |
-
|
| 453 |
-
# Filter out failed results
|
| 454 |
all_captions = [r for r in results if r is not None]
|
| 455 |
|
| 456 |
-
# Final progress report
|
| 457 |
if len(all_captions) == len(image_paths):
|
| 458 |
print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully completed all {len(all_captions)} captions.")
|
| 459 |
-
|
| 460 |
else:
|
| 461 |
print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} captions.")
|
| 462 |
-
|
| 463 |
|
| 464 |
# Upload results
|
| 465 |
if all_captions and zip_full_name:
|
| 466 |
-
# Use the full zip file name for the upload as requested
|
| 467 |
-
print(f"[{FLOW_ID}] Uploading {len(all_captions)} captions for {zip_full_name}...")
|
| 468 |
if await upload_captions_to_hf(zip_full_name, all_captions):
|
| 469 |
print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
|
| 470 |
-
# If
|
| 471 |
-
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
else:
|
| 474 |
-
print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}.")
|
| 475 |
-
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
else:
|
| 478 |
-
print(f"[{FLOW_ID}] No captions generated or zip_full_name is missing. Skipping upload for {zip_full_name}.")
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
# Mark the file as processed and save state
|
| 485 |
-
if current_file_success:
|
| 486 |
-
async with state_lock:
|
| 487 |
-
state.processed_files.add(repo_file_full_path)
|
| 488 |
-
await save_state_to_hf()
|
| 489 |
|
| 490 |
except Exception as e:
|
| 491 |
error_message = str(e)
|
| 492 |
-
print(f"[{FLOW_ID}] Critical error in
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
finally:
|
| 496 |
-
# Cleanup temporary files
|
| 497 |
if extract_dir and extract_dir.exists():
|
| 498 |
print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
|
| 499 |
shutil.rmtree(extract_dir, ignore_errors=True)
|
| 500 |
-
|
| 501 |
-
# If
|
| 502 |
-
if
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
# --- Final Report after the loop is complete ---
|
| 506 |
-
print(f"[{FLOW_ID}] All processing loops complete for {course_name}.")
|
| 507 |
-
|
| 508 |
-
# Report completion to manager
|
| 509 |
-
final_error_message = error_message if not global_success else None
|
| 510 |
-
await report_completion(course_name, global_success, final_error_message)
|
| 511 |
-
|
| 512 |
-
return global_success
|
| 513 |
-
|
| 514 |
-
async def report_completion(course_name: str, success: bool, error_message: Optional[str] = None):
|
| 515 |
-
"""Reports the task result back to the Manager Server."""
|
| 516 |
-
print(f"[{FLOW_ID}] Reporting completion for {course_name} (Success: {success})...")
|
| 517 |
-
|
| 518 |
-
payload = {
|
| 519 |
-
"flow_id": FLOW_ID,
|
| 520 |
-
"course_name": course_name,
|
| 521 |
-
"success": success,
|
| 522 |
-
"error_message": error_message
|
| 523 |
-
}
|
| 524 |
-
|
| 525 |
-
try:
|
| 526 |
-
async with aiohttp.ClientSession() as session:
|
| 527 |
-
async with session.post(MANAGER_COMPLETE_TASK_URL, json=payload) as resp:
|
| 528 |
-
if resp.status != 200:
|
| 529 |
-
print(f"[{FLOW_ID}] ERROR: Manager reported non-200 status: {resp.status} - {await resp.text()}")
|
| 530 |
-
else:
|
| 531 |
-
print(f"[{FLOW_ID}] Successfully reported completion to Manager.")
|
| 532 |
-
|
| 533 |
-
except aiohttp.ClientError as e:
|
| 534 |
-
print(f"[{FLOW_ID}] CRITICAL ERROR: Could not connect to Manager at {MANAGER_COMPLETE_TASK_URL}. Task completion not reported. Error: {e}")
|
| 535 |
-
except Exception as e:
|
| 536 |
-
print(f"[{FLOW_ID}] Unexpected error during reporting: {e}")
|
| 537 |
|
| 538 |
# --- FastAPI App and Endpoints ---
|
| 539 |
|
|
@@ -543,78 +505,167 @@ app = FastAPI(
|
|
| 543 |
version="2.0.0"
|
| 544 |
)
|
| 545 |
|
|
|
|
|
|
|
|
|
|
| 546 |
@app.on_event("startup")
|
| 547 |
async def startup_event():
|
| 548 |
-
print(f"Flow Server {FLOW_ID} starting up...")
|
| 549 |
-
# Load state first before starting the server
|
| 550 |
-
await load_state_from_hf()
|
| 551 |
print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}. Manager URL: {MANAGER_URL}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
|
| 553 |
@app.get("/", response_class=HTMLResponse)
|
| 554 |
-
async def
|
| 555 |
-
"""
|
| 556 |
async with state_lock:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
context = {
|
| 558 |
"request": request,
|
| 559 |
"flow_id": FLOW_ID,
|
| 560 |
-
"status":
|
| 561 |
-
"
|
| 562 |
-
"
|
| 563 |
-
"
|
| 564 |
-
"
|
| 565 |
-
"
|
| 566 |
-
"
|
| 567 |
-
"
|
|
|
|
|
|
|
|
|
|
| 568 |
}
|
| 569 |
-
return templates.TemplateResponse("
|
| 570 |
|
| 571 |
-
@app.post("/
|
| 572 |
-
async def
|
| 573 |
-
"""
|
| 574 |
-
|
| 575 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
|
| 577 |
async with state_lock:
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
-
@app.post("/
|
| 589 |
-
async def
|
| 590 |
-
"""
|
| 591 |
-
|
| 592 |
-
initial_count = len(state.servers)
|
| 593 |
-
state.servers = [s for s in state.servers if s.url != server_url]
|
| 594 |
-
if len(state.servers) == initial_count:
|
| 595 |
-
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Server not found.")
|
| 596 |
|
| 597 |
-
await
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
@app.post("/process_course")
|
| 601 |
-
async def process_course_api(request: ProcessCourseRequest, background_tasks: BackgroundTasks):
|
| 602 |
-
"""
|
| 603 |
-
Receives a course name and optional start index and starts processing in the background.
|
| 604 |
-
"""
|
| 605 |
-
course_name = request.course_name
|
| 606 |
-
start_index = request.start_index
|
| 607 |
|
| 608 |
-
|
| 609 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
|
| 618 |
if __name__ == "__main__":
|
| 619 |
-
# Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
|
| 620 |
uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
|
|
|
|
| 4 |
import asyncio
|
| 5 |
import aiohttp
|
| 6 |
import zipfile
|
| 7 |
+
import io
|
| 8 |
import shutil
|
| 9 |
from typing import Dict, List, Set, Optional, Any
|
| 10 |
from urllib.parse import quote
|
| 11 |
from datetime import datetime
|
| 12 |
from pathlib import Path
|
|
|
|
| 13 |
|
| 14 |
+
from fastapi import FastAPI, BackgroundTasks, HTTPException, status, Request
|
| 15 |
from fastapi.responses import HTMLResponse
|
| 16 |
from fastapi.templating import Jinja2Templates
|
| 17 |
from pydantic import BaseModel, Field
|
|
|
|
| 19 |
import uvicorn
|
| 20 |
|
| 21 |
# --- Configuration ---
|
|
|
|
| 22 |
FLOW_ID = os.getenv("FLOW_ID", "flow_default")
|
| 23 |
+
FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
|
|
|
|
|
|
|
| 24 |
MANAGER_URL = os.getenv("MANAGER_URL", "https://fred808-fcord.hf.space")
|
| 25 |
MANAGER_COMPLETE_TASK_URL = f"{MANAGER_URL}/task/complete"
|
| 26 |
+
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
|
|
|
|
|
|
| 27 |
HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3")
|
| 28 |
+
HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/helium")
|
| 29 |
+
STATE_FILE_NAME = f"{FLOW_ID}_state.json"
|
| 30 |
|
| 31 |
+
# Using the full list from the user's original code for actual deployment
|
| 32 |
+
CAPTION_SERVERS = [
|
| 33 |
"https://fred808-pil-4-1.hf.space/analyze",
|
| 34 |
"https://fred808-pil-4-2.hf.space/analyze",
|
| 35 |
"https://fred808-pil-4-3.hf.space/analyze",
|
|
|
|
| 74 |
TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
|
| 75 |
TEMP_DIR.mkdir(exist_ok=True)
|
| 76 |
|
|
|
|
|
|
|
|
|
|
| 77 |
# --- Models ---
|
| 78 |
class ProcessCourseRequest(BaseModel):
|
| 79 |
course_name: Optional[str] = None
|
|
|
|
| 80 |
|
| 81 |
+
class CaptionServer:
|
| 82 |
+
def __init__(self, url):
|
| 83 |
+
self.url = url
|
| 84 |
+
self.busy = False
|
| 85 |
+
self.total_processed = 0
|
| 86 |
+
self.total_time = 0
|
| 87 |
+
self.model = MODEL_TYPE
|
| 88 |
|
| 89 |
@property
|
| 90 |
def fps(self):
|
| 91 |
return self.total_processed / self.total_time if self.total_time > 0 else 0
|
| 92 |
|
| 93 |
+
class ServerState(BaseModel):
|
| 94 |
+
# The list of all zip files in the dataset (frames/ directory)
|
| 95 |
+
all_zip_files: List[str] = Field(default_factory=list)
|
| 96 |
+
# The set of zip files that have been successfully processed and uploaded
|
| 97 |
+
processed_files: Set[str] = Field(default_factory=set)
|
| 98 |
+
# The index in all_zip_files from which the next download should start
|
| 99 |
+
current_index: int = 0
|
| 100 |
+
# Total number of files to process
|
| 101 |
+
total_files: int = 0
|
| 102 |
+
# Status of the current operation
|
| 103 |
+
status: str = "Idle"
|
| 104 |
+
# Name of the file currently being processed
|
| 105 |
+
current_file: Optional[str] = None
|
| 106 |
+
# Progress within the current file
|
| 107 |
+
current_file_progress: str = "0/0"
|
| 108 |
+
# Timestamp of the last update
|
| 109 |
+
last_update: str = datetime.now().isoformat()
|
| 110 |
+
# Flag to control the processing loop
|
| 111 |
+
is_running: bool = False
|
| 112 |
+
|
| 113 |
+
# Global state for caption servers and the overall server state
|
| 114 |
+
servers = [CaptionServer(url) for url in CAPTION_SERVERS]
|
| 115 |
server_index = 0
|
| 116 |
+
state = ServerState()
|
| 117 |
+
# Lock for thread-safe access to the global state
|
| 118 |
+
state_lock = asyncio.Lock()
|
| 119 |
+
|
| 120 |
+
# --- Persistence Functions ---
|
| 121 |
|
| 122 |
+
def get_hf_api():
|
| 123 |
+
"""Helper to get HfApi instance."""
|
| 124 |
+
return HfApi(token=HF_TOKEN)
|
| 125 |
+
|
| 126 |
+
def get_hf_fs():
|
| 127 |
+
"""Helper to get HfFileSystem instance."""
|
| 128 |
+
return HfFileSystem(token=HF_TOKEN)
|
| 129 |
|
| 130 |
async def load_state_from_hf():
|
| 131 |
+
"""Loads the state from the Hugging Face output dataset."""
|
| 132 |
global state
|
| 133 |
+
fs = get_hf_fs()
|
| 134 |
+
state_path = f"{HF_OUTPUT_DATASET_ID}/{STATE_FILE_NAME}"
|
| 135 |
+
|
| 136 |
+
async with state_lock:
|
| 137 |
+
try:
|
| 138 |
+
if fs.exists(state_path):
|
| 139 |
+
print(f"[{FLOW_ID}] Loading state from {state_path}...")
|
| 140 |
+
with fs.open(state_path, 'rb') as f:
|
| 141 |
+
data = json.load(f)
|
| 142 |
+
# Convert list of processed files back to a set
|
| 143 |
+
if 'processed_files' in data and isinstance(data['processed_files'], list):
|
| 144 |
+
data['processed_files'] = set(data['processed_files'])
|
| 145 |
+
state = ServerState.parse_obj(data)
|
| 146 |
+
print(f"[{FLOW_ID}] State loaded successfully. Current index: {state.current_index}")
|
| 147 |
+
else:
|
| 148 |
+
print(f"[{FLOW_ID}] State file {state_path} not found. Starting with default state.")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"[{FLOW_ID}] Error loading state from HF: {e}. Starting with default state.")
|
| 151 |
+
state = ServerState()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
async def save_state_to_hf():
|
| 154 |
+
"""Saves the current state to the Hugging Face output dataset."""
|
| 155 |
+
global state
|
| 156 |
+
api = get_hf_api()
|
| 157 |
+
state_path = STATE_FILE_NAME
|
| 158 |
+
|
| 159 |
async with state_lock:
|
| 160 |
+
state.last_update = datetime.now().isoformat()
|
| 161 |
+
# Convert set of processed files to a list for JSON serialization
|
| 162 |
+
data_to_save = state.dict()
|
| 163 |
+
data_to_save['processed_files'] = list(state.processed_files)
|
| 164 |
+
|
| 165 |
+
json_content = json.dumps(data_to_save, indent=2, ensure_ascii=False).encode('utf-8')
|
| 166 |
+
|
| 167 |
try:
|
| 168 |
+
print(f"[{FLOW_ID}] Saving state to {state_path} in {HF_OUTPUT_DATASET_ID}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
api.upload_file(
|
| 170 |
path_or_fileobj=io.BytesIO(json_content),
|
| 171 |
+
path_in_repo=state_path,
|
| 172 |
repo_id=HF_OUTPUT_DATASET_ID,
|
| 173 |
repo_type="dataset",
|
| 174 |
+
commit_message=f"[{FLOW_ID}] Update server state. Index: {state.current_index}"
|
| 175 |
)
|
| 176 |
print(f"[{FLOW_ID}] State saved successfully.")
|
| 177 |
return True
|
| 178 |
except Exception as e:
|
| 179 |
+
print(f"[{FLOW_ID}] Error saving state to HF: {e}")
|
| 180 |
return False
|
| 181 |
|
| 182 |
+
async def update_file_list():
|
| 183 |
+
"""Fetches the list of all zip files from the BG3 dataset."""
|
| 184 |
+
global state
|
| 185 |
+
api = get_hf_api()
|
| 186 |
+
|
| 187 |
+
async with state_lock:
|
| 188 |
+
try:
|
| 189 |
+
state.status = "Updating file list..."
|
| 190 |
+
print(f"[{FLOW_ID}] Fetching file list from {HF_DATASET_ID}...")
|
| 191 |
+
repo_files = api.list_repo_files(
|
| 192 |
+
repo_id=HF_DATASET_ID,
|
| 193 |
+
repo_type="dataset"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Filter for zip files in the 'frames/' directory
|
| 197 |
+
zip_files = sorted([
|
| 198 |
+
f for f in repo_files
|
| 199 |
+
if f.startswith("frames/") and f.endswith('.zip')
|
| 200 |
+
])
|
| 201 |
+
|
| 202 |
+
state.all_zip_files = zip_files
|
| 203 |
+
state.total_files = len(zip_files)
|
| 204 |
+
state.status = "File list updated."
|
| 205 |
+
print(f"[{FLOW_ID}] Found {state.total_files} zip files.")
|
| 206 |
+
except Exception as e:
|
| 207 |
+
state.status = f"Error updating file list: {e}"
|
| 208 |
+
print(f"[{FLOW_ID}] Error updating file list: {e}")
|
| 209 |
+
|
| 210 |
+
await save_state_to_hf()
|
| 211 |
+
|
| 212 |
+
# --- Core Processing Functions (Modified) ---
|
| 213 |
|
| 214 |
async def get_available_server(timeout: float = 300.0) -> CaptionServer:
|
| 215 |
+
"""Round-robin selection of an available caption server."""
|
| 216 |
global server_index
|
| 217 |
start_time = time.time()
|
|
|
|
| 218 |
while True:
|
| 219 |
+
for _ in range(len(servers)):
|
| 220 |
+
server = servers[server_index]
|
| 221 |
+
server_index = (server_index + 1) % len(servers)
|
| 222 |
+
if not server.busy:
|
| 223 |
+
return server
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
|
|
|
| 225 |
await asyncio.sleep(0.5)
|
| 226 |
|
|
|
|
| 227 |
if time.time() - start_time > timeout:
|
| 228 |
raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
|
| 229 |
|
|
|
|
| 233 |
for attempt in range(MAX_RETRIES):
|
| 234 |
server = None
|
| 235 |
try:
|
|
|
|
| 236 |
server = await get_available_server()
|
| 237 |
+
server.busy = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
start_time = time.time()
|
| 239 |
|
|
|
|
| 240 |
if attempt == 0:
|
| 241 |
print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
|
| 242 |
|
|
|
|
| 243 |
form_data = aiohttp.FormData()
|
| 244 |
form_data.add_field('file',
|
| 245 |
image_path.open('rb'),
|
|
|
|
| 247 |
content_type='image/jpeg')
|
| 248 |
form_data.add_field('model_choice', MODEL_TYPE)
|
| 249 |
|
|
|
|
| 250 |
async with aiohttp.ClientSession() as session:
|
|
|
|
| 251 |
async with session.post(server.url, data=form_data, timeout=600) as resp:
|
| 252 |
if resp.status == 200:
|
| 253 |
result = await resp.json()
|
| 254 |
caption = result.get("caption")
|
| 255 |
|
| 256 |
if caption:
|
| 257 |
+
# Update progress counter and global state
|
| 258 |
progress_tracker['completed'] += 1
|
| 259 |
+
async with state_lock:
|
| 260 |
+
state.current_file_progress = f"{progress_tracker['completed']}/{progress_tracker['total']}"
|
| 261 |
+
|
| 262 |
if progress_tracker['completed'] % 50 == 0:
|
| 263 |
print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} captions completed.")
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
return {
|
| 266 |
"course": course_name,
|
| 267 |
"image_path": image_path.name,
|
|
|
|
| 270 |
}
|
| 271 |
else:
|
| 272 |
print(f"[{FLOW_ID}] Server {server.url} returned success but no caption for {image_path.name}. Retrying...")
|
| 273 |
+
continue
|
| 274 |
else:
|
| 275 |
error_text = await resp.text()
|
| 276 |
print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
|
| 277 |
+
continue
|
| 278 |
|
| 279 |
+
except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
|
| 280 |
+
print(f"[{FLOW_ID}] Connection/Timeout error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
|
| 281 |
+
continue
|
|
|
|
| 282 |
except Exception as e:
|
| 283 |
print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
|
| 284 |
+
continue
|
| 285 |
finally:
|
| 286 |
if server:
|
| 287 |
end_time = time.time()
|
| 288 |
+
server.busy = False
|
| 289 |
+
server.total_processed += 1
|
| 290 |
+
server.total_time += (end_time - start_time)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
|
| 293 |
return None
|
| 294 |
|
| 295 |
+
async def download_and_extract_zip(repo_file_full_path: str) -> Optional[tuple[Path, str]]:
|
| 296 |
+
"""Downloads the zip file at the given path and extracts its contents."""
|
| 297 |
|
| 298 |
+
zip_full_name = Path(repo_file_full_path).name
|
| 299 |
+
course_name = zip_full_name.split('_')[0] # Assuming course name is the prefix before the first underscore
|
| 300 |
+
|
| 301 |
+
try:
|
| 302 |
+
print(f"[{FLOW_ID}] Downloading file: {repo_file_full_path}. Full name: {zip_full_name}")
|
| 303 |
+
|
| 304 |
+
# Use hf_hub_download to get the file path
|
| 305 |
+
zip_path = hf_hub_download(
|
| 306 |
+
repo_id=HF_DATASET_ID,
|
| 307 |
+
filename=repo_file_full_path, # Use the full path in the repo
|
| 308 |
+
repo_type="dataset",
|
| 309 |
+
token=HF_TOKEN,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
|
| 313 |
+
|
| 314 |
+
# Create a temporary directory for extraction
|
| 315 |
+
extract_dir = TEMP_DIR / course_name / zip_full_name.replace('.', '_')
|
| 316 |
+
extract_dir.mkdir(parents=True, exist_ok=True)
|
| 317 |
+
|
| 318 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 319 |
+
zip_ref.extractall(extract_dir)
|
| 320 |
+
|
| 321 |
+
print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
|
| 322 |
+
|
| 323 |
+
# Clean up the downloaded zip file
|
| 324 |
+
os.remove(zip_path)
|
| 325 |
+
|
| 326 |
+
# Return the extraction directory and the full zip file name
|
| 327 |
+
return extract_dir, zip_full_name
|
| 328 |
+
|
| 329 |
+
except Exception as e:
|
| 330 |
+
print(f"[{FLOW_ID}] Error downloading or extracting zip for {repo_file_full_path}: {e}")
|
| 331 |
+
return None
|
| 332 |
+
|
| 333 |
+
async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
|
| 334 |
+
"""Uploads the final captions JSON file to the output dataset."""
|
| 335 |
caption_filename = Path(zip_full_name).with_suffix('.json').name
|
| 336 |
|
| 337 |
try:
|
| 338 |
print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
|
| 339 |
|
|
|
|
| 340 |
json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
|
| 341 |
|
| 342 |
+
api = get_hf_api()
|
| 343 |
api.upload_file(
|
| 344 |
path_or_fileobj=io.BytesIO(json_content),
|
| 345 |
path_in_repo=caption_filename,
|
|
|
|
| 355 |
print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
|
| 356 |
return False
|
| 357 |
|
| 358 |
+
async def process_next_file_task():
|
| 359 |
+
"""Task to process the next file in the list based on the current index."""
|
| 360 |
+
global state
|
| 361 |
|
| 362 |
+
if not state.is_running:
|
| 363 |
+
print(f"[{FLOW_ID}] Processing loop is not running. Exiting task.")
|
| 364 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
+
while state.is_running:
|
| 367 |
+
repo_file_full_path = None
|
| 368 |
+
current_index = -1
|
|
|
|
|
|
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
async with state_lock:
|
| 371 |
+
current_index = state.current_index
|
| 372 |
+
if current_index >= state.total_files:
|
| 373 |
+
state.status = "Finished processing all files."
|
| 374 |
+
state.is_running = False
|
| 375 |
+
print(f"[{FLOW_ID}] Reached end of file list. Stopping processing.")
|
| 376 |
+
await save_state_to_hf()
|
| 377 |
+
break
|
| 378 |
+
|
| 379 |
+
repo_file_full_path = state.all_zip_files[current_index]
|
| 380 |
|
| 381 |
+
if repo_file_full_path in state.processed_files:
|
| 382 |
+
state.current_index += 1
|
| 383 |
+
state.status = f"Skipping processed file: {Path(repo_file_full_path).name}"
|
| 384 |
+
state.current_file = Path(repo_file_full_path).name
|
| 385 |
+
print(f"[{FLOW_ID}] Skipping already processed file: {repo_file_full_path}")
|
| 386 |
+
await save_state_to_hf()
|
| 387 |
+
continue
|
| 388 |
+
|
| 389 |
+
# Mark the file as in-progress in the state
|
| 390 |
+
state.status = f"Processing file {current_index + 1}/{state.total_files}"
|
| 391 |
+
state.current_file = Path(repo_file_full_path).name
|
| 392 |
+
state.current_file_progress = "0/0"
|
| 393 |
+
await save_state_to_hf()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
+
# --- Start Processing ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
extract_dir = None
|
| 397 |
zip_full_name = None
|
| 398 |
+
global_success = False
|
| 399 |
|
| 400 |
try:
|
| 401 |
+
download_result = await download_and_extract_zip(repo_file_full_path)
|
|
|
|
| 402 |
|
| 403 |
+
if download_result is None:
|
| 404 |
+
raise Exception("Failed to download or extract zip file.")
|
|
|
|
|
|
|
| 405 |
|
| 406 |
+
extract_dir, zip_full_name = download_result
|
| 407 |
+
course_name = zip_full_name.split('_')[0]
|
|
|
|
| 408 |
|
| 409 |
+
# Find images
|
| 410 |
image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
|
| 411 |
+
print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
if not image_paths:
|
| 414 |
+
print(f"[{FLOW_ID}] No images found in {zip_full_name}. Marking as complete.")
|
| 415 |
+
global_success = True
|
| 416 |
else:
|
| 417 |
# Initialize progress tracker
|
| 418 |
progress_tracker = {
|
| 419 |
'total': len(image_paths),
|
| 420 |
'completed': 0
|
| 421 |
}
|
|
|
|
|
|
|
|
|
|
| 422 |
async with state_lock:
|
| 423 |
+
state.current_file_progress = f"0/{len(image_paths)}"
|
| 424 |
+
await save_state_to_hf()
|
| 425 |
+
|
| 426 |
+
# Create and run captioning tasks
|
| 427 |
+
semaphore = asyncio.Semaphore(len(servers))
|
| 428 |
async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
|
| 429 |
async with semaphore:
|
| 430 |
return await send_image_for_captioning(image_path, course_name, progress_tracker)
|
| 431 |
|
|
|
|
| 432 |
caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
|
|
|
|
|
|
|
| 433 |
results = await asyncio.gather(*caption_tasks)
|
|
|
|
|
|
|
| 434 |
all_captions = [r for r in results if r is not None]
|
| 435 |
|
| 436 |
+
# Final progress report
|
| 437 |
if len(all_captions) == len(image_paths):
|
| 438 |
print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully completed all {len(all_captions)} captions.")
|
| 439 |
+
global_success = True
|
| 440 |
else:
|
| 441 |
print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} captions.")
|
| 442 |
+
global_success = False
|
| 443 |
|
| 444 |
# Upload results
|
| 445 |
if all_captions and zip_full_name:
|
|
|
|
|
|
|
| 446 |
if await upload_captions_to_hf(zip_full_name, all_captions):
|
| 447 |
print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
|
| 448 |
+
# If upload is successful, we mark the file as processed, regardless of partial success
|
| 449 |
+
# The uploaded JSON will reflect the actual number of captions
|
| 450 |
+
if global_success:
|
| 451 |
+
print(f"[{FLOW_ID}] Fully processed and uploaded: {zip_full_name}")
|
| 452 |
+
else:
|
| 453 |
+
print(f"[{FLOW_ID}] Partially processed but uploaded: {zip_full_name}. Needs manual review.")
|
| 454 |
+
|
| 455 |
+
# Mark as processed only if upload succeeded
|
| 456 |
+
async with state_lock:
|
| 457 |
+
state.processed_files.add(repo_file_full_path)
|
| 458 |
+
state.current_index += 1
|
| 459 |
+
state.current_file = None
|
| 460 |
+
state.current_file_progress = "0/0"
|
| 461 |
+
state.status = "Idle"
|
| 462 |
+
await save_state_to_hf()
|
| 463 |
+
|
| 464 |
else:
|
| 465 |
+
print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}. Will retry this file later.")
|
| 466 |
+
# Do NOT increment index or mark as processed, so it will be retried
|
| 467 |
+
async with state_lock:
|
| 468 |
+
state.status = f"Error uploading captions for {zip_full_name}. Retrying later."
|
| 469 |
+
await save_state_to_hf()
|
| 470 |
+
# Wait before retrying to avoid immediate re-attempt on a transient error
|
| 471 |
+
await asyncio.sleep(60)
|
| 472 |
+
|
| 473 |
else:
|
| 474 |
+
print(f"[{FLOW_ID}] No captions generated or zip_full_name is missing. Skipping upload for {zip_full_name}. Will retry later.")
|
| 475 |
+
# Do NOT increment index or mark as processed
|
| 476 |
+
async with state_lock:
|
| 477 |
+
state.status = f"No captions generated for {zip_full_name}. Retrying later."
|
| 478 |
+
await save_state_to_hf()
|
| 479 |
+
await asyncio.sleep(60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
except Exception as e:
|
| 482 |
error_message = str(e)
|
| 483 |
+
print(f"[{FLOW_ID}] Critical error in process_next_file_task for {repo_file_full_path}: {error_message}")
|
| 484 |
+
async with state_lock:
|
| 485 |
+
state.status = f"CRITICAL ERROR for {Path(repo_file_full_path).name}. Retrying later. Error: {error_message[:50]}..."
|
| 486 |
+
await save_state_to_hf()
|
| 487 |
+
# Wait before retrying
|
| 488 |
+
await asyncio.sleep(60)
|
| 489 |
|
| 490 |
finally:
|
| 491 |
+
# Cleanup temporary files
|
| 492 |
if extract_dir and extract_dir.exists():
|
| 493 |
print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
|
| 494 |
shutil.rmtree(extract_dir, ignore_errors=True)
|
| 495 |
+
|
| 496 |
+
# If the loop is still running, wait a short time before checking for the next file
|
| 497 |
+
if state.is_running:
|
| 498 |
+
await asyncio.sleep(5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
# --- FastAPI App and Endpoints ---
|
| 501 |
|
|
|
|
| 505 |
version="2.0.0"
|
| 506 |
)
|
| 507 |
|
| 508 |
+
# Setup Jinja2 templates for the UI
|
| 509 |
+
templates = Jinja2Templates(directory="templates")
|
| 510 |
+
|
| 511 |
@app.on_event("startup")
|
| 512 |
async def startup_event():
|
|
|
|
|
|
|
|
|
|
| 513 |
print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}. Manager URL: {MANAGER_URL}")
|
| 514 |
+
# 1. Load state from persistence (HF)
|
| 515 |
+
await load_state_from_hf()
|
| 516 |
+
# 2. Update the list of all files from the dataset
|
| 517 |
+
await update_file_list()
|
| 518 |
+
# 3. Start the continuous processing task if the index is valid
|
| 519 |
+
if state.current_index < state.total_files:
|
| 520 |
+
state.is_running = True
|
| 521 |
+
BackgroundTasks().add_task(process_next_file_task)
|
| 522 |
+
else:
|
| 523 |
+
state.is_running = False
|
| 524 |
+
print(f"[{FLOW_ID}] Index {state.current_index} is out of bounds. Starting in Idle mode.")
|
| 525 |
+
|
| 526 |
|
| 527 |
@app.get("/", response_class=HTMLResponse)
|
| 528 |
+
async def home(request: Request):
|
| 529 |
+
"""Home page with status and controls."""
|
| 530 |
async with state_lock:
|
| 531 |
+
processed_count = len(state.processed_files)
|
| 532 |
+
remaining_count = state.total_files - processed_count
|
| 533 |
+
|
| 534 |
+
# Calculate server stats
|
| 535 |
+
server_stats = [
|
| 536 |
+
{
|
| 537 |
+
"url": s.url,
|
| 538 |
+
"busy": s.busy,
|
| 539 |
+
"processed": s.total_processed,
|
| 540 |
+
"fps": f"{s.fps:.2f}"
|
| 541 |
+
} for s in servers
|
| 542 |
+
]
|
| 543 |
+
|
| 544 |
+
# Calculate overall FPS
|
| 545 |
+
total_processed = sum(s.total_processed for s in servers)
|
| 546 |
+
total_time = sum(s.total_time for s in servers)
|
| 547 |
+
overall_fps = total_processed / total_time if total_time > 0 else 0
|
| 548 |
+
|
| 549 |
context = {
|
| 550 |
"request": request,
|
| 551 |
"flow_id": FLOW_ID,
|
| 552 |
+
"status": state.status,
|
| 553 |
+
"is_running": state.is_running,
|
| 554 |
+
"total_files": state.total_files,
|
| 555 |
+
"processed_count": processed_count,
|
| 556 |
+
"remaining_count": remaining_count,
|
| 557 |
+
"current_index": state.current_index,
|
| 558 |
+
"current_file": state.current_file if state.current_file else "N/A",
|
| 559 |
+
"current_file_progress": state.current_file_progress,
|
| 560 |
+
"last_update": state.last_update,
|
| 561 |
+
"overall_fps": f"{overall_fps:.2f}",
|
| 562 |
+
"server_stats": server_stats
|
| 563 |
}
|
| 564 |
+
return templates.TemplateResponse("index.html", context)
|
| 565 |
|
| 566 |
+
@app.post("/set_index")
|
| 567 |
+
async def set_index(request: Request, background_tasks: BackgroundTasks):
|
| 568 |
+
"""Endpoint to manually set the start index."""
|
| 569 |
+
global state
|
| 570 |
+
|
| 571 |
+
form = await request.form()
|
| 572 |
+
try:
|
| 573 |
+
new_index = int(form.get("start_index"))
|
| 574 |
+
except (TypeError, ValueError):
|
| 575 |
+
raise HTTPException(status_code=400, detail="Invalid index value.")
|
| 576 |
|
| 577 |
async with state_lock:
|
| 578 |
+
if 0 <= new_index < state.total_files:
|
| 579 |
+
state.current_index = new_index
|
| 580 |
+
state.status = f"Index set to {new_index}. Restarting processing."
|
| 581 |
+
|
| 582 |
+
# If the loop is not running, start it
|
| 583 |
+
if not state.is_running:
|
| 584 |
+
state.is_running = True
|
| 585 |
+
background_tasks.add_task(process_next_file_task)
|
| 586 |
+
|
| 587 |
+
await save_state_to_hf()
|
| 588 |
+
print(f"[{FLOW_ID}] Index manually set to {new_index}.")
|
| 589 |
+
return {"status": "success", "message": f"Start index set to {new_index}. Processing will resume from this point."}
|
| 590 |
+
elif new_index == state.total_files:
|
| 591 |
+
state.current_index = new_index
|
| 592 |
+
state.is_running = False
|
| 593 |
+
state.status = "Finished processing all files."
|
| 594 |
+
await save_state_to_hf()
|
| 595 |
+
return {"status": "success", "message": "Index set to end of list. Processing stopped."}
|
| 596 |
+
else:
|
| 597 |
+
raise HTTPException(status_code=400, detail=f"Index {new_index} is out of bounds (0 to {state.total_files}).")
|
| 598 |
|
| 599 |
+
@app.post("/control_processing")
|
| 600 |
+
async def control_processing(request: Request, background_tasks: BackgroundTasks):
|
| 601 |
+
"""Endpoint to start/stop the processing loop."""
|
| 602 |
+
global state
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
+
form = await request.form()
|
| 605 |
+
action = form.get("action")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
+
async with state_lock:
|
| 608 |
+
if action == "start":
|
| 609 |
+
if not state.is_running and state.current_index < state.total_files:
|
| 610 |
+
state.is_running = True
|
| 611 |
+
state.status = "Processing started."
|
| 612 |
+
background_tasks.add_task(process_next_file_task)
|
| 613 |
+
await save_state_to_hf()
|
| 614 |
+
return {"status": "success", "message": "Processing loop started."}
|
| 615 |
+
elif state.current_index >= state.total_files:
|
| 616 |
+
return {"status": "error", "message": "Cannot start. All files have been processed."}
|
| 617 |
+
else:
|
| 618 |
+
return {"status": "info", "message": "Processing is already running."}
|
| 619 |
+
elif action == "stop":
|
| 620 |
+
if state.is_running:
|
| 621 |
+
state.is_running = False
|
| 622 |
+
state.status = "Processing stopped by user."
|
| 623 |
+
await save_state_to_hf()
|
| 624 |
+
return {"status": "success", "message": "Processing loop stopped."}
|
| 625 |
+
else:
|
| 626 |
+
return {"status": "info", "message": "Processing is already stopped."}
|
| 627 |
+
else:
|
| 628 |
+
raise HTTPException(status_code=400, detail="Invalid action.")
|
| 629 |
+
|
| 630 |
+
@app.get("/status")
|
| 631 |
+
async def get_status():
|
| 632 |
+
"""API endpoint to get the current server status as JSON."""
|
| 633 |
+
async with state_lock:
|
| 634 |
+
processed_count = len(state.processed_files)
|
| 635 |
|
| 636 |
+
server_stats = [
|
| 637 |
+
{
|
| 638 |
+
"url": s.url,
|
| 639 |
+
"busy": s.busy,
|
| 640 |
+
"processed": s.total_processed,
|
| 641 |
+
"fps": f"{s.fps:.2f}"
|
| 642 |
+
} for s in servers
|
| 643 |
+
]
|
| 644 |
+
|
| 645 |
+
total_processed = sum(s.total_processed for s in servers)
|
| 646 |
+
total_time = sum(s.total_time for s in servers)
|
| 647 |
+
overall_fps = total_processed / total_time if total_time > 0 else 0
|
| 648 |
+
|
| 649 |
+
return {
|
| 650 |
+
"flow_id": FLOW_ID,
|
| 651 |
+
"status": state.status,
|
| 652 |
+
"is_running": state.is_running,
|
| 653 |
+
"total_files": state.total_files,
|
| 654 |
+
"processed_count": processed_count,
|
| 655 |
+
"remaining_count": state.total_files - processed_count,
|
| 656 |
+
"current_index": state.current_index,
|
| 657 |
+
"current_file": state.current_file,
|
| 658 |
+
"current_file_progress": state.current_file_progress,
|
| 659 |
+
"last_update": state.last_update,
|
| 660 |
+
"overall_fps": f"{overall_fps:.2f}",
|
| 661 |
+
"server_stats": server_stats
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
# The original /process_course endpoint is now obsolete as the server manages its own queue
|
| 665 |
+
# @app.post("/process_course")
|
| 666 |
+
# async def process_course(request: ProcessCourseRequest, background_tasks: BackgroundTasks):
|
| 667 |
+
# return {"status": "obsolete", "message": "The server now manages its own processing queue based on the index."}
|
| 668 |
+
|
| 669 |
|
| 670 |
if __name__ == "__main__":
|
|
|
|
| 671 |
uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)
|