File size: 29,770 Bytes
ef25d3a
 
 
 
 
 
a303362
 
 
ef25d3a
 
 
 
a303362
 
 
ef25d3a
a303362
ef25d3a
 
 
 
 
 
 
 
 
 
a303362
ef25d3a
a303362
ef25d3a
bdff23f
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdff23f
ef25d3a
 
 
a303362
ef25d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
 
a303362
 
 
 
 
 
 
bdff23f
 
 
a303362
 
 
bdff23f
 
 
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a303362
ef25d3a
a303362
 
 
ef25d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a303362
 
ef25d3a
a303362
 
 
ef25d3a
a303362
ef25d3a
a303362
ef25d3a
 
a303362
ef25d3a
 
 
 
 
 
a303362
 
 
ef25d3a
 
 
 
 
 
a303362
 
 
 
 
ef25d3a
 
a303362
 
ef25d3a
 
 
 
 
 
 
 
 
 
a303362
ef25d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a303362
 
 
ef25d3a
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
 
a303362
ef25d3a
 
a303362
ef25d3a
a303362
 
ef25d3a
a303362
 
ef25d3a
a303362
ef25d3a
 
 
 
 
a303362
ef25d3a
a303362
ef25d3a
 
 
 
a303362
 
 
 
 
ef25d3a
 
 
 
 
a303362
ef25d3a
 
 
a303362
ef25d3a
 
a303362
ef25d3a
 
a303362
ef25d3a
a303362
ef25d3a
 
 
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
a303362
 
 
 
 
 
 
 
ef25d3a
a303362
 
 
 
 
 
ef25d3a
 
 
a303362
 
 
 
 
 
ef25d3a
 
a303362
ef25d3a
 
 
 
a303362
 
 
ef25d3a
 
 
 
 
 
a303362
ef25d3a
 
a303362
 
ef25d3a
 
 
 
a303362
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
 
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
a303362
 
 
 
ef25d3a
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
a303362
 
 
 
ef25d3a
a303362
 
ef25d3a
a303362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef25d3a
 
bf4e87d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
import os
import json
import time
import asyncio
import aiohttp
import zipfile
import io
import shutil
from typing import Dict, List, Set, Optional, Any
from urllib.parse import quote
from datetime import datetime
from pathlib import Path

from fastapi import FastAPI, BackgroundTasks, HTTPException, status, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field
from huggingface_hub import HfApi, hf_hub_download, HfFileSystem
import uvicorn

# --- Configuration ---
FLOW_ID = os.getenv("FLOW_ID", "flow_default")
FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
MANAGER_URL = os.getenv("MANAGER_URL", "https://fred808-fcord.hf.space")
MANAGER_COMPLETE_TASK_URL = f"{MANAGER_URL}/task/complete"
HF_TOKEN = os.getenv("HF_TOKEN", "")
HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3")
HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "fred808/helium")
STATE_FILE_NAME = f"{FLOW_ID}_state.json"

# Using the full list from the user's original code for actual deployment
CAPTION_SERVERS = [
    # Using the full list from the user's original code for actual deployment
    "https://fred808-pil-4-1.hf.space/analyze",
    "https://fred808-pil-4-2.hf.space/analyze",
    "https://fred808-pil-4-3.hf.space/analyze",
    "https://fred1012-fred1012-gw0j2h.hf.space/analyze",
    "https://fred1012-fred1012-wqs6c2.hf.space/analyze",
    "https://fred1012-fred1012-oncray.hf.space/analyze",
    "https://fred1012-fred1012-4goge7.hf.space/analyze",
    "https://fred1012-fred1012-z0eh7m.hf.space/analyze",
    "https://fred1012-fred1012-u95rte.hf.space/analyze",
    "https://fred1012-fred1012-igje22.hf.space/analyze",
    "https://fred1012-fred1012-ibkuf8.hf.space/analyze",
    "https://fred1012-fred1012-nwqthy.hf.space/analyze",
    "https://fred1012-fred1012-4ldqj4.hf.space/analyze",
    "https://fred1012-fred1012-pivlzg.hf.space/analyze",
    "https://fred1012-fred1012-ptlc5u.hf.space/analyze",
    "https://fred1012-fred1012-u7lh57.hf.space/analyze",
    "https://fred1012-fred1012-q8djv1.hf.space/analyze",
    "https://fredalone-fredalone-ozugrp.hf.space/analyze",
    "https://fredalone-fredalone-9brxj2.hf.space/analyze",
    "https://fredalone-fredalone-p8vq9a.hf.space/analyze",
    "https://fredalone-fredalone-vbli2y.hf.space/analyze",
    "https://fredalone-fredalone-uggger.hf.space/analyze",
    "https://fredalone-fredalone-nmi7e8.hf.space/analyze",
    "https://fredalone-fredalone-d1f26d.hf.space/analyze",
    "https://fredalone-fredalone-461jp2.hf.space/analyze",
    "https://fredalone-fredalone-3enfg4.hf.space/analyze",
    "https://fredalone-fredalone-dqdbpv.hf.space/analyze",
    "https://fredalone-fredalone-ivtjua.hf.space/analyze",
    "https://fredalone-fredalone-6bezt2.hf.space/analyze",
    "https://fredalone-fredalone-e0wfnk.hf.space/analyze",
    "https://fredalone-fredalone-zu2t7j.hf.space/analyze",
    "https://fredalone-fredalone-dqtv1o.hf.space/analyze",
    "https://fredalone-fredalone-wclyog.hf.space/analyze",
    "https://fredalone-fredalone-t27vig.hf.space/analyze",
    "https://fredalone-fredalone-gahbxh.hf.space/analyze",
    "https://fredalone-fredalone-kw2po4.hf.space/analyze",
    "https://fredalone-fredalone-8h285h.hf.space",
]
MODEL_TYPE = "Florence-2-large"

# Temporary storage for images
TEMP_DIR = Path(f"temp_images_{FLOW_ID}")
TEMP_DIR.mkdir(exist_ok=True)

# --- Models ---
class ProcessCourseRequest(BaseModel):
    course_name: Optional[str] = None

class CaptionServer:
    def __init__(self, url):
        self.url = url
        self.busy = False
        self.total_processed = 0
        self.total_time = 0
        self.model = MODEL_TYPE

    @property
    def fps(self):
        return self.total_processed / self.total_time if self.total_time > 0 else 0

class ServerState(BaseModel):
    # The list of all zip files in the dataset (frames/ directory)
    all_zip_files: List[str] = Field(default_factory=list)
    # The set of zip files that have been successfully processed and uploaded
    processed_files: Set[str] = Field(default_factory=set)
    # The index in all_zip_files from which the next download should start
    current_index: int = 0
    # Total number of files to process
    total_files: int = 0
    # Status of the current operation
    status: str = "Idle"
    # Name of the file currently being processed
    current_file: Optional[str] = None
    # Progress within the current file
    current_file_progress: str = "0/0"
    # Timestamp of the last update
    last_update: str = datetime.now().isoformat()
    # Flag to control the processing loop
    is_running: bool = False

# Global state for caption servers and the overall server state
servers = [CaptionServer(url) for url in CAPTION_SERVERS]
server_index = 0
state = ServerState()
# Lock for thread-safe access to the global state
state_lock = asyncio.Lock()

# --- Persistence Functions ---

def get_hf_api():
    """Helper to get HfApi instance, raising error if token is missing."""
    if not HF_TOKEN:
        raise ValueError("HF_TOKEN environment variable is not set. Cannot access Hugging Face.")
    return HfApi(token=HF_TOKEN)

def get_hf_fs():
    """Helper to get HfFileSystem instance, raising error if token is missing."""
    if not HF_TOKEN:
        raise ValueError("HF_TOKEN environment variable is not set. Cannot access Hugging Face.")
    return HfFileSystem(token=HF_TOKEN)

async def load_state_from_hf():
    """Loads the state from the Hugging Face output dataset."""
    global state
    fs = get_hf_fs()
    state_path = f"{HF_OUTPUT_DATASET_ID}/{STATE_FILE_NAME}"
    
    async with state_lock:
        try:
            if fs.exists(state_path):
                print(f"[{FLOW_ID}] Loading state from {state_path}...")
                with fs.open(state_path, 'rb') as f:
                    data = json.load(f)
                    # Convert list of processed files back to a set
                    if 'processed_files' in data and isinstance(data['processed_files'], list):
                        data['processed_files'] = set(data['processed_files'])
                    state = ServerState.parse_obj(data)
                    print(f"[{FLOW_ID}] State loaded successfully. Current index: {state.current_index}")
            else:
                print(f"[{FLOW_ID}] State file {state_path} not found. Starting with default state.")
        except Exception as e:
            print(f"[{FLOW_ID}] Error loading state from HF: {e}. Starting with default state.")
            state = ServerState()

async def save_state_to_hf():
    """Saves the current state to the Hugging Face output dataset."""
    global state
    api = get_hf_api()
    state_path = STATE_FILE_NAME
    
    async with state_lock:
        state.last_update = datetime.now().isoformat()
        # Convert set of processed files to a list for JSON serialization
        data_to_save = state.dict()
        data_to_save['processed_files'] = list(state.processed_files)
        
        json_content = json.dumps(data_to_save, indent=2, ensure_ascii=False).encode('utf-8')
        
        try:
            print(f"[{FLOW_ID}] Saving state to {state_path} in {HF_OUTPUT_DATASET_ID}...")
            api.upload_file(
                path_or_fileobj=io.BytesIO(json_content),
                path_in_repo=state_path,
                repo_id=HF_OUTPUT_DATASET_ID,
                repo_type="dataset",
                commit_message=f"[{FLOW_ID}] Update server state. Index: {state.current_index}"
            )
            print(f"[{FLOW_ID}] State saved successfully.")
            return True
        except Exception as e:
            print(f"[{FLOW_ID}] Error saving state to HF: {e}")
            return False

async def update_file_list():
    """Fetches the list of all zip files from the BG3 dataset."""
    global state
    api = get_hf_api()
    
    async with state_lock:
        try:
            state.status = "Updating file list..."
            print(f"[{FLOW_ID}] Fetching file list from {HF_DATASET_ID}...")
            repo_files = api.list_repo_files(
                repo_id=HF_DATASET_ID,
                repo_type="dataset"
            )
            
            # Filter for zip files in the 'frames/' directory
            zip_files = sorted([
                f for f in repo_files 
                if f.startswith("frames/") and f.endswith('.zip')
            ])
            
            state.all_zip_files = zip_files
            state.total_files = len(zip_files)
            state.status = "File list updated."
            print(f"[{FLOW_ID}] Found {state.total_files} zip files.")
        except Exception as e:
            state.status = f"Error updating file list: {e}"
            print(f"[{FLOW_ID}] Error updating file list: {e}")
            
        await save_state_to_hf()

# --- Core Processing Functions (Modified) ---

async def get_available_server(timeout: float = 300.0) -> CaptionServer:
    """Round-robin selection of an available caption server."""
    global server_index
    start_time = time.time()
    while True:
        for _ in range(len(servers)):
            server = servers[server_index]
            server_index = (server_index + 1) % len(servers)
            if not server.busy:
                return server
        
        await asyncio.sleep(0.5)
        
        if time.time() - start_time > timeout:
            raise TimeoutError(f"Timeout ({timeout}s) waiting for an available caption server.")

async def send_image_for_captioning(image_path: Path, course_name: str, progress_tracker: Dict) -> Optional[Dict]:
    """Sends a single image to a caption server for processing."""
    MAX_RETRIES = 3
    for attempt in range(MAX_RETRIES):
        server = None
        try:
            server = await get_available_server()
            server.busy = True
            start_time = time.time()
            
            if attempt == 0:
                print(f"[{FLOW_ID}] Starting attempt on {image_path.name}...")
            
            form_data = aiohttp.FormData()
            form_data.add_field('file',
                                image_path.open('rb'),
                                filename=image_path.name,
                                content_type='image/jpeg')
            form_data.add_field('model_choice', MODEL_TYPE)
            
            async with aiohttp.ClientSession() as session:
                async with session.post(server.url, data=form_data, timeout=600) as resp:
                    if resp.status == 200:
                        result = await resp.json()
                        caption = result.get("caption")
                        
                        if caption:
                            # Update progress counter and global state
                            progress_tracker['completed'] += 1
                            async with state_lock:
                                state.current_file_progress = f"{progress_tracker['completed']}/{progress_tracker['total']}"
                                
                            if progress_tracker['completed'] % 50 == 0:
                                print(f"[{FLOW_ID}] PROGRESS: {progress_tracker['completed']}/{progress_tracker['total']} captions completed.")
                            
                            return {
                                "course": course_name,
                                "image_path": image_path.name,
                                "caption": caption,
                                "timestamp": datetime.now().isoformat()
                            }
                        else:
                            print(f"[{FLOW_ID}] Server {server.url} returned success but no caption for {image_path.name}. Retrying...")
                            continue
                    else:
                        error_text = await resp.text()
                        print(f"[{FLOW_ID}] Error from server {server.url} for {image_path.name}: {resp.status} - {error_text}. Retrying...")
                        continue
                        
        except (aiohttp.ClientError, asyncio.TimeoutError, TimeoutError) as e:
            print(f"[{FLOW_ID}] Connection/Timeout error for {image_path.name} on {server.url if server else 'unknown server'}: {e}. Retrying...")
            continue
        except Exception as e:
            print(f"[{FLOW_ID}] Unexpected error during captioning for {image_path.name}: {e}. Retrying...")
            continue
        finally:
            if server:
                end_time = time.time()
                server.busy = False
                server.total_processed += 1
                server.total_time += (end_time - start_time)

    print(f"[{FLOW_ID}] FAILED after {MAX_RETRIES} attempts for {image_path.name}.")
    return None

async def download_and_extract_zip(repo_file_full_path: str) -> Optional[tuple[Path, str]]:
    """Downloads the zip file at the given path and extracts its contents."""
    
    zip_full_name = Path(repo_file_full_path).name 
    course_name = zip_full_name.split('_')[0] # Assuming course name is the prefix before the first underscore

    try:
        print(f"[{FLOW_ID}] Downloading file: {repo_file_full_path}. Full name: {zip_full_name}")
        
        # Use hf_hub_download to get the file path
        zip_path = hf_hub_download(
            repo_id=HF_DATASET_ID,
            filename=repo_file_full_path, # Use the full path in the repo
            repo_type="dataset",
            token=HF_TOKEN,
        )
        
        print(f"[{FLOW_ID}] Downloaded to {zip_path}. Extracting...")
        
        # Create a temporary directory for extraction
        extract_dir = TEMP_DIR / course_name / zip_full_name.replace('.', '_')
        extract_dir.mkdir(parents=True, exist_ok=True)
        
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_dir)
            
        print(f"[{FLOW_ID}] Extraction complete to {extract_dir}.")
        
        # Clean up the downloaded zip file
        os.remove(zip_path)
        
        # Return the extraction directory and the full zip file name
        return extract_dir, zip_full_name
        
    except Exception as e:
        print(f"[{FLOW_ID}] Error downloading or extracting zip for {repo_file_full_path}: {e}")
        return None

async def upload_captions_to_hf(zip_full_name: str, captions: List[Dict]) -> bool:
    """Uploads the final captions JSON file to the output dataset."""
    caption_filename = Path(zip_full_name).with_suffix('.json').name
    
    try:
        print(f"[{FLOW_ID}] Uploading {len(captions)} captions for {zip_full_name} as {caption_filename} to {HF_OUTPUT_DATASET_ID}...")
        
        json_content = json.dumps(captions, indent=2, ensure_ascii=False).encode('utf-8')
        
        api = get_hf_api()
        api.upload_file(
            path_or_fileobj=io.BytesIO(json_content),
            path_in_repo=caption_filename,
            repo_id=HF_OUTPUT_DATASET_ID,
            repo_type="dataset",
            commit_message=f"[{FLOW_ID}] Captions for {zip_full_name}"
        )
        
        print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
        return True
        
    except Exception as e:
        print(f"[{FLOW_ID}] Error uploading captions for {zip_full_name}: {e}")
        return False

async def process_next_file_task():
    """Task to process the next file in the list based on the current index."""
    global state
    
    if not state.is_running:
        print(f"[{FLOW_ID}] Processing loop is not running. Exiting task.")
        return
        
    while state.is_running:
        repo_file_full_path = None
        current_index = -1
        
        async with state_lock:
            current_index = state.current_index
            if current_index >= state.total_files:
                state.status = "Finished processing all files."
                state.is_running = False
                print(f"[{FLOW_ID}] Reached end of file list. Stopping processing.")
                await save_state_to_hf()
                break
                
            repo_file_full_path = state.all_zip_files[current_index]
            
            if repo_file_full_path in state.processed_files:
                state.current_index += 1
                state.status = f"Skipping processed file: {Path(repo_file_full_path).name}"
                state.current_file = Path(repo_file_full_path).name
                print(f"[{FLOW_ID}] Skipping already processed file: {repo_file_full_path}")
                await save_state_to_hf()
                continue
                
            # Mark the file as in-progress in the state
            state.status = f"Processing file {current_index + 1}/{state.total_files}"
            state.current_file = Path(repo_file_full_path).name
            state.current_file_progress = "0/0"
            await save_state_to_hf()
            
        # --- Start Processing ---
        extract_dir = None
        zip_full_name = None
        global_success = False
        
        try:
            download_result = await download_and_extract_zip(repo_file_full_path)
            
            if download_result is None:
                raise Exception("Failed to download or extract zip file.")
                
            extract_dir, zip_full_name = download_result
            course_name = zip_full_name.split('_')[0]
            
            # Find images
            image_paths = [p for p in extract_dir.glob("**/*") if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
            print(f"[{FLOW_ID}] Found {len(image_paths)} images to process in {zip_full_name}.")
            
            if not image_paths:
                print(f"[{FLOW_ID}] No images found in {zip_full_name}. Marking as complete.")
                global_success = True
            else:
                # Initialize progress tracker
                progress_tracker = {
                    'total': len(image_paths),
                    'completed': 0
                }
                async with state_lock:
                    state.current_file_progress = f"0/{len(image_paths)}"
                    await save_state_to_hf()

                # Create and run captioning tasks
                semaphore = asyncio.Semaphore(len(servers))
                async def limited_send_image_for_captioning(image_path, course_name, progress_tracker):
                    async with semaphore:
                        return await send_image_for_captioning(image_path, course_name, progress_tracker)
                
                caption_tasks = [limited_send_image_for_captioning(p, course_name, progress_tracker) for p in image_paths]
                results = await asyncio.gather(*caption_tasks)
                all_captions = [r for r in results if r is not None]
                
                # Final progress report
                if len(all_captions) == len(image_paths):
                    print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Successfully completed all {len(all_captions)} captions.")
                    global_success = True
                else:
                    print(f"[{FLOW_ID}] FINAL PROGRESS for {zip_full_name}: Completed with partial result: {len(all_captions)}/{len(image_paths)} captions.")
                    global_success = False
                
                # Upload results
                if all_captions and zip_full_name:
                    if await upload_captions_to_hf(zip_full_name, all_captions):
                        print(f"[{FLOW_ID}] Successfully uploaded captions for {zip_full_name}.")
                        # If upload is successful, we mark the file as processed, regardless of partial success
                        # The uploaded JSON will reflect the actual number of captions
                        if global_success:
                            print(f"[{FLOW_ID}] Fully processed and uploaded: {zip_full_name}")
                        else:
                            print(f"[{FLOW_ID}] Partially processed but uploaded: {zip_full_name}. Needs manual review.")
                        
                        # Mark as processed only if upload succeeded
                        async with state_lock:
                            state.processed_files.add(repo_file_full_path)
                            state.current_index += 1
                            state.current_file = None
                            state.current_file_progress = "0/0"
                            state.status = "Idle"
                            await save_state_to_hf()
                            
                    else:
                        print(f"[{FLOW_ID}] Failed to upload captions for {zip_full_name}. Will retry this file later.")
                        # Do NOT increment index or mark as processed, so it will be retried
                        async with state_lock:
                            state.status = f"Error uploading captions for {zip_full_name}. Retrying later."
                            await save_state_to_hf()
                        # Wait before retrying to avoid immediate re-attempt on a transient error
                        await asyncio.sleep(60) 
                        
                else:
                    print(f"[{FLOW_ID}] No captions generated or zip_full_name is missing. Skipping upload for {zip_full_name}. Will retry later.")
                    # Do NOT increment index or mark as processed
                    async with state_lock:
                        state.status = f"No captions generated for {zip_full_name}. Retrying later."
                        await save_state_to_hf()
                    await asyncio.sleep(60)
            
        except Exception as e:
            error_message = str(e)
            print(f"[{FLOW_ID}] Critical error in process_next_file_task for {repo_file_full_path}: {error_message}")
            async with state_lock:
                state.status = f"CRITICAL ERROR for {Path(repo_file_full_path).name}. Retrying later. Error: {error_message[:50]}..."
                await save_state_to_hf()
            # Wait before retrying
            await asyncio.sleep(60)
            
        finally:
            # Cleanup temporary files
            if extract_dir and extract_dir.exists():
                print(f"[{FLOW_ID}] Cleaned up temporary directory {extract_dir}.")
                shutil.rmtree(extract_dir, ignore_errors=True)
                
            # If the loop is still running, wait a short time before checking for the next file
            if state.is_running:
                await asyncio.sleep(5) 

# --- FastAPI App and Endpoints ---

app = FastAPI(
    title=f"Flow Server {FLOW_ID} API",
    description="Fetches, extracts, and captions images for a given course.",
    version="2.0.0"
)

# Setup Jinja2 templates for the UI
templates = Jinja2Templates(directory="templates")

@app.on_event("startup")
async def startup_event():
    print(f"Flow Server {FLOW_ID} started on port {FLOW_PORT}. Manager URL: {MANAGER_URL}")
    # 1. Load state from persistence (HF)
    await load_state_from_hf()
    # 2. Update the list of all files from the dataset
    await update_file_list()
    # 3. Start the continuous processing task if the index is valid
    if state.current_index < state.total_files:
        state.is_running = True
        BackgroundTasks().add_task(process_next_file_task)
    else:
        state.is_running = False
        print(f"[{FLOW_ID}] Index {state.current_index} is out of bounds. Starting in Idle mode.")


@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    """Home page with status and controls."""
    async with state_lock:
        processed_count = len(state.processed_files)
        remaining_count = state.total_files - processed_count
        
        # Calculate server stats
        server_stats = [
            {
                "url": s.url,
                "busy": s.busy,
                "processed": s.total_processed,
                "fps": f"{s.fps:.2f}"
            } for s in servers
        ]
        
        # Calculate overall FPS
        total_processed = sum(s.total_processed for s in servers)
        total_time = sum(s.total_time for s in servers)
        overall_fps = total_processed / total_time if total_time > 0 else 0
        
        context = {
            "request": request,
            "flow_id": FLOW_ID,
            "status": state.status,
            "is_running": state.is_running,
            "total_files": state.total_files,
            "processed_count": processed_count,
            "remaining_count": remaining_count,
            "current_index": state.current_index,
            "current_file": state.current_file if state.current_file else "N/A",
            "current_file_progress": state.current_file_progress,
            "last_update": state.last_update,
            "overall_fps": f"{overall_fps:.2f}",
            "server_stats": server_stats
        }
    return templates.TemplateResponse("index.html", context)

@app.post("/set_index")
async def set_index(request: Request, background_tasks: BackgroundTasks):
    """Endpoint to manually set the start index."""
    global state
    
    form = await request.form()
    try:
        new_index = int(form.get("start_index"))
    except (TypeError, ValueError):
        raise HTTPException(status_code=400, detail="Invalid index value.")
        
    async with state_lock:
        if 0 <= new_index < state.total_files:
            state.current_index = new_index
            state.status = f"Index set to {new_index}. Restarting processing."
            
            # If the loop is not running, start it
            if not state.is_running:
                state.is_running = True
                background_tasks.add_task(process_next_file_task)
                
            await save_state_to_hf()
            print(f"[{FLOW_ID}] Index manually set to {new_index}.")
            return {"status": "success", "message": f"Start index set to {new_index}. Processing will resume from this point."}
        elif new_index == state.total_files:
            state.current_index = new_index
            state.is_running = False
            state.status = "Finished processing all files."
            await save_state_to_hf()
            return {"status": "success", "message": "Index set to end of list. Processing stopped."}
        else:
            raise HTTPException(status_code=400, detail=f"Index {new_index} is out of bounds (0 to {state.total_files}).")

@app.post("/control_processing")
async def control_processing(request: Request, background_tasks: BackgroundTasks):
    """Endpoint to start/stop the processing loop."""
    global state
    
    form = await request.form()
    action = form.get("action")
    
    async with state_lock:
        if action == "start":
            if not state.is_running and state.current_index < state.total_files:
                state.is_running = True
                state.status = "Processing started."
                background_tasks.add_task(process_next_file_task)
                await save_state_to_hf()
                return {"status": "success", "message": "Processing loop started."}
            elif state.current_index >= state.total_files:
                return {"status": "error", "message": "Cannot start. All files have been processed."}
            else:
                return {"status": "info", "message": "Processing is already running."}
        elif action == "stop":
            if state.is_running:
                state.is_running = False
                state.status = "Processing stopped by user."
                await save_state_to_hf()
                return {"status": "success", "message": "Processing loop stopped."}
            else:
                return {"status": "info", "message": "Processing is already stopped."}
        else:
            raise HTTPException(status_code=400, detail="Invalid action.")

@app.get("/status")
async def get_status():
    """API endpoint to get the current server status as JSON."""
    async with state_lock:
        processed_count = len(state.processed_files)
        
        server_stats = [
            {
                "url": s.url,
                "busy": s.busy,
                "processed": s.total_processed,
                "fps": f"{s.fps:.2f}"
            } for s in servers
        ]
        
        total_processed = sum(s.total_processed for s in servers)
        total_time = sum(s.total_time for s in servers)
        overall_fps = total_processed / total_time if total_time > 0 else 0
        
        return {
            "flow_id": FLOW_ID,
            "status": state.status,
            "is_running": state.is_running,
            "total_files": state.total_files,
            "processed_count": processed_count,
            "remaining_count": state.total_files - processed_count,
            "current_index": state.current_index,
            "current_file": state.current_file,
            "current_file_progress": state.current_file_progress,
            "last_update": state.last_update,
            "overall_fps": f"{overall_fps:.2f}",
            "server_stats": server_stats
        }

# The original /process_course endpoint is now obsolete as the server manages its own queue
# @app.post("/process_course")
# async def process_course(request: ProcessCourseRequest, background_tasks: BackgroundTasks):
#     return {"status": "obsolete", "message": "The server now manages its own processing queue based on the index."}


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8001)