mustafa2ak commited on
Commit
6e605a5
Β·
verified Β·
1 Parent(s): b4346ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +456 -56
app.py CHANGED
@@ -1,15 +1,17 @@
1
  """
2
- Simplified Dog Tracking for Training Dataset Collection
3
  - Process video with adjustable threshold
4
  - Temporary storage with discard option
5
- - Manual validation (grid of thumbnails with per-image checkboxes)
6
  - Export to folder structure for fine-tuning
 
7
  - Automatic HuggingFace backup/restore
8
  """
9
  import os
10
  os.environ["OMP_NUM_THREADS"] = "1"
11
 
12
  import zipfile
 
13
  import gradio as gr
14
  import cv2
15
  import numpy as np
@@ -53,8 +55,8 @@ class DatasetCollectionApp:
53
  self.current_video_path = None
54
  self.is_processing = False
55
 
56
- # Validation state: list of (temp_id, [checkbox components])
57
- self.validation_checkboxes: List[tuple] = []
58
 
59
  print("Dataset Collection App initialized")
60
  print(f"Database has {len(self.db.get_all_dogs())} dogs")
@@ -73,7 +75,7 @@ class DatasetCollectionApp:
73
  self.tracker.reset()
74
  self.reid.reset_session()
75
  self.current_video_path = None
76
- self.validation_checkboxes = []
77
 
78
  gc.collect()
79
  if torch.cuda.is_available():
@@ -83,7 +85,8 @@ class DatasetCollectionApp:
83
  None, # Clear video input
84
  "<p style='text-align:center; color:#868e96;'>Session cleared. Upload a new video to start.</p>",
85
  "",
86
- ""
 
87
  )
88
 
89
  def discard_session(self):
@@ -92,17 +95,16 @@ class DatasetCollectionApp:
92
  self.temp_session.clear()
93
  self.tracker.reset()
94
  self.reid.reset_session()
95
- self.validation_checkboxes = []
96
 
97
  gc.collect()
98
  if torch.cuda.is_available():
99
  torch.cuda.empty_cache()
100
 
101
- # Return UI updates for validation container + status + database display
102
  return (
103
- gr.update(visible=False), # hide validation container
104
  f"Discarded {count} temporary dogs. Try different threshold.",
105
- gr.update(visible=False) # hide database display
106
  )
107
 
108
  def process_video(self, video_path: str, reid_threshold: float,
@@ -110,11 +112,12 @@ class DatasetCollectionApp:
110
  """Process video and store in temporary session"""
111
 
112
  if not video_path:
113
- return None, "Please upload a video", ""
114
 
115
  self.is_processing = True
116
  self.current_video_path = video_path
117
  self.temp_session.clear()
 
118
 
119
  # Set threshold
120
  self.reid.set_threshold(reid_threshold)
@@ -127,7 +130,7 @@ class DatasetCollectionApp:
127
  try:
128
  cap = cv2.VideoCapture(video_path)
129
  if not cap.isOpened():
130
- return None, "Cannot open video", ""
131
 
132
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
133
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
@@ -248,6 +251,10 @@ class DatasetCollectionApp:
248
  # Store in temp session
249
  self.temp_session = temp_dogs
250
 
 
 
 
 
251
  # Generate summary
252
  summary = f"Processing complete!\n"
253
  summary += f"Detected {original_count} dogs initially\n"
@@ -261,9 +268,11 @@ class DatasetCollectionApp:
261
  if len(temp_dogs) == 0:
262
  summary += "No dogs met the minimum requirement of 14 images.\n"
263
  summary += "Try adjusting the ReID threshold or using a longer video."
 
264
  else:
265
  summary += "Results stored in TEMPORARY session\n"
266
- summary += "Review and validate before saving to database"
 
267
 
268
  gallery_html = self._create_temp_gallery()
269
 
@@ -271,12 +280,17 @@ class DatasetCollectionApp:
271
  if torch.cuda.is_available():
272
  torch.cuda.empty_cache()
273
 
274
- return gallery_html, summary, "Ready for validation" if len(temp_dogs) > 0 else "No valid dogs"
 
 
 
 
 
275
 
276
  except Exception as e:
277
  import traceback
278
  error = f"Error: {str(e)}\n{traceback.format_exc()}"
279
- return None, error, ""
280
  finally:
281
  self.is_processing = False
282
 
@@ -322,56 +336,442 @@ class DatasetCollectionApp:
322
  html += "</div></div>"
323
  return html
324
 
325
- def create_validation_interface(self):
326
- """Create simplified validation interface for visual review (fallback)
327
- Note: This function returns HTML used by the simple load button.
328
- We keep it for backward compatibility; the interactive grid is built
329
- by render_validation in the Gradio Blocks UI below."""
330
  if not self.temp_session:
331
- return "<p>No temporary session to validate</p>"
332
-
 
 
 
 
 
 
 
333
  html = "<div style='padding: 20px;'>"
334
- html += "<h2 style='text-align:center;'>Validation - Visual Review</h2>"
335
- html += "<p style='text-align:center; color:#666;'>Review images before saving. All images will be saved when you click 'Save to Database'</p>"
 
336
 
337
- for temp_id in sorted(self.temp_session.keys()):
338
- dog_data = self.temp_session[temp_id]
339
- images = dog_data['images']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  html += f"""
342
- <div style='border: 2px solid #495057; border-radius: 10px;
343
- padding: 15px; margin: 20px 0; background: #f8f9fa;'>
344
- <h3 style='margin: 0 0 15px 0;'>Temp Dog #{temp_id} - {len(images)} images</h3>
345
- <div style='display: grid; grid-template-columns: repeat(6, 1fr); gap: 10px;'>
 
 
 
 
 
346
  """
347
-
348
- for idx, img in enumerate(images):
 
349
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
350
  img_base64 = self._img_to_base64(img_rgb)
351
-
352
  html += f"""
353
- <div style='position: relative;'>
354
- <img src='data:image/jpeg;base64,{img_base64}'
355
- style='width: 100%; aspect-ratio: 1; object-fit: cover;
356
- border-radius: 5px; border: 2px solid #dee2e6;'>
357
- <div style='position: absolute; bottom: 5px; right: 5px;
358
- background: rgba(0,0,0,0.7); color: white;
359
- padding: 2px 6px; border-radius: 3px; font-size: 10px;'>
360
- {idx+1}
361
- </div>
362
- </div>
363
  """
364
-
365
- html += """
366
- </div>
367
- </div>
368
- """
369
-
370
- html += "<p style='text-align:center; color:#868e96; margin-top: 30px;'>"
371
- html += "If results look good, click 'Save to Database' below.<br>"
372
- html += "If not satisfied, go back to Tab 1 and click 'Discard & Retry' with different threshold."
373
- html += "</p>"
374
- html += "</div>"
375
  return html
376
 
377
- # (rest of code unchanged, includes save_validated_to_database, _backup_database, _restore_database, _show_database, export_dataset, _img_to_base64, create_interface, launch, __main__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Simplified Dog Tracking for Training Dataset Collection
3
  - Process video with adjustable threshold
4
  - Temporary storage with discard option
5
+ - Manual validation with checkbox selection per image
6
  - Export to folder structure for fine-tuning
7
+ - Download to laptop as ZIP
8
  - Automatic HuggingFace backup/restore
9
  """
10
  import os
11
  os.environ["OMP_NUM_THREADS"] = "1"
12
 
13
  import zipfile
14
+ import tempfile
15
  import gradio as gr
16
  import cv2
17
  import numpy as np
 
55
  self.current_video_path = None
56
  self.is_processing = False
57
 
58
+ # Validation state: stores checkbox states for each temp_id
59
+ self.validation_data = {} # {temp_id: [bool, bool, ...]}
60
 
61
  print("Dataset Collection App initialized")
62
  print(f"Database has {len(self.db.get_all_dogs())} dogs")
 
75
  self.tracker.reset()
76
  self.reid.reset_session()
77
  self.current_video_path = None
78
+ self.validation_data = {}
79
 
80
  gc.collect()
81
  if torch.cuda.is_available():
 
85
  None, # Clear video input
86
  "<p style='text-align:center; color:#868e96;'>Session cleared. Upload a new video to start.</p>",
87
  "",
88
+ "",
89
+ gr.update(visible=False) # Hide validation area
90
  )
91
 
92
  def discard_session(self):
 
95
  self.temp_session.clear()
96
  self.tracker.reset()
97
  self.reid.reset_session()
98
+ self.validation_data = {}
99
 
100
  gc.collect()
101
  if torch.cuda.is_available():
102
  torch.cuda.empty_cache()
103
 
 
104
  return (
105
+ gr.update(visible=False), # Hide validation container
106
  f"Discarded {count} temporary dogs. Try different threshold.",
107
+ gr.update(visible=False) # Hide database display
108
  )
109
 
110
  def process_video(self, video_path: str, reid_threshold: float,
 
112
  """Process video and store in temporary session"""
113
 
114
  if not video_path:
115
+ return None, "Please upload a video", "", gr.update(visible=False)
116
 
117
  self.is_processing = True
118
  self.current_video_path = video_path
119
  self.temp_session.clear()
120
+ self.validation_data = {}
121
 
122
  # Set threshold
123
  self.reid.set_threshold(reid_threshold)
 
130
  try:
131
  cap = cv2.VideoCapture(video_path)
132
  if not cap.isOpened():
133
+ return None, "Cannot open video", "", gr.update(visible=False)
134
 
135
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
136
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
251
  # Store in temp session
252
  self.temp_session = temp_dogs
253
 
254
+ # Initialize validation data (all images selected by default)
255
+ for temp_id in temp_dogs.keys():
256
+ self.validation_data[temp_id] = [True] * len(temp_dogs[temp_id]['images'])
257
+
258
  # Generate summary
259
  summary = f"Processing complete!\n"
260
  summary += f"Detected {original_count} dogs initially\n"
 
268
  if len(temp_dogs) == 0:
269
  summary += "No dogs met the minimum requirement of 14 images.\n"
270
  summary += "Try adjusting the ReID threshold or using a longer video."
271
+ show_validation = False
272
  else:
273
  summary += "Results stored in TEMPORARY session\n"
274
+ summary += "Go to Tab 2 to review and select images before saving"
275
+ show_validation = True
276
 
277
  gallery_html = self._create_temp_gallery()
278
 
 
280
  if torch.cuda.is_available():
281
  torch.cuda.empty_cache()
282
 
283
+ return (
284
+ gallery_html,
285
+ summary,
286
+ "Ready for validation" if len(temp_dogs) > 0 else "No valid dogs",
287
+ gr.update(visible=show_validation)
288
+ )
289
 
290
  except Exception as e:
291
  import traceback
292
  error = f"Error: {str(e)}\n{traceback.format_exc()}"
293
+ return None, error, "", gr.update(visible=False)
294
  finally:
295
  self.is_processing = False
296
 
 
336
  html += "</div></div>"
337
  return html
338
 
339
+ def load_validation_interface(self):
340
+ """Load validation interface with checkbox selection"""
 
 
 
341
  if not self.temp_session:
342
+ return (
343
+ gr.update(visible=False),
344
+ "No temporary session to validate. Process a video first.",
345
+ ""
346
+ )
347
+
348
+ # Create components list for dynamic rendering
349
+ validation_components = []
350
+
351
  html = "<div style='padding: 20px;'>"
352
+ html += "<h2 style='text-align:center;'>Review and Select Images</h2>"
353
+ html += "<p style='text-align:center; color:#666;'>Check/uncheck images to keep/discard. All are selected by default.</p>"
354
+ html += "</div>"
355
 
356
+ status = f"Loaded {len(self.temp_session)} dogs for validation. Review and click 'Save Selected to Database' when ready."
357
+
358
+ return (
359
+ gr.update(visible=True),
360
+ status,
361
+ html
362
+ )
363
+
364
+ def save_validated_to_database(self, *checkbox_states):
365
+ """Save validated images to permanent database"""
366
+ if not self.temp_session:
367
+ return "No temporary session to save", gr.update()
368
+
369
+ try:
370
+ saved_count = 0
371
+ total_images_saved = 0
372
+
373
+ # Collect checkbox states
374
+ checkbox_idx = 0
375
+
376
+ for temp_id in sorted(self.temp_session.keys()):
377
+ dog_data = self.temp_session[temp_id]
378
+ num_images = len(dog_data['images'])
379
+
380
+ # Get checkbox states for this dog
381
+ selected_indices = []
382
+ for i in range(num_images):
383
+ if checkbox_idx < len(checkbox_states) and checkbox_states[checkbox_idx]:
384
+ selected_indices.append(i)
385
+ checkbox_idx += 1
386
+
387
+ # Skip if no images selected
388
+ if not selected_indices:
389
+ continue
390
+
391
+ # Add dog to database
392
+ dog_id = self.db.add_dog(
393
+ name=f"Dog_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{temp_id}"
394
+ )
395
+
396
+ # Add only selected images
397
+ for idx in selected_indices:
398
+ self.db.add_dog_image(
399
+ dog_id=dog_id,
400
+ image=dog_data['images'][idx],
401
+ timestamp=dog_data['timestamps'][idx],
402
+ confidence=dog_data['confidences'][idx],
403
+ bbox=dog_data['bboxes'][idx]
404
+ )
405
+ total_images_saved += 1
406
+
407
+ saved_count += 1
408
+
409
+ # Clear temporary session after saving
410
+ self.temp_session.clear()
411
+ self.validation_data = {}
412
+
413
+ # Backup to HuggingFace
414
+ self._backup_database()
415
+
416
+ # Show updated database
417
+ db_html = self._show_database()
418
+
419
+ summary = f"βœ… Successfully saved {saved_count} dogs with {total_images_saved} selected images to permanent database!"
420
+
421
+ gc.collect()
422
+ if torch.cuda.is_available():
423
+ torch.cuda.empty_cache()
424
+
425
+ return summary, gr.update(value=db_html, visible=True)
426
+
427
+ except Exception as e:
428
+ import traceback
429
+ error = f"Error saving: {str(e)}\n{traceback.format_exc()}"
430
+ return error, gr.update()
431
+
432
+ def _backup_database(self):
433
+ """Backup database to HuggingFace"""
434
+ try:
435
+ from huggingface_hub import HfApi
436
+
437
+ hf_token = os.getenv('HF_TOKEN')
438
+ if not hf_token:
439
+ print("Warning: HF_TOKEN not found, skipping backup")
440
+ return
441
+
442
+ api = HfApi()
443
+ repo_id = "mustafa2ak/dog-dataset-backup"
444
+
445
+ # Upload database file
446
+ api.upload_file(
447
+ path_or_fileobj='dog_monitoring.db',
448
+ path_in_repo='dog_monitoring.db',
449
+ repo_id=repo_id,
450
+ repo_type='dataset',
451
+ token=hf_token
452
+ )
453
+
454
+ print(f"βœ… Database backed up to {repo_id}")
455
+
456
+ except Exception as e:
457
+ print(f"Backup failed: {str(e)}")
458
 
459
+ def _restore_database(self):
460
+ """Restore database from HuggingFace"""
461
+ try:
462
+ from huggingface_hub import hf_hub_download
463
+
464
+ hf_token = os.getenv('HF_TOKEN')
465
+ if not hf_token:
466
+ print("No HF_TOKEN found, starting with fresh database")
467
+ return
468
+
469
+ repo_id = "mustafa2ak/dog-dataset-backup"
470
+
471
+ # Download database
472
+ db_path = hf_hub_download(
473
+ repo_id=repo_id,
474
+ filename='dog_monitoring.db',
475
+ repo_type='dataset',
476
+ token=hf_token
477
+ )
478
+
479
+ # Copy to current directory
480
+ import shutil
481
+ shutil.copy(db_path, 'dog_monitoring.db')
482
+
483
+ print(f"βœ… Database restored from {repo_id}")
484
+
485
+ except Exception as e:
486
+ print(f"No backup found or restore failed: {str(e)}")
487
+
488
+ def _show_database(self) -> str:
489
+ """Show current database contents"""
490
+ dogs = self.db.get_all_dogs()
491
+
492
+ if not dogs:
493
+ return "<p style='text-align:center; color:#868e96;'>No dogs in database yet</p>"
494
+
495
+ html = "<div style='padding: 20px;'>"
496
+ html += f"<h2 style='text-align:center; color:#228be6;'>Permanent Database ({len(dogs)} dogs)</h2>"
497
+ html += "<div style='display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 20px;'>"
498
+
499
+ for _, dog in dogs.iterrows():
500
+ images = self.db.get_dog_images(dog['dog_id'])
501
+ display_count = min(6, len(images))
502
+
503
  html += f"""
504
+ <div style='border: 2px solid #228be6; border-radius: 10px;
505
+ padding: 15px; background: #e7f5ff;'>
506
+ <h3 style='margin: 0 0 10px 0; color:#1971c2;'>{dog['name']}</h3>
507
+ <p style='color: #666; margin: 5px 0;'>ID: {dog['dog_id']}</p>
508
+ <p style='color: #666; margin: 5px 0;'>Images: {len(images)}</p>
509
+ <p style='color: #666; margin: 5px 0; font-size: 12px;'>
510
+ First seen: {dog['first_seen']}
511
+ </p>
512
+ <div style='display: grid; grid-template-columns: repeat(3, 1fr); gap: 5px; margin-top: 10px;'>
513
  """
514
+
515
+ for img_data in images[:display_count]:
516
+ img = img_data['image']
517
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
518
  img_base64 = self._img_to_base64(img_rgb)
 
519
  html += f"""
520
+ <img src='data:image/jpeg;base64,{img_base64}'
521
+ style='width: 100%; aspect-ratio: 1; object-fit: cover;
522
+ border-radius: 5px;'>
 
 
 
 
 
 
 
523
  """
524
+
525
+ html += "</div></div>"
526
+
527
+ html += "</div></div>"
 
 
 
 
 
 
 
528
  return html
529
 
530
+ def export_dataset(self):
531
+ """Export dataset as downloadable ZIP file"""
532
+ try:
533
+ dogs = self.db.get_all_dogs()
534
+
535
+ if dogs.empty:
536
+ return "No dogs in database to export", None
537
+
538
+ # Create in-memory ZIP file
539
+ zip_buffer = BytesIO()
540
+
541
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
542
+ total_images = 0
543
+ export_info = []
544
+
545
+ for _, dog in dogs.iterrows():
546
+ dog_id = dog['dog_id']
547
+ dog_name = dog['name'] or f"dog_{dog_id}"
548
+ safe_name = "".join(c if c.isalnum() or c in ('_', '-') else '_' for c in dog_name)
549
+
550
+ images = self.db.get_dog_images(dog_id)
551
+
552
+ if not images:
553
+ continue
554
+
555
+ # Add each image to ZIP
556
+ for idx, img_data in enumerate(images):
557
+ image = img_data['image']
558
+
559
+ # Convert to PIL Image
560
+ img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
561
+ pil_image = Image.fromarray(img_rgb)
562
+
563
+ # Save to bytes
564
+ img_buffer = BytesIO()
565
+ pil_image.save(img_buffer, format='JPEG', quality=95)
566
+ img_bytes = img_buffer.getvalue()
567
+
568
+ # Add to ZIP
569
+ filename = f"training_dataset/{safe_name}/image_{idx+1:04d}.jpg"
570
+ zipf.writestr(filename, img_bytes)
571
+ total_images += 1
572
+
573
+ export_info.append({
574
+ 'dog_id': int(dog_id),
575
+ 'name': dog_name,
576
+ 'image_count': len(images)
577
+ })
578
+
579
+ # Add metadata
580
+ metadata = {
581
+ 'export_date': datetime.now().isoformat(),
582
+ 'total_dogs': len(dogs),
583
+ 'total_images': total_images,
584
+ 'dogs': export_info
585
+ }
586
+
587
+ zipf.writestr('training_dataset/metadata.json', json.dumps(metadata, indent=2))
588
+
589
+ # Save to temporary file
590
+ zip_buffer.seek(0)
591
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip', prefix='dog_dataset_')
592
+ temp_file.write(zip_buffer.getvalue())
593
+ temp_file.close()
594
+
595
+ summary = f"βœ… Dataset exported successfully!\n\n"
596
+ summary += f"πŸ“¦ Total dogs: {len(dogs)}\n"
597
+ summary += f"πŸ–ΌοΈ Total images: {total_images}\n\n"
598
+ summary += "Click the download button below to save to your laptop."
599
+
600
+ return summary, temp_file.name
601
+
602
+ except Exception as e:
603
+ import traceback
604
+ error = f"Export error: {str(e)}\n{traceback.format_exc()}"
605
+ return error, None
606
+
607
+ def _img_to_base64(self, img_array: np.ndarray) -> str:
608
+ """Convert image array to base64 string"""
609
+ img_pil = Image.fromarray(img_array)
610
+ buffered = BytesIO()
611
+ img_pil.save(buffered, format="JPEG", quality=85)
612
+ return base64.b64encode(buffered.getvalue()).decode()
613
+
614
+ def create_interface(self):
615
+ """Create Gradio interface with validation checkboxes"""
616
+
617
+ with gr.Blocks(title="Dog Dataset Collection", theme=gr.themes.Soft()) as app:
618
+ gr.Markdown("""
619
+ # πŸ• Dog Training Dataset Collection
620
+ **Process β†’ Validate β†’ Save β†’ Export**
621
+ """)
622
+
623
+ with gr.Tabs():
624
+ # TAB 1: Process Video
625
+ with gr.Tab("1. Process Video"):
626
+ gr.Markdown("### Upload and process video to detect dogs")
627
+
628
+ with gr.Row():
629
+ with gr.Column():
630
+ video_input = gr.Video(label="Upload Video")
631
+
632
+ with gr.Row():
633
+ reid_threshold = gr.Slider(
634
+ minimum=0.1, maximum=0.9, value=0.3, step=0.05,
635
+ label="ReID Threshold (lower = more dogs)"
636
+ )
637
+ sample_rate = gr.Slider(
638
+ minimum=1, maximum=10, value=3, step=1,
639
+ label="Frame Sampling Rate"
640
+ )
641
+
642
+ flip_camera = gr.Checkbox(label="Flip Camera Horizontally", value=False)
643
+
644
+ with gr.Row():
645
+ process_btn = gr.Button("🎬 Process Video", variant="primary", size="lg")
646
+ stop_btn = gr.Button("⏹️ Stop", variant="stop")
647
+ clear_btn = gr.Button("πŸ—‘οΈ Clear & Reset")
648
+
649
+ progress_text = gr.Textbox(label="Progress", lines=1)
650
+ status_text = gr.Textbox(label="Status", lines=8)
651
+
652
+ with gr.Column():
653
+ gallery_output = gr.HTML(label="Detection Results")
654
+
655
+ with gr.Row():
656
+ discard_btn = gr.Button("❌ Discard & Retry with Different Threshold", variant="secondary")
657
+
658
+ # TAB 2: Validate & Save
659
+ with gr.Tab("2. Validate & Save"):
660
+ gr.Markdown("### Review detected dogs and select images to keep")
661
+
662
+ with gr.Column(visible=False) as validation_container:
663
+ validation_status = gr.Textbox(label="Status", lines=2)
664
+
665
+ load_btn = gr.Button("πŸ“‹ Load Validation Interface", variant="primary", size="lg")
666
+
667
+ # Dynamic validation area
668
+ @gr.render(inputs=[], triggers=[load_btn.click])
669
+ def render_validation():
670
+ if not self.temp_session:
671
+ gr.Markdown("No temporary session. Process a video first.")
672
+ return
673
+
674
+ checkboxes = []
675
+
676
+ for temp_id in sorted(self.temp_session.keys()):
677
+ dog_data = self.temp_session[temp_id]
678
+ images = dog_data['images']
679
+
680
+ with gr.Group():
681
+ gr.Markdown(f"### πŸ• Dog #{temp_id} - {len(images)} images")
682
+
683
+ # Create grid of images with checkboxes
684
+ for i in range(0, len(images), 6):
685
+ with gr.Row():
686
+ for j in range(6):
687
+ if i + j < len(images):
688
+ img = images[i + j]
689
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
690
+
691
+ with gr.Column(scale=1, min_width=120):
692
+ gr.Image(
693
+ value=img_rgb,
694
+ label=f"#{i+j+1}",
695
+ interactive=False,
696
+ height=150,
697
+ show_download_button=False
698
+ )
699
+ cb = gr.Checkbox(
700
+ label="Keep",
701
+ value=True,
702
+ elem_id=f"cb_{temp_id}_{i+j}"
703
+ )
704
+ checkboxes.append(cb)
705
+
706
+ # Save button
707
+ save_btn = gr.Button("πŸ’Ύ Save Selected to Database", variant="primary", size="lg")
708
+ save_status = gr.Textbox(label="Save Status", lines=3)
709
+
710
+ # Connect save button
711
+ save_btn.click(
712
+ fn=self.save_validated_to_database,
713
+ inputs=checkboxes,
714
+ outputs=[save_status, validation_container]
715
+ )
716
+
717
+ # TAB 3: Database & Export
718
+ with gr.Tab("3. Database & Export"):
719
+ gr.Markdown("### View database and export for fine-tuning")
720
+
721
+ refresh_db_btn = gr.Button("πŸ”„ Refresh Database", variant="secondary")
722
+ database_display = gr.HTML(label="Database Contents", visible=False)
723
+
724
+ gr.Markdown("---")
725
+
726
+ export_btn = gr.Button("πŸ“¦ Export Dataset", variant="primary", size="lg")
727
+ export_status = gr.Textbox(label="Export Status", lines=5)
728
+ download_btn = gr.File(label="Download Exported Dataset", interactive=False)
729
+
730
+ # Event handlers
731
+ process_btn.click(
732
+ fn=self.process_video,
733
+ inputs=[video_input, reid_threshold, flip_camera, sample_rate],
734
+ outputs=[gallery_output, status_text, progress_text, validation_container]
735
+ )
736
+
737
+ stop_btn.click(
738
+ fn=self.stop_processing,
739
+ outputs=[status_text, progress_text, gallery_output]
740
+ )
741
+
742
+ clear_btn.click(
743
+ fn=self.clear_reset,
744
+ outputs=[video_input, gallery_output, status_text, progress_text, validation_container]
745
+ )
746
+
747
+ discard_btn.click(
748
+ fn=self.discard_session,
749
+ outputs=[validation_container, status_text, database_display]
750
+ )
751
+
752
+ load_btn.click(
753
+ fn=self.load_validation_interface,
754
+ outputs=[validation_container, validation_status, gr.HTML()]
755
+ )
756
+
757
+ refresh_db_btn.click(
758
+ fn=lambda: gr.update(value=self._show_database(), visible=True),
759
+ outputs=[database_display]
760
+ )
761
+
762
+ export_btn.click(
763
+ fn=self.export_dataset,
764
+ outputs=[export_status, download_btn]
765
+ )
766
+
767
+ return app
768
+
769
+ def launch(self):
770
+ """Launch the Gradio app"""
771
+ app = self.create_interface()
772
+ app.launch(share=False, server_name="0.0.0.0", server_port=7860)
773
+
774
+
775
+ if __name__ == "__main__":
776
+ app = DatasetCollectionApp()
777
+ app.launch()