Fred808 commited on
Commit
ae693ae
·
verified ·
1 Parent(s): 37ae538

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -63
app.py CHANGED
@@ -5,6 +5,8 @@ import os
5
  import threading
6
  import time
7
  import urllib.parse
 
 
8
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
9
  from fastapi.responses import JSONResponse
10
  import json
@@ -24,11 +26,44 @@ app = FastAPI(
24
 
25
  # --- Environment Configuration ---
26
  HF_TOKEN = os.getenv("HF_TOKEN", "")
27
- HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/data") # Dataset to store results
 
28
  HF_STATE_FILE = os.getenv("HF_STATE_FILE", "processing_state_cursors.json")
29
  TEMP_DATASET_DIR = Path("temp_cursor_detection")
30
  TEMP_DATASET_DIR.mkdir(exist_ok=True)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Global variable to store loaded templates
33
  CURSOR_TEMPLATES: Dict[str, np.ndarray] = {}
34
  CURSOR_TEMPLATES_DIR = Path("cursors")
@@ -201,35 +236,25 @@ def _unlock_file_as_processed(image_name: str, state: dict, next_index: int) ->
201
  state['next_download_index'] = next_index
202
  return _upload_hf_state(state)
203
 
204
- def _get_image_list_from_hf() -> list:
205
- """Return sorted list of image file paths from HF_DATASET_ID."""
206
- try:
207
- api = HfApi(token=HF_TOKEN)
208
- files = api.list_repo_files(repo_id=HF_DATASET_ID, repo_type="dataset")
209
- image_files = sorted([f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
210
- print(f"[DATASET] Found {len(image_files)} image files in {HF_DATASET_ID}.")
211
- return image_files
212
- except Exception as e:
213
- print(f"[DATASET] Error listing HF dataset files: {e}")
214
- return []
215
 
216
- def _upload_cursor_results(image_name: str, results: dict) -> bool:
217
- """Upload cursor detection results JSON to dataset."""
 
218
  try:
219
- filename = Path(image_name).with_suffix('.json').name
220
  content = json.dumps(results, indent=2, ensure_ascii=False).encode('utf-8')
221
  api = HfApi(token=HF_TOKEN)
222
  api.upload_file(
223
  path_or_fileobj=io.BytesIO(content),
224
  path_in_repo=f"cursor_results/{filename}",
225
- repo_id=HF_DATASET_ID,
226
  repo_type="dataset",
227
- commit_message=f"Cursor detection results for {image_name}"
228
  )
229
- print(f"[DATASET] Uploaded results for {image_name} to {HF_DATASET_ID}.")
230
  return True
231
  except Exception as e:
232
- print(f"[DATASET] Failed to upload results for {image_name}: {e}")
233
  return False
234
 
235
  class DatasetProgress:
@@ -281,7 +306,7 @@ async def process_image(image_bytes: bytes, threshold: float = 0.8) -> dict:
281
  raise ValueError(f"Error processing image: {str(e)}")
282
 
283
  async def dataset_task(start_index: int = 1):
284
- """Main dataset processing loop."""
285
  global dataset_progress
286
 
287
  dataset_progress = DatasetProgress()
@@ -299,88 +324,115 @@ async def dataset_task(start_index: int = 1):
299
 
300
  try:
301
  state = await asyncio.to_thread(_load_hf_state)
302
- image_list = await asyncio.to_thread(_get_image_list_from_hf)
303
 
304
- if not image_list:
305
- err = "No images found in dataset"
306
  dataset_progress.status = "error"
307
  dataset_progress.error = err
308
  print(f"[DATASET] {err}")
309
  return False
310
 
311
- dataset_progress.total_images = len(image_list)
312
  dataset_progress.status = "processing"
313
 
314
  if start_index < 1:
315
  start_index = 1
316
 
317
- for idx in range(start_index-1, len(image_list)):
318
  try:
319
- image_path = image_list[idx]
320
- image_name = Path(image_path).name
321
- print(f"[DATASET] Processing image {idx + 1}/{len(image_list)}: {image_name}")
322
 
323
- file_state = state.get('file_states', {}).get(image_name)
324
  if file_state == 'processed':
325
- print(f"[DATASET] Skipping {image_name}: already processed.")
326
  dataset_progress.processed_images += 1
327
  continue
328
  if file_state == 'processing':
329
- print(f"[DATASET] Skipping {image_name}: currently processing by another worker.")
330
  continue
331
 
332
  # Try to lock
333
- locked = await asyncio.to_thread(_lock_file_for_processing, image_name, state)
334
  if not locked:
335
- print(f"[DATASET] Could not lock {image_name}; skipping.")
336
  continue
337
 
338
  try:
339
- # Download and process image
340
- print(f"[DATASET] Downloading {image_name}...")
341
- image_bytes = await asyncio.to_thread(
342
- lambda: hf_hub_download(
343
- repo_id=HF_DATASET_ID,
344
- filename=image_path,
345
- repo_type="dataset",
346
- token=HF_TOKEN
347
- )
348
- )
349
-
350
- with open(image_bytes, 'rb') as f:
351
- content = f.read()
352
-
353
- # Process image
354
- results = await process_image(content)
355
- results['image_name'] = image_name
356
- results['image_path'] = image_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
- # Upload results
359
- uploaded = await asyncio.to_thread(_upload_cursor_results, image_name, results)
360
 
361
  if uploaded:
362
  next_index = idx + 2 # next 1-based index
363
- ok = await asyncio.to_thread(_unlock_file_as_processed, image_name, state, next_index)
364
  if not ok:
365
- print(f"[DATASET] Warning: processed but failed to update state for {image_name}.")
366
  dataset_progress.processed_images += 1
367
- print(f"[DATASET] Successfully processed {image_name}")
368
  else:
369
- print(f"[DATASET] Failed to upload results for {image_name}")
370
- state['file_states'][image_name] = 'failed'
371
  await asyncio.to_thread(_upload_hf_state, state)
372
-
373
  except Exception as e:
374
- print(f"[DATASET] Error processing {image_name}: {e}")
375
- state['file_states'][image_name] = 'failed'
376
  await asyncio.to_thread(_upload_hf_state, state)
377
  continue
 
 
 
 
 
 
 
378
 
379
  except Exception as e:
380
- print(f"[DATASET] Error in image loop: {e}")
381
  continue
382
 
383
- print(f"[DATASET] Task completed. Processed {dataset_progress.processed_images}/{len(image_list)} images.")
384
  dataset_progress.status = "completed"
385
  return True
386
 
 
5
  import threading
6
  import time
7
  import urllib.parse
8
+ import zipfile
9
+ import shutil
10
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
11
  from fastapi.responses import JSONResponse
12
  import json
 
26
 
27
  # --- Environment Configuration ---
28
  HF_TOKEN = os.getenv("HF_TOKEN", "")
29
+ HF_DATASET_ID = os.getenv("HF_DATASET_ID", "Fred808/BG3") # Source dataset with zips
30
+ HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "Fred808/data") # Results dataset
31
  HF_STATE_FILE = os.getenv("HF_STATE_FILE", "processing_state_cursors.json")
32
  TEMP_DATASET_DIR = Path("temp_cursor_detection")
33
  TEMP_DATASET_DIR.mkdir(exist_ok=True)
34
 
35
+ def _get_zip_file_list_from_hf() -> list:
36
+ """Return sorted list of zip file paths from HF_DATASET_ID."""
37
+ try:
38
+ api = HfApi(token=HF_TOKEN)
39
+ files = api.list_repo_files(repo_id=HF_DATASET_ID, repo_type="dataset")
40
+ zip_files = sorted([f for f in files if f.startswith('frames_zips/') and f.lower().endswith('.zip')])
41
+ print(f"[DATASET] Found {len(zip_files)} zip files in {HF_DATASET_ID}.")
42
+ return zip_files
43
+ except Exception as e:
44
+ print(f"[DATASET] Error listing HF dataset files: {e}")
45
+ return []
46
+
47
+ def _download_and_extract_zip(repo_path: str) -> Optional[Path]:
48
+ """Download zip from HF dataset and extract into a temp subfolder."""
49
+ try:
50
+ zip_local = hf_hub_download(repo_id=HF_DATASET_ID, filename=repo_path, repo_type="dataset", token=HF_TOKEN)
51
+ zip_name = Path(repo_path).name
52
+ extract_dir = TEMP_DATASET_DIR / zip_name.replace('.zip','')
53
+ if extract_dir.exists():
54
+ shutil.rmtree(extract_dir)
55
+ extract_dir.mkdir(parents=True, exist_ok=True)
56
+ with zipfile.ZipFile(zip_local, 'r') as z:
57
+ z.extractall(extract_dir)
58
+ try:
59
+ os.remove(zip_local)
60
+ except Exception:
61
+ pass
62
+ return extract_dir
63
+ except Exception as e:
64
+ print(f"[DATASET] Error downloading/extracting {repo_path}: {e}")
65
+ return None
66
+
67
  # Global variable to store loaded templates
68
  CURSOR_TEMPLATES: Dict[str, np.ndarray] = {}
69
  CURSOR_TEMPLATES_DIR = Path("cursors")
 
236
  state['next_download_index'] = next_index
237
  return _upload_hf_state(state)
238
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+
241
+ def _upload_cursor_results(zip_name: str, results: dict) -> bool:
242
+ """Upload cursor detection results JSON to output dataset."""
243
  try:
244
+ filename = Path(zip_name).with_suffix('.json').name
245
  content = json.dumps(results, indent=2, ensure_ascii=False).encode('utf-8')
246
  api = HfApi(token=HF_TOKEN)
247
  api.upload_file(
248
  path_or_fileobj=io.BytesIO(content),
249
  path_in_repo=f"cursor_results/{filename}",
250
+ repo_id=HF_OUTPUT_DATASET_ID, # Using output dataset
251
  repo_type="dataset",
252
+ commit_message=f"Cursor detection results for {zip_name}"
253
  )
254
+ print(f"[DATASET] Uploaded results for {zip_name} to {HF_OUTPUT_DATASET_ID}.")
255
  return True
256
  except Exception as e:
257
+ print(f"[DATASET] Failed to upload results for {zip_name}: {e}")
258
  return False
259
 
260
  class DatasetProgress:
 
306
  raise ValueError(f"Error processing image: {str(e)}")
307
 
308
  async def dataset_task(start_index: int = 1):
309
+ """Main dataset processing loop for processing zip files."""
310
  global dataset_progress
311
 
312
  dataset_progress = DatasetProgress()
 
324
 
325
  try:
326
  state = await asyncio.to_thread(_load_hf_state)
327
+ zip_list = await asyncio.to_thread(_get_zip_file_list_from_hf)
328
 
329
+ if not zip_list:
330
+ err = "No zip files found in dataset"
331
  dataset_progress.status = "error"
332
  dataset_progress.error = err
333
  print(f"[DATASET] {err}")
334
  return False
335
 
336
+ dataset_progress.total_images = len(zip_list)
337
  dataset_progress.status = "processing"
338
 
339
  if start_index < 1:
340
  start_index = 1
341
 
342
+ for idx in range(start_index-1, len(zip_list)):
343
  try:
344
+ zip_path = zip_list[idx]
345
+ zip_name = Path(zip_path).name
346
+ print(f"[DATASET] Processing zip {idx + 1}/{len(zip_list)}: {zip_name}")
347
 
348
+ file_state = state.get('file_states', {}).get(zip_name)
349
  if file_state == 'processed':
350
+ print(f"[DATASET] Skipping {zip_name}: already processed.")
351
  dataset_progress.processed_images += 1
352
  continue
353
  if file_state == 'processing':
354
+ print(f"[DATASET] Skipping {zip_name}: currently processing by another worker.")
355
  continue
356
 
357
  # Try to lock
358
+ locked = await asyncio.to_thread(_lock_file_for_processing, zip_name, state)
359
  if not locked:
360
+ print(f"[DATASET] Could not lock {zip_name}; skipping.")
361
  continue
362
 
363
  try:
364
+ # Download and extract zip
365
+ print(f"[DATASET] Downloading and extracting {zip_name}...")
366
+ extract_dir = await asyncio.to_thread(_download_and_extract_zip, zip_path)
367
+ if not extract_dir:
368
+ print(f"[DATASET] Failed to download/extract {zip_name}; marking failed.")
369
+ state['file_states'][zip_name] = 'failed'
370
+ await asyncio.to_thread(_upload_hf_state, state)
371
+ continue
372
+
373
+ # Find all images in extracted directory
374
+ image_paths = [p for p in extract_dir.rglob('*') if p.is_file() and p.suffix.lower() in ('.jpg','.jpeg','.png')]
375
+ print(f"[DATASET] Found {len(image_paths)} images in {zip_name}")
376
+
377
+ # Process all images in the zip
378
+ results = []
379
+ for image_path in image_paths:
380
+ try:
381
+ with open(image_path, 'rb') as f:
382
+ content = f.read()
383
+
384
+ # Process image for cursor detection
385
+ image_result = await process_image(content)
386
+ image_result['image_name'] = image_path.name
387
+ image_result['image_path'] = str(image_path.relative_to(extract_dir))
388
+ results.append(image_result)
389
+
390
+ except Exception as e:
391
+ print(f"[DATASET] Error processing {image_path.name}: {e}")
392
+ continue
393
+
394
+ # Create combined results for the zip
395
+ zip_results = {
396
+ 'zip_name': zip_name,
397
+ 'zip_path': zip_path,
398
+ 'total_images': len(image_paths),
399
+ 'processed_images': len(results),
400
+ 'results': results
401
+ }
402
 
403
+ # Upload combined results
404
+ uploaded = await asyncio.to_thread(_upload_cursor_results, zip_name, zip_results)
405
 
406
  if uploaded:
407
  next_index = idx + 2 # next 1-based index
408
+ ok = await asyncio.to_thread(_unlock_file_as_processed, zip_name, state, next_index)
409
  if not ok:
410
+ print(f"[DATASET] Warning: processed but failed to update state for {zip_name}.")
411
  dataset_progress.processed_images += 1
412
+ print(f"[DATASET] Successfully processed {zip_name}")
413
  else:
414
+ print(f"[DATASET] Failed to upload results for {zip_name}")
415
+ state['file_states'][zip_name] = 'failed'
416
  await asyncio.to_thread(_upload_hf_state, state)
417
+
418
  except Exception as e:
419
+ print(f"[DATASET] Error processing zip {zip_name}: {e}")
420
+ state['file_states'][zip_name] = 'failed'
421
  await asyncio.to_thread(_upload_hf_state, state)
422
  continue
423
+ finally:
424
+ # Cleanup extracted directory
425
+ try:
426
+ if extract_dir and extract_dir.exists():
427
+ shutil.rmtree(extract_dir)
428
+ except Exception as e:
429
+ print(f"[DATASET] Warning: Failed to clean up {extract_dir}: {e}")
430
 
431
  except Exception as e:
432
+ print(f"[DATASET] Error in zip processing loop: {e}")
433
  continue
434
 
435
+ print(f"[DATASET] Task completed. Processed {dataset_progress.processed_images}/{len(zip_list)} zip files.")
436
  dataset_progress.status = "completed"
437
  return True
438