Fred808 commited on
Commit
058bb22
·
verified ·
1 Parent(s): 76b9935

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +327 -0
app.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import aiohttp
5
+ from typing import Dict, List, Set, Optional
6
+ from pathlib import Path
7
+ from datetime import datetime
8
+
9
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, status
10
+ from pydantic import BaseModel
11
+ from huggingface_hub import HfApi, HfFileSystem
12
+ import uvicorn
13
+
14
+ # --- Configuration ---
15
+ # Manager Server will run on port 8000
16
+ MANAGER_PORT = 8000
17
+
18
+ # Hugging Face Configuration
19
+ HF_TOKEN = "" # User provided token
20
+ HF_DATASET_ID = "Fred808/BG3"
21
+ HF_DATASET_REPO_TYPE = "dataset"
22
+ FRAMES_FOLDER = "frames"
23
+ STATE_FILE_PATH = "flow_processing_state.json"
24
+
25
+ # Flow Server Configuration (Hardcoded as per user request)
26
+ # NOTE: These URLs must be accessible to the Manager Server.
27
+ # For local testing, you might use localhost with different ports (e.g., 8001 and 8002)
28
+ FLOW_SERVERS = {
29
+ "flow1": "http://localhost:8001",
30
+ "flow2": "http://localhost:8002",
31
+ }
32
+
33
+ # --- State Management Models ---
34
+ class TaskStatus(BaseModel):
35
+ status: str # UNPROCESSED, IN_PROGRESS, COMPLETED, FAILED
36
+ assigned_to: Optional[str] = None # flow1 or flow2
37
+ assigned_at: Optional[datetime] = None
38
+ completed_at: Optional[datetime] = None
39
+ error_message: Optional[str] = None
40
+
41
+ class ProcessingState(BaseModel):
42
+ # Key is the zip file name (which is the course name)
43
+ tasks: Dict[str, TaskStatus] = {}
44
+
45
+ # Track which flow server is currently processing which task
46
+ flow_assignments: Dict[str, Optional[str]] = {
47
+ "flow1": None,
48
+ "flow2": None,
49
+ }
50
+
51
+ class CompleteTaskRequest(BaseModel):
52
+ flow_id: str
53
+ course_name: str
54
+ success: bool
55
+ error_message: Optional[str] = None
56
+
57
+ # --- Global State and Initialization ---
58
+ app = FastAPI(
59
+ title="BG3 Processing Manager",
60
+ description="Coordinates flow servers for BG3 dataset processing.",
61
+ version="1.0.0"
62
+ )
63
+
64
+ api = HfApi(token=HF_TOKEN)
65
+ fs = HfFileSystem(token=HF_TOKEN)
66
+ state = ProcessingState()
67
+ is_coordinating = False
68
+
69
+ # --- Persistence Functions ---
70
+
71
+ def get_full_path(filename: str) -> str:
72
+ return f"{HF_DATASET_ID}/{filename}"
73
+
74
+ async def load_state_from_hf():
75
+ global state
76
+ try:
77
+ # Check if state file exists
78
+ if fs.exists(f"{HF_DATASET_ID}/{STATE_FILE_PATH}"):
79
+ print(f"Loading state from {STATE_FILE_PATH}...")
80
+
81
+ # Use HfApi to download the file content
82
+ content = api.read_file(
83
+ path_in_repo=STATE_FILE_PATH,
84
+ repo_id=HF_DATASET_ID,
85
+ repo_type=HF_DATASET_REPO_TYPE
86
+ ).decode('utf-8')
87
+
88
+ data = json.loads(content)
89
+ state = ProcessingState(**data)
90
+ print(f"State loaded. Total tasks: {len(state.tasks)}")
91
+ else:
92
+ print(f"State file {STATE_FILE_PATH} not found. Initializing.")
93
+ await initialize_tasks()
94
+ await save_state_to_hf() # Save initial state
95
+
96
+ except Exception as e:
97
+ print(f"Error loading state from HF: {e}")
98
+ # Fallback to initialization if loading fails
99
+ await initialize_tasks()
100
+
101
+ async def save_state_to_hf():
102
+ try:
103
+ print(f"Saving state to {STATE_FILE_PATH}...")
104
+ content = state.model_dump_json(indent=2).encode('utf-8')
105
+
106
+ api.upload_file(
107
+ path_or_fileobj=content,
108
+ path_in_repo=STATE_FILE_PATH,
109
+ repo_id=HF_DATASET_ID,
110
+ repo_type=HF_DATASET_REPO_TYPE,
111
+ commit_message="Update processing state"
112
+ )
113
+ print("State saved successfully.")
114
+ except Exception as e:
115
+ print(f"Error saving state to HF: {e}")
116
+
117
+ async def initialize_tasks():
118
+ global state
119
+ print(f"Discovering zip files in {FRAMES_FOLDER}/...")
120
+
121
+ # List files in the frames folder
122
+ try:
123
+ file_list = fs.ls(f"{HF_DATASET_ID}/{FRAMES_FOLDER}", detail=False)
124
+
125
+ zip_files = [
126
+ Path(f).name
127
+ for f in file_list
128
+ if f.endswith(".zip") and not f.endswith(".zip.json")
129
+ ]
130
+
131
+ new_tasks = {}
132
+ for zip_file in zip_files:
133
+ course_name = zip_file.replace(".zip", "")
134
+ if course_name not in state.tasks:
135
+ new_tasks[course_name] = TaskStatus(status="UNPROCESSED")
136
+ else:
137
+ # Keep existing status if it was already tracked
138
+ new_tasks[course_name] = state.tasks[course_name]
139
+
140
+ state.tasks = new_tasks
141
+ print(f"Found {len(zip_files)} zip files. Total tasks: {len(state.tasks)}")
142
+
143
+ except Exception as e:
144
+ print(f"Error discovering files from HF: {e}")
145
+ # If discovery fails, we can't proceed.
146
+ raise RuntimeError(f"Failed to discover files: {e}")
147
+
148
+ # --- Core Coordination Logic ---
149
+
150
+ async def assign_next_task(flow_id: str):
151
+ """
152
+ Finds the next UNPROCESSED task and assigns it to the given flow server.
153
+ """
154
+ global state
155
+
156
+ # 1. Find an UNPROCESSED task
157
+ next_course = None
158
+ for course_name, task_status in state.tasks.items():
159
+ if task_status.status == "UNPROCESSED":
160
+ next_course = course_name
161
+ break
162
+
163
+ if next_course is None:
164
+ print(f"No UNPROCESSED tasks left for {flow_id}.")
165
+ course_to_assign = None
166
+
167
+ else:
168
+ # 2. Update state to IN_PROGRESS
169
+ state.tasks[next_course] = TaskStatus(
170
+ status="IN_PROGRESS",
171
+ assigned_to=flow_id,
172
+ assigned_at=datetime.now()
173
+ )
174
+ state.flow_assignments[flow_id] = next_course
175
+ course_to_assign = next_course
176
+
177
+ # 3. Persist state change
178
+ await save_state_to_hf()
179
+
180
+ # 4. Notify the Flow Server
181
+ flow_url = FLOW_SERVERS.get(flow_id)
182
+ if not flow_url:
183
+ print(f"Error: Unknown flow_id {flow_id}")
184
+ return
185
+
186
+ try:
187
+ print(f"Assigning '{course_to_assign}' to {flow_id} at {flow_url}/process_course")
188
+ async with aiohttp.ClientSession() as session:
189
+ async with session.post(
190
+ f"{flow_url}/process_course",
191
+ json={"course_name": course_to_assign}
192
+ ) as response:
193
+ if response.status != 200:
194
+ print(f"Error sending task to {flow_id}: {response.status} - {await response.text()}")
195
+ # Revert state if assignment fails
196
+ if next_course:
197
+ state.tasks[next_course] = TaskStatus(status="UNPROCESSED")
198
+ state.flow_assignments[flow_id] = None
199
+ await save_state_to_hf()
200
+ else:
201
+ print(f"Successfully assigned {course_to_assign} to {flow_id}.")
202
+
203
+ except aiohttp.ClientConnectorError as e:
204
+ print(f"Connection Error: Could not connect to {flow_id} at {flow_url}. Reverting task status. Error: {e}")
205
+ # Revert state if connection fails
206
+ if next_course:
207
+ state.tasks[next_course] = TaskStatus(status="UNPROCESSED")
208
+ state.flow_assignments[flow_id] = None
209
+ await save_state_to_hf()
210
+ except Exception as e:
211
+ print(f"Unexpected error during assignment to {flow_id}. Error: {e}")
212
+ # Revert state for safety
213
+ if next_course:
214
+ state.tasks[next_course] = TaskStatus(status="UNPROCESSED")
215
+ state.flow_assignments[flow_id] = None
216
+ await save_state_to_hf()
217
+
218
+
219
+ async def coordinate_loop():
220
+ """
221
+ The main coordination loop that runs in the background.
222
+ """
223
+ global is_coordinating
224
+ if is_coordinating:
225
+ print("Coordinator is already running.")
226
+ return
227
+
228
+ is_coordinating = True
229
+ print("Starting coordination loop...")
230
+
231
+ try:
232
+ # Load state and initialize tasks on startup
233
+ await load_state_from_hf()
234
+
235
+ # Check and assign tasks to any free flow server
236
+ for flow_id in FLOW_SERVERS.keys():
237
+ if state.flow_assignments.get(flow_id) is None:
238
+ asyncio.create_task(assign_next_task(flow_id))
239
+
240
+ except Exception as e:
241
+ print(f"Coordination loop failed to start: {e}")
242
+ finally:
243
+ # The loop is now event-driven based on /task/complete calls
244
+ pass
245
+
246
+ # --- API Endpoints ---
247
+
248
+ @app.on_event("startup")
249
+ async def startup_event():
250
+ # Start the coordination loop as a background task
251
+ BackgroundTasks().add_task(coordinate_loop)
252
+
253
+ @app.get("/")
254
+ async def root():
255
+ return {
256
+ "message": "BG3 Processing Manager API",
257
+ "status": "running",
258
+ "is_coordinating": is_coordinating,
259
+ "flow_assignments": state.flow_assignments,
260
+ "total_tasks": len(state.tasks),
261
+ "unprocessed": sum(1 for t in state.tasks.values() if t.status == "UNPROCESSED"),
262
+ "in_progress": sum(1 for t in state.tasks.values() if t.status == "IN_PROGRESS"),
263
+ "completed": sum(1 for t in state.tasks.values() if t.status == "COMPLETED"),
264
+ }
265
+
266
+ @app.post("/task/complete")
267
+ async def task_complete(request: CompleteTaskRequest):
268
+ """
269
+ Endpoint for flow servers to report task completion.
270
+ """
271
+ global state
272
+ flow_id = request.flow_id
273
+ course_name = request.course_name
274
+
275
+ if course_name not in state.tasks:
276
+ raise HTTPException(status_code=404, detail=f"Unknown course: {course_name}")
277
+
278
+ task = state.tasks[course_name]
279
+
280
+ if task.assigned_to != flow_id:
281
+ # This is a safety check, should not happen in normal operation
282
+ print(f"Warning: {flow_id} reported completion for a task not assigned to it: {course_name}")
283
+
284
+ if request.success:
285
+ print(f"Task COMPLETED: {course_name} by {flow_id}")
286
+ task.status = "COMPLETED"
287
+ task.completed_at = datetime.now()
288
+ task.error_message = None
289
+ else:
290
+ print(f"Task FAILED: {course_name} by {flow_id}. Error: {request.error_message}")
291
+ # For now, mark as FAILED. A more robust system might retry.
292
+ task.status = "FAILED"
293
+ task.completed_at = datetime.now()
294
+ task.error_message = request.error_message
295
+
296
+ # Free up the flow server slot
297
+ state.flow_assignments[flow_id] = None
298
+
299
+ # Persist state change
300
+ await save_state_to_hf()
301
+
302
+ # Assign the next task to the now-free flow server
303
+ asyncio.create_task(assign_next_task(flow_id))
304
+
305
+ return {"status": "success", "message": f"Task {course_name} marked as {'COMPLETED' if request.success else 'FAILED'}. Next task assigned."}
306
+
307
+ @app.post("/start_coordination")
308
+ async def start_coordination(background_tasks: BackgroundTasks):
309
+ """
310
+ Manually trigger the coordination loop.
311
+ """
312
+ if is_coordinating:
313
+ return {"status": "info", "message": "Coordination is already running."}
314
+
315
+ background_tasks.add_task(coordinate_loop)
316
+ return {"status": "success", "message": "Coordination loop started."}
317
+
318
+ @app.get("/state")
319
+ async def get_state():
320
+ """
321
+ Returns the current processing state.
322
+ """
323
+ return state
324
+
325
+ if __name__ == "__main__":
326
+ # Note: When running in the sandbox, we need to use 0.0.0.0 to expose the port.
327
+ uvicorn.run(app, host="0.0.0.0", port=MANAGER_PORT)