Fred808 commited on
Commit
f5f8b6b
·
verified ·
1 Parent(s): 5b60a10

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +475 -360
app.py CHANGED
@@ -1,360 +1,475 @@
1
- import os
2
- import json
3
- import asyncio
4
- import aiohttp
5
- from typing import Dict, List, Set, Optional
6
- from pathlib import Path
7
- from datetime import datetime
8
-
9
- from fastapi import FastAPI, BackgroundTasks, HTTPException, status
10
- from pydantic import BaseModel
11
- from huggingface_hub import HfApi, HfFileSystem
12
- import uvicorn
13
-
14
- # --- Configuration ---
15
- # Manager Server will run on port 8000
16
- MANAGER_PORT = 8000
17
-
18
- # Hugging Face Configuration
19
- HF_TOKEN = "" # User provided token
20
- HF_DATASET_ID = "Fred808/BG3" # Dataset where the zip files are located
21
- HF_DATASET_REPO_TYPE = "dataset"
22
- FRAMES_FOLDER = "frames"
23
- STATE_FILE_PATH = "flow_processing_state.json"
24
-
25
- # Flow Server Configuration (Hardcoded as per user request)
26
- # NOTE: These URLs must be accessible to the Manager Server.
27
- # For local testing, you might use localhost with different ports (e.g., 8001 and 8002)
28
- FLOW_SERVERS = {
29
- "flow1": "http://localhost:8001",
30
- "flow2": "https://fred808-flowcap2.hf.space",
31
- }
32
-
33
- # --- State Management Models ---
34
- class TaskStatus(BaseModel):
35
- status: str # UNPROCESSED, IN_PROGRESS, COMPLETED, FAILED
36
- assigned_to: Optional[str] = None # flow1 or flow2
37
- assigned_at: Optional[datetime] = None
38
- completed_at: Optional[datetime] = None
39
- error_message: Optional[str] = None
40
-
41
- class ProcessingState(BaseModel):
42
- # Key is the zip file name (which is the course name)
43
- tasks: Dict[str, TaskStatus] = {}
44
-
45
- # Track which flow server is currently processing which task
46
- flow_assignments: Dict[str, Optional[str]] = {
47
- "flow1": None,
48
- "flow2": None,
49
- }
50
-
51
- class CompleteTaskRequest(BaseModel):
52
- flow_id: str
53
- course_name: str
54
- success: bool
55
- error_message: Optional[str] = None
56
-
57
- # --- Global State and Initialization ---
58
- app = FastAPI(
59
- title="BG3 Processing Manager",
60
- description="Coordinates flow servers for BG3 dataset processing.",
61
- version="1.0.0"
62
- )
63
-
64
- api = HfApi(token=HF_TOKEN)
65
- fs = HfFileSystem(token=HF_TOKEN)
66
- state = ProcessingState()
67
- is_coordinating = False
68
-
69
- # --- Persistence Functions ---
70
-
71
- def get_full_path(filename: str) -> str:
72
- return f"{HF_DATASET_ID}/{filename}"
73
-
74
- async def load_state_from_hf():
75
- global state
76
- try:
77
- # Check if state file exists
78
- if fs.exists(f"{HF_DATASET_ID}/{STATE_FILE_PATH}"):
79
- print(f"Loading state from {STATE_FILE_PATH}...")
80
-
81
- # Use HfApi to download the file content
82
- content = api.read_file(
83
- path_in_repo=STATE_FILE_PATH,
84
- repo_id=HF_DATASET_ID,
85
- repo_type=HF_DATASET_REPO_TYPE
86
- ).decode('utf-8')
87
-
88
- data = json.loads(content)
89
- state = ProcessingState(**data)
90
- print(f"State loaded. Total tasks: {len(state.tasks)}")
91
- else:
92
- print(f"State file {STATE_FILE_PATH} not found. Initializing.")
93
- await initialize_tasks()
94
- await save_state_to_hf() # Save initial state
95
-
96
- except Exception as e:
97
- print(f"Error loading state from HF: {e}")
98
- # Fallback to initialization if loading fails
99
- await initialize_tasks()
100
-
101
- async def save_state_to_hf():
102
- try:
103
- print(f"Saving state to {STATE_FILE_PATH}...")
104
- content = state.model_dump_json(indent=2).encode('utf-8')
105
-
106
- api.upload_file(
107
- path_or_fileobj=content,
108
- path_in_repo=STATE_FILE_PATH,
109
- repo_id=HF_DATASET_ID,
110
- repo_type=HF_DATASET_REPO_TYPE,
111
- commit_message="Update processing state"
112
- )
113
- print("State saved successfully.")
114
- except Exception as e:
115
- print(f"Error saving state to HF: {e}")
116
-
117
- async def initialize_tasks():
118
- global state
119
- print(f"Discovering zip files in {FRAMES_FOLDER}/...")
120
-
121
- # 1. Fetch the list of valid course names from Fred808/BG1
122
- print("Fetching valid course names from Fred808/BG1...")
123
- try:
124
- # Assuming the 'BG1' dataset contains a file listing the course names.
125
- # We will use the base name of files in the root of the BG1 dataset as the list of valid course names.
126
-
127
- bg1_files = fs.ls("Fred808/BG1", detail=False)
128
- # We use Path(f).stem to get the name without extension (e.g., 'course_name.zip' -> 'course_name')
129
- valid_course_names = {Path(f).stem for f in bg1_files if not Path(f).name.startswith('.')}
130
-
131
- if not valid_course_names:
132
- print("Warning: Fred808/BG1 dataset seems empty or contains no processable files. Using all found zip files.")
133
- # Fallback to using all zip files found in BG3 if BG1 is empty
134
-
135
- except Exception as e:
136
- print(f"Error fetching course names from Fred808/BG1: {e}. Falling back to all zip files in BG3.")
137
- valid_course_names = set()
138
-
139
- # 2. List zip files in the frames folder of the main dataset (BG3)
140
- try:
141
- file_list = fs.ls(f"{HF_DATASET_ID}/{FRAMES_FOLDER}", detail=False)
142
-
143
- zip_files = [
144
- Path(f).name
145
- for f in file_list
146
- if f.endswith(".zip") and not f.endswith(".zip.json")
147
- ]
148
-
149
- new_tasks = {}
150
- for zip_file in zip_files:
151
- course_name = zip_file.replace(".zip", "")
152
-
153
- # 3. Filter: Only process if the course name is in the valid list from BG1 (if non-empty)
154
- if valid_course_names and course_name not in valid_course_names:
155
- print(f"Skipping {course_name}: Not found in Fred808/BG1 list.")
156
- continue
157
-
158
- if course_name not in state.tasks:
159
- new_tasks[course_name] = TaskStatus(status="UNPROCESSED")
160
- else:
161
- # Keep existing status if it was already tracked
162
- new_tasks[course_name] = state.tasks[course_name]
163
-
164
- # Merge new tasks with existing state, only keeping tasks that still exist as zip files
165
- # This prevents old, deleted zip files from persisting in the state.
166
- existing_tasks_to_keep = {
167
- k: v for k, v in state.tasks.items()
168
- if k in new_tasks or v.status in ["IN_PROGRESS", "COMPLETED", "FAILED"] # Keep history
169
- }
170
-
171
- # Prioritize new tasks over existing ones for the latest status
172
- state.tasks = {**existing_tasks_to_keep, **new_tasks}
173
-
174
- print(f"Found {len(zip_files)} zip files in {HF_DATASET_ID}/{FRAMES_FOLDER}. Valid course names from BG1: {len(valid_course_names)}. Total tasks tracked: {len(state.tasks)}")
175
-
176
- except Exception as e:
177
- print(f"Error discovering files from HF: {e}")
178
- # If discovery fails, we can't proceed.
179
- raise RuntimeError(f"Failed to discover files: {e}")
180
-
181
- # --- Core Coordination Logic ---
182
-
183
- async def assign_next_task(flow_id: str):
184
- """
185
- Finds the next UNPROCESSED task and assigns it to the given flow server.
186
- """
187
- global state
188
-
189
- # 1. Find an UNPROCESSED task
190
- next_course = None
191
- for course_name, task_status in state.tasks.items():
192
- if task_status.status == "UNPROCESSED":
193
- next_course = course_name
194
- break
195
-
196
- if next_course is None:
197
- print(f"No UNPROCESSED tasks left for {flow_id}.")
198
- course_to_assign = None
199
-
200
- else:
201
- # 2. Update state to IN_PROGRESS
202
- state.tasks[next_course] = TaskStatus(
203
- status="IN_PROGRESS",
204
- assigned_to=flow_id,
205
- assigned_at=datetime.now()
206
- )
207
- state.flow_assignments[flow_id] = next_course
208
- course_to_assign = next_course
209
-
210
- # 3. Persist state change
211
- await save_state_to_hf()
212
-
213
- # 4. Notify the Flow Server
214
- flow_url = FLOW_SERVERS.get(flow_id)
215
- if not flow_url:
216
- print(f"Error: Unknown flow_id {flow_id}")
217
- return
218
-
219
- try:
220
- print(f"Assigning '{course_to_assign}' to {flow_id} at {flow_url}/process_course")
221
- async with aiohttp.ClientSession() as session:
222
- async with session.post(
223
- f"{flow_url}/process_course",
224
- json={"course_name": course_to_assign}
225
- ) as response:
226
- if response.status != 200:
227
- print(f"Error sending task to {flow_id}: {response.status} - {await response.text()}")
228
- # Revert state if assignment fails
229
- if next_course:
230
- state.tasks[next_course] = TaskStatus(status="UNPROCESSED")
231
- state.flow_assignments[flow_id] = None
232
- await save_state_to_hf()
233
- else:
234
- print(f"Successfully assigned {course_to_assign} to {flow_id}.")
235
-
236
- except aiohttp.ClientConnectorError as e:
237
- print(f"Connection Error: Could not connect to {flow_id} at {flow_url}. Reverting task status. Error: {e}")
238
- # Revert state if connection fails
239
- if next_course:
240
- state.tasks[next_course] = TaskStatus(status="UNPROCESSED")
241
- state.flow_assignments[flow_id] = None
242
- await save_state_to_hf()
243
- except Exception as e:
244
- print(f"Unexpected error during assignment to {flow_id}. Error: {e}")
245
- # Revert state for safety
246
- if next_course:
247
- state.tasks[next_course] = TaskStatus(status="UNPROCESSED")
248
- state.flow_assignments[flow_id] = None
249
- await save_state_to_hf()
250
-
251
-
252
- async def coordinate_loop():
253
- """
254
- The main coordination loop that runs in the background.
255
- """
256
- global is_coordinating
257
- if is_coordinating:
258
- print("Coordinator is already running.")
259
- return
260
-
261
- is_coordinating = True
262
- print("Starting coordination loop...")
263
-
264
- try:
265
- # Load state and initialize tasks on startup
266
- await load_state_from_hf()
267
-
268
- # Check and assign tasks to any free flow server
269
- for flow_id in FLOW_SERVERS.keys():
270
- if state.flow_assignments.get(flow_id) is None:
271
- asyncio.create_task(assign_next_task(flow_id))
272
-
273
- except Exception as e:
274
- print(f"Coordination loop failed to start: {e}")
275
- finally:
276
- # The loop is now event-driven based on /task/complete calls
277
- pass
278
-
279
- # --- API Endpoints ---
280
-
281
- @app.on_event("startup")
282
- async def startup_event():
283
- # Start the coordination loop as a background task
284
- BackgroundTasks().add_task(coordinate_loop)
285
-
286
- @app.get("/")
287
- async def root():
288
- return {
289
- "message": "BG3 Processing Manager API",
290
- "status": "running",
291
- "is_coordinating": is_coordinating,
292
- "flow_assignments": state.flow_assignments,
293
- "total_tasks": len(state.tasks),
294
- "unprocessed": sum(1 for t in state.tasks.values() if t.status == "UNPROCESSED"),
295
- "in_progress": sum(1 for t in state.tasks.values() if t.status == "IN_PROGRESS"),
296
- "completed": sum(1 for t in state.tasks.values() if t.status == "COMPLETED"),
297
- }
298
-
299
- @app.post("/task/complete")
300
- async def task_complete(request: CompleteTaskRequest):
301
- """
302
- Endpoint for flow servers to report task completion.
303
- """
304
- global state
305
- flow_id = request.flow_id
306
- course_name = request.course_name
307
-
308
- if course_name not in state.tasks:
309
- raise HTTPException(status_code=404, detail=f"Unknown course: {course_name}")
310
-
311
- task = state.tasks[course_name]
312
-
313
- if task.assigned_to != flow_id:
314
- # This is a safety check, should not happen in normal operation
315
- print(f"Warning: {flow_id} reported completion for a task not assigned to it: {course_name}")
316
-
317
- if request.success:
318
- print(f"Task COMPLETED: {course_name} by {flow_id}")
319
- task.status = "COMPLETED"
320
- task.completed_at = datetime.now()
321
- task.error_message = None
322
- else:
323
- print(f"Task FAILED: {course_name} by {flow_id}. Error: {request.error_message}")
324
- # For now, mark as FAILED. A more robust system might retry.
325
- task.status = "FAILED"
326
- task.completed_at = datetime.now()
327
- task.error_message = request.error_message
328
-
329
- # Free up the flow server slot
330
- state.flow_assignments[flow_id] = None
331
-
332
- # Persist state change
333
- await save_state_to_hf()
334
-
335
- # Assign the next task to the now-free flow server
336
- asyncio.create_task(assign_next_task(flow_id))
337
-
338
- return {"status": "success", "message": f"Task {course_name} marked as {'COMPLETED' if request.success else 'FAILED'}. Next task assigned."}
339
-
340
- @app.post("/start_coordination")
341
- async def start_coordination(background_tasks: BackgroundTasks):
342
- """
343
- Manually trigger the coordination loop.
344
- """
345
- if is_coordinating:
346
- return {"status": "info", "message": "Coordination is already running."}
347
-
348
- background_tasks.add_task(coordinate_loop)
349
- return {"status": "success", "message": "Coordination loop started."}
350
-
351
- @app.get("/state")
352
- async def get_state():
353
- """
354
- Returns the current processing state.
355
- """
356
- return state
357
-
358
- if __name__ == "__main__":
359
- # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
360
- uvicorn.run(app, host="0.0.0.0", port=MANAGER_PORT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import asyncio
5
+ import aiohttp
6
+ import zipfile
7
+ from typing import Dict, List, Set, Optional
8
+ from urllib.parse import quote
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ import io
12
+
13
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, status
14
+ from pydantic import BaseModel, Field
15
+ from huggingface_hub import HfApi, hf_hub_download
16
+ import uvicorn
17
+
18
+ # --- Configuration ---
19
+ # Flow Server ID and Port will be set via environment variables for easy deployment
20
+ FLOW_ID = os.getenv("FLOW_ID", "flow_default")
21
+ FLOW_PORT = int(os.getenv("FLOW_PORT", 8001)) # Default to 8001 for flow1
22
+
23
+ # Manager Server Configuration
24
+ MANAGER_URL = os.getenv("MANAGER_URL", "https://fred808-fcord.hf.space")
25
+ MANAGER_COMPLETE_TASK_URL = f"{MANAGER_URL}/task/complete"
26
+
27
+ # Hugging Face Configuration
28
+ HF_TOKEN = os.getenv("HF_TOKEN", "") # User provided token
29
+ HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3")
30
+ HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/helium") # Target dataset for captions
31
+
32
+ # Using the full list from the user's original code for actual deployment
33
+ CAPTION_SERVERS = [
34
+ "https://fred808-pil-4-1.hf.space/analyze",
35
+ "https://fred808-pil-4-2.hf.space/analyze",
36
+ "https://fred808-pil-4-3.hf.space/analyze",
37
+ "https://fred1012-fred1012-gw0j2h.hf.space/analyze",
38
+ "https://fred1012-fred1012-wqs6c2.hf.space/analyze",
39
+ "https://fred1012-fred1012-oncray.hf.space/analyze",
40
+ "https://fred1012-fred1012-4goge7.hf.space/analyze",
41
+ "https://fred1012-fred1012-z0eh7m.hf.space/analyze",
42
+ "https://fred1012-fred1012-u95rte.hf.space/analyze",
43
+ "https://fred1012-fred1012-igje22.hf.space/analyze",
44
+ "https://fred1012-fred1012-ibkuf8.hf.space/analyze",
45
+ "https://fred1012-fred1012-nwqthy.hf.space/analyze",
46
+ "https://fred1012-fred1012-4ldqj4.hf.space/analyze",
47
+ "https://fred1012-fred1012-pivlzg.hf.space/analyze",
48
+ "https://fred1012-fred1012-ptlc5u.hf.space/analyze",
49
+ "https://fred1012-fred1012-u7lh57.hf.space/analyze",
50
+ "https://fred1012-fred1012-q8djv1.hf.space/analyze",
51
+ "https://fredalone-fredalone-ozugrp.hf.space/analyze",
52
+ "https://fredalone-fredalone-9brxj2.hf.space/analyze",
53
+ "https://fredalone-fredalone-p8vq9a.hf.space/analyze",
54
+ "https://fredalone-fredalone-vbli2y.hf.space/analyze",
55
+ "https://fredalone-fredalone-uggger.hf.space/analyze",
56
+ "https://fredalone-fredalone-nmi7e8.hf.space/analyze",
57
+ "https://fredalone-fredalone-d1f26d.hf.space/analyze",
58
+ "https://fredalone-fredalone-461jp2.hf.space/analyze",
59
+ "https://fredalone-fredalone-3enfg4.hf.space/analyze",
60
+ "https://fredalone-fredalone-dqdbpv.hf.space/analyze",
61
+ "https://fredalone-fredalone-ivtjua.hf.space/analyze",
62
+ "https://fredalone-fredalone-6bezt2.hf.space/analyze",
63
+ "https://fredalone-fredalone-e0wfnk.hf.space/analyze",
64
+ "https://fredalone-fredalone-zu2t7j.hf.space/analyze",
65
+ "https://fredalone-fredalone-dqtv1o.hf.space/analyze",
66
+ "https://fredalone-fredalone-wclyog.hf.space/analyze",
67
+ "https://fredalone-fredalone-t27vig.hf.space/analyze",
68
+ "https://fredalone-fredalone-gahbxh.hf.space/analyze",
69
+ "https://fredalone-fredalone-kw2po4.hf.space/analyze",
70
+ "https://fredalone-fredalone-8h285h.hf.space/analyze"
71
+ ]
72
+ MODEL_TYPE = "Florence-2-large"
73
+
74
+ # Temporary storage for images
75
+ TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
76
+ TEMP_DIR.mkdir(exist_ok=True)
77
+
78
+ # --- Models ---
79
+ class ProcessCourseRequest(BaseModel):
80
+ course_name: Optional[str] = None
81
+
82
+ class CaptionServer:
83
+ def __init__(self, url):
84
+ self.url = url
85
+ self.busy = False
86
+ self.total_processed = 0
87
+ self.total_time = 0
88
+ self.model = MODEL_TYPE
89
+
90
+ @property
91
+ def fps(self):
92
+ return self.total_processed / self.total_time if self.total_time > 0 else 0
93
+
94
+ # Global state for caption servers
95
+ servers = [CaptionServer(url) for url in CAPTION_SERVERS]
96
+ server_index = 0
97
+
98
+ # --- Core Processing Functions ---
99
+
100
+ async def get_available_server(timeout: float = 300.0) -> CaptionServer:
101
+ """Round-robin selection of an available caption server."""
102
+ global server_index
103
+ start_time = time.time()
104
+ while True:
105
+ # Round-robin check for an available server
106
+ for _ in range(len(servers)):
107
+ server = servers[server_index]
108
+ server_index = (server_index + 1) % len(servers)
109
+ if not server.busy:
110
+ return server
111
+
112
+ # If all servers are busy, wait for a short period and check again
113
+ await asyncio.sleep(0.5)
114
+
115
+ # Check if timeout has been reached
116
+ if time.time() - start_time > timeout:
117
+ raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")
118
+
119
+ async def send_image_for_captioning(image_path: Path, course_name: str, progress_tracker: Dict) -> Optional[Dict]:
120
+ """Sends a single image to a caption server for processing."""
121
+ # This function now handles server selection and retries internally
122
+ MAX_RETRIES = 3
123
+ for attempt in range(MAX_RETRIES):
124
+ server = None
125
+ try:
126
+ # 1. Get an available server (will wait if all are busy, with a timeout)
127
+ server = await get_available_server()
128
+ server.busy = True
129
+ start_time = time.time()
130
+
131
+ # Print a less verbose message only on the first attempt
132
+ if attempt == 0:
133
+ print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
134
+
135
+ # 2. Prepare request data
136
+ form_data = aiohttp.FormData()
137
+ form_data.add_field('file',
138
+ image_path.open('rb'),
139
+ filename=image_path.name,
140
+ content_type='image/jpeg')
141
+ form_data.add_field('model_choice', MODEL_TYPE)
142
+
143
+ # 3. Send request
144
+ async with aiohttp.ClientSession() as session:
145
+ # Increased timeout to 10 minutes (600s) as requested by user's problem description
146
+ async with session.post(server.url, data=form_data, timeout=600) as resp:
147
+ if resp.status == 200:
148
+ result = await resp.json()
149
+ caption = result.get("caption")
150
+
151
+ if caption:
152
+ # Update progress counter
153
+ progress_tracker['completed'] += 1
154
+ if progress_tracker['completed'] % 50 == 0:
155
+ print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} captions completed.")
156
+
157
+ # Log success only if it's not a progress report interval
158
+ if progress_tracker['completed'] % 50 != 0:
159
+ print(f"[{FLOW_ID}] Success: {image_path.name} captioned by {server.url}")
160
+
161
+ return {
162
+ "course": course_name,
163
+ "image_path": image_path.name,
164
+ "caption": caption,
165
+ "timestamp": datetime.now().isoformat()
166
+ }
167
+ else:
168
+ print(f"[{FLOW_ID}] Server {server.url} returned success but no caption for {image_path.name}. Retrying...")
169
+ continue # Retry with a different server
170
+ else:
171
+ error_text = await resp.text()
172
+ print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
173
+ continue # Retry with a different server
174
+
175
+ except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
176
+ print(f"[{FLOW_ID}] Connection/Timeout error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
177
+ continue # Retry with a different server
178
+ except Exception as e:
179
+ print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
180
+ continue # Retry with a different server
181
+ finally:
182
+ if server:
183
+ end_time = time.time()
184
+ server.busy = False
185
+ server.total_processed += 1
186
+ server.total_time += (end_time - start_time)
187
+
188
+ print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
189
+ return None
190
+
191
+ async def download_and_extract_zip(course_name: str, processed_files: Set[str]) -> Optional[tuple[Path, str, str]]:
192
+ """Downloads the zip file for the course and extracts its contents."""
193
+ print(f"[{FLOW_ID}] Looking for files starting with '{course_name}' in frames/ directory...")
194
+
195
+ try:
196
+ api = HfApi(token=HF_TOKEN)
197
+
198
+ # List all files in the frames directory
199
+ repo_files = api.list_repo_files(
200
+ repo_id=HF_DATASET_ID,
201
+ repo_type="dataset"
202
+ )
203
+
204
+ # Find zip files that start with the course name
205
+ matching_files = [
206
+ f for f in repo_files
207
+ if f.startswith(f"frames/{course_name}") and f.endswith('.zip')
208
+ ]
209
+
210
+ if not matching_files:
211
+ print(f"[{FLOW_ID}] No zip files found starting with '{course_name}' in frames/ directory.")
212
+ return None, None
213
+
214
+ # Filter out already processed files and select the first one
215
+ unprocessed_files = [f for f in matching_files if f not in processed_files]
216
+
217
+ if not unprocessed_files:
218
+ print(f"[{FLOW_ID}] No new zip files found for '{course_name}'.")
219
+ return None, None, None
220
+
221
+ repo_file_full_path = unprocessed_files[0] # e.g., frames/DAREEFSA_full_name.zip
222
+
223
+ # Extract the full file name from the path (e.g., DAREEFSA_full_name.zip)
224
+ zip_full_name = Path(repo_file_full_path).name
225
+ print(f"[{FLOW_ID}] Found new matching file: {repo_file_full_path}. Full name: {zip_full_name}")
226
+
227
+ # Use hf_hub_download to get the file path
228
+ zip_path = hf_hub_download(
229
+ repo_id=HF_DATASET_ID,
230
+ filename=repo_file_full_path, # Use the full path in the repo
231
+ repo_type="dataset",
232
+ token=HF_TOKEN,
233
+ )
234
+
235
+ print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
236
+
237
+ # Create a temporary directory for extraction
238
+ extract_dir = TEMP_DIR / course_name
239
+ extract_dir.mkdir(exist_ok=True)
240
+
241
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
242
+ zip_ref.extractall(extract_dir)
243
+
244
+ print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
245
+
246
+ # Return the extraction directory, the full zip file name, and the repo path
247
+ return extract_dir, zip_full_name, repo_file_full_path
248
+
249
+ except Exception as e:
250
+ print(f"[{FLOW_ID}] Error downloading or extracting zip for {course_name}: {e}")
251
+ return None, None, None
252
+
253
+ async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
254
+ """Uploads the final captions JSON file to the output dataset.
255
+
256
+ The user requested the output JSON file to be named after the full zip file name.
257
+ """
258
+ # Use the full zip name, replacing the extension with .json
259
+ caption_filename = Path(zip_full_name).with_suffix('.json').name
260
+
261
+ try:
262
+ print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
263
+
264
+ # Create JSON content in memory
265
+ json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
266
+
267
+ api = HfApi(token=HF_TOKEN)
268
+ api.upload_file(
269
+ path_or_fileobj=io.BytesIO(json_content),
270
+ path_in_repo=caption_filename,
271
+ repo_id=HF_OUTPUT_DATASET_ID,
272
+ repo_type="dataset",
273
+ commit_message=f"[{FLOW_ID}] Captions for {zip_full_name}"
274
+ )
275
+
276
+ print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
277
+ return True
278
+
279
+ except Exception as e:
280
+ print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
281
+ return False
282
+
283
+ async def process_course_task(course_name: str):
284
+ """Main task to process a single course, looping until all files are processed."""
285
+ print(f"[{FLOW_ID}] Starting continuous processing for course: {course_name}")
286
+
287
+ processed_files = set()
288
+ all_processed_files_log = []
289
+ global_success = True
290
+
291
+ # Loop to continuously check for new files matching the course_name prefix
292
+ while True:
293
+ extract_dir = None
294
+ zip_full_name = None
295
+ repo_file_full_path = None
296
+
297
+ try:
298
+ # download_and_extract_zip now returns a tuple: (extract_dir, zip_full_name, repo_file_full_path)
299
+ download_result = await download_and_extract_zip(course_name, processed_files)
300
+
301
+ if download_result is None or download_result[0] is None:
302
+ # No new files found, or an error occurred during search/download
303
+ if download_result is not None and download_result[0] is None and download_result[1] is None:
304
+ print(f"[{FLOW_ID}] No new files found for {course_name}. Exiting loop.")
305
+ break
306
+ else:
307
+ # An error occurred during search/download
308
+ raise Exception("Failed to download or extract zip file.")
309
+
310
+ extract_dir, zip_full_name, repo_file_full_path = download_result
311
+
312
+ # Add the file to the processed set immediately to avoid re-processing in the next loop
313
+ processed_files.add(repo_file_full_path)
314
+ all_processed_files_log.append(repo_file_full_path)
315
+
316
+ # --- Start Processing the single file ---
317
+
318
+ # FIX: Use recursive glob to find images in subdirectories
319
+ image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
320
+ print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
321
+
322
+ current_file_success = False
323
+
324
+ if not image_paths:
325
+ print(f"[{FLOW_ID}] No images found in {zip_full_name}. Marking as complete.")
326
+ current_file_success = True
327
+ else:
328
+ # Initialize progress tracker
329
+ progress_tracker = {
330
+ 'total': len(image_paths),
331
+ 'completed': 0
332
+ }
333
+ print(f"[{FLOW_ID}] Starting captioning for {progress_tracker['total']} images in {zip_full_name}...")
334
+
335
+ # Create a semaphore to limit concurrent tasks to the number of available servers
336
+ semaphore = asyncio.Semaphore(len(servers))
337
+
338
+ async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
339
+ async with semaphore:
340
+ return await send_image_for_captioning(image_path, course_name, progress_tracker)
341
+
342
+ # Create a list of tasks for parallel captioning
343
+ caption_tasks = []
344
+ for image_path in image_paths:
345
+ caption_tasks.append(limited_send_image_for_captioning(image_path, course_name, progress_tracker))
346
+
347
+ # Run all captioning tasks concurrently
348
+ results = await asyncio.gather(*caption_tasks)
349
+
350
+ # Filter out failed results
351
+ all_captions = [r for r in results if r is not None]
352
+
353
+ # Final progress report for the current file
354
+ if len(all_captions) == len(image_paths):
355
+ print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully completed all {len(all_captions)} captions.")
356
+ current_file_success = True
357
+ else:
358
+ print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} captions.")
359
+ current_file_success = False
360
+
361
+ # Upload results
362
+ if all_captions and zip_full_name:
363
+ # Use the full zip file name for the upload as requested
364
+ print(f"[{FLOW_ID}] Uploading {len(all_captions)} captions for {zip_full_name}...")
365
+ if await upload_captions_to_hf(zip_full_name, all_captions):
366
+ print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
367
+ # If partial success, we still upload, but the overall task is marked as failure if any file failed
368
+ if not current_file_success:
369
+ global_success = False
370
+ else:
371
+ print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}.")
372
+ current_file_success = False
373
+ global_success = False
374
+ else:
375
+ print(f"[{FLOW_ID}] No captions generated or zip_full_name is missing. Skipping upload for {zip_full_name}.")
376
+ current_file_success = False
377
+ global_success = False
378
+
379
+ # --- End Processing the single file ---
380
+
381
+ except Exception as e:
382
+ error_message = str(e)
383
+ print(f"[{FLOW_ID}] Critical error in process_course_task for {course_name}: {error_message}")
384
+ global_success = False
385
+
386
+ finally:
387
+ # Cleanup temporary files for the current file
388
+ if extract_dir and extract_dir.exists():
389
+ print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
390
+ import shutil
391
+ shutil.rmtree(extract_dir, ignore_errors=True)
392
+
393
+ # If an unrecoverable error occurred (e.g., during search/download), break the loop
394
+ if download_result is None and extract_dir is None:
395
+ break
396
+
397
+ # --- Final Report after the loop is complete ---
398
+ print(f"[{FLOW_ID}] All processing loops complete for {course_name}.")
399
+ print(f"[{FLOW_ID}] Total files processed: {len(all_processed_files_log)}")
400
+ print(f"[{FLOW_ID}] List of processed files: {all_processed_files_log}")
401
+
402
+ # Report completion to manager
403
+ final_error_message = error_message if not global_success else None
404
+ # Assuming report_completion exists and is an async function
405
+ # await report_completion(course_name, global_success, final_error_message)
406
+
407
+ return global_success
408
+
409
+ async def report_completion(course_name: str, success: bool, error_message: Optional[str] = None):
410
+ """Reports the task result back to the Manager Server."""
411
+ print(f"[{FLOW_ID}] Reporting completion for {course_name} (Success: {success})...")
412
+
413
+ payload = {
414
+ "flow_id": FLOW_ID,
415
+ "course_name": course_name,
416
+ "success": success,
417
+ "error_message": error_message
418
+ }
419
+
420
+ try:
421
+ async with aiohttp.ClientSession() as session:
422
+ async with session.post(MANAGER_COMPLETE_TASK_URL, json=payload) as resp:
423
+ if resp.status != 200:
424
+ print(f"[{FLOW_ID}] ERROR: Manager reported non-200 status: {resp.status} - {await resp.text()}")
425
+ else:
426
+ print(f"[{FLOW_ID}] Successfully reported completion to Manager.")
427
+
428
+ except aiohttp.ClientError as e:
429
+ print(f"[{FLOW_ID}] CRITICAL ERROR: Could not connect to Manager at {MANAGER_COMPLETE_TASK_URL}. Task completion not reported. Error: {e}")
430
+ except Exception as e:
431
+ print(f"[{FLOW_ID}] Unexpected error during reporting: {e}")
432
+
433
+ # --- FastAPI App and Endpoints ---
434
+
435
+ app = FastAPI(
436
+ title=f"Flow Server {FLOW_ID} API",
437
+ description="Fetches, extracts, and captions images for a given course.",
438
+ version="1.0.0"
439
+ )
440
+
441
+ @app.on_event("startup")
442
+ async def startup_event():
443
+ print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}. Manager URL: {MANAGER_URL}")
444
+
445
+ @app.get("/")
446
+ async def root():
447
+ return {
448
+ "flow_id": FLOW_ID,
449
+ "status": "ready",
450
+ "manager_url": MANAGER_URL,
451
+ "total_servers": len(servers),
452
+ "busy_servers": sum(1 for s in servers if s.busy),
453
+ }
454
+
455
+ @app.post("/process_course")
456
+ async def process_course(request: ProcessCourseRequest, background_tasks: BackgroundTasks):
457
+ """
458
+ Receives a course name from the Manager and starts processing in the background.
459
+ """
460
+ course_name = request.course_name
461
+
462
+ if not course_name:
463
+ print(f"[{FLOW_ID}] Received empty course name. Stopping processing loop.")
464
+ return {"status": "stopped", "message": "No more courses to process."}
465
+
466
+ print(f"[{FLOW_ID}] Received course: {course_name}. Starting background task.")
467
+
468
+ # Start the heavy processing in a background task so the API call returns immediately
469
+ background_tasks.add_task(process_course_task, course_name)
470
+
471
+ return {"status": "processing", "course_name": course_name, "message": "Processing started in background."}
472
+
473
+ if __name__ == "__main__":
474
+ # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
475
+ uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)