Darknsu commited on
Commit
ef2f1f1
Β·
verified Β·
1 Parent(s): 807bbf8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +200 -91
main.py CHANGED
@@ -372,7 +372,6 @@
372
 
373
 
374
 
375
-
376
  import os
377
  import json
378
  import torch
@@ -386,7 +385,8 @@ from eval import evaluation_detection
386
  from iou_utils import non_max_suppression, check_overlap_proposal
387
  from typing import List, Dict, Optional
388
  from huggingface_hub import hf_hub_download, list_repo_files
389
- from pathlib import Path
 
390
 
391
  # Configuration
392
  VIS_CONFIG = {
@@ -394,40 +394,107 @@ VIS_CONFIG = {
394
  'min_segment_duration': 1.0,
395
  }
396
 
397
- # Cache directory for downloaded .npz files
398
- CACHE_DIR = Path("./data/I3D")
399
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
400
-
401
- # Hugging Face dataset repository
402
  HF_DATASET_REPO = "Darknsu/EGTEA_Dataset"
403
- HF_NPZ_SUBFOLDER = "features" # Adjust if .npz files are in a different subfolder
404
 
405
  # Determine device
406
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
407
  print(f"Using device: {device}")
408
 
 
 
 
 
409
  def download_npz_file(video_name: str) -> str:
410
- """Download .npz file from HF dataset repo or return cached path"""
411
- npz_filename = f"{video_name}.npz"
412
- cache_path = CACHE_DIR / npz_filename
413
-
414
- if cache_path.exists():
415
- return str(cache_path)
416
-
417
  try:
418
- # Download the .npz file to cache
 
 
 
 
 
 
 
 
 
 
419
  downloaded_path = hf_hub_download(
420
  repo_id=HF_DATASET_REPO,
421
- filename=f"{npz_filename}",
422
  repo_type="dataset",
423
- cache_dir=CACHE_DIR,
424
- local_dir=CACHE_DIR,
425
- local_dir_use_symlinks=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  )
427
- return downloaded_path
 
 
 
 
 
 
 
 
 
 
 
428
  except Exception as e:
429
- print(f"Error downloading {npz_filename}: {str(e)}")
430
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  def eval_frame(opt, model, dataset):
433
  """Evaluate model frame by frame"""
@@ -565,28 +632,38 @@ def load_ground_truth(opt, video_name):
565
 
566
  return gt_segments, duration
567
 
568
- def process_video(video_name, split_number):
569
  """Process a single video for action localization"""
570
  try:
 
 
571
  # Parse options
572
  opt = opts.parse_opt()
573
  opt = vars(opt)
574
  opt['mode'] = 'test'
575
  opt['split'] = str(split_number)
576
  opt['checkpoint_path'] = './checkpoint'
577
- opt['video_feature_all_test'] = str(CACHE_DIR) # Use cache directory
578
  opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
579
  opt['batch_size'] = 1
580
 
 
 
581
  # Check if required files exist
582
  checkpoint_path = './checkpoint/01_ckp_best.pth.tar'
583
  if not os.path.exists(checkpoint_path):
584
- return "Error: Model checkpoint not found at ./checkpoint/01_ckp_best.pth.tar"
 
 
 
 
 
 
 
 
 
585
 
586
- # Download or get cached .npz file
587
- npz_path = download_npz_file(video_name)
588
- if not npz_path:
589
- return f"Error: Could not download feature file for {video_name}"
590
 
591
  # Load model
592
  model = MYNET(opt).to(device)
@@ -600,14 +677,21 @@ def process_video(video_name, split_number):
600
 
601
  model.eval()
602
 
603
- # Create dataset
604
- dataset = VideoDataSet(opt, subset='test', video_name=video_name)
 
 
605
 
606
  if len(dataset.video_list) == 0:
607
- return f"Error: No video found with name '{video_name}' in dataset"
 
 
608
 
609
  # Run inference
610
  output_cls, output_reg, labels_cls, labels_reg = eval_frame(opt, model, dataset)
 
 
 
611
  result_dict = eval_map_nms(opt, dataset, output_cls, output_reg)
612
 
613
  # Load ground truth
@@ -625,6 +709,8 @@ def process_video(video_name, split_number):
625
  'score': pred['score']
626
  })
627
 
 
 
628
  # Generate output text
629
  output_text = f"Predicted Actions for Video: {video_name}\n"
630
  output_text += "=" * 50 + "\n\n"
@@ -688,79 +774,102 @@ def process_video(video_name, split_number):
688
  output_text += f"Recall: {recall:.3f}\n"
689
  output_text += f"F1-Score: {f1:.3f}\n"
690
 
 
691
  return output_text
692
 
693
  except Exception as e:
694
- return f"Error processing video: {str(e)}\n\nPlease check:\n1. Model checkpoint exists\n2. Feature file exists in HF dataset\n3. All dependencies are installed"
695
 
696
- def get_available_videos():
697
- """Get list of available videos from HF dataset repo"""
698
- try:
699
- # List all files in the features subfolder
700
- repo_files = list_repo_files(
701
- repo_id=HF_DATASET_REPO,
702
- repo_type="dataset",
703
- )
704
- # Filter for .npz files and extract video names
705
- videos = [file.replace('.npz', '') for file in repo_files if file.endswith('.npz')]
706
- return sorted(videos) if videos else ["No videos found"]
707
- except Exception as e:
708
- print(f"Error listing videos: {str(e)}")
709
- return ["No videos found"]
710
 
711
  # Initialize available videos
712
- available_videos = get_available_videos()
 
 
 
 
 
713
 
714
  # Gradio Interface
715
- iface = gr.Interface(
716
- fn=process_video,
717
- inputs=[
718
- gr.Dropdown(
719
- label="Select Video",
720
- choices=available_videos,
721
- value=available_videos[0] if available_videos else None,
722
- info="Choose from videos in HF dataset: Darknsu/EGTEA_Dataset"
723
- ),
724
- gr.Dropdown(
725
- label="Split Number",
726
- choices=["1", "2", "3"],
727
- value="1",
728
- info="Dataset split for annotations"
729
- )
730
- ],
731
- outputs=[
732
- gr.Textbox(
733
- label="Action Predictions",
734
- lines=20,
735
- max_lines=50,
736
- show_copy_button=True
737
- )
738
- ],
739
- title="🎬 Temporal Action Localization",
740
- description="""
741
- This app performs temporal action localization on videos using I3D features from the EGTEA dataset.
742
 
743
- **How to use:**
744
- 1. Select a video from the dropdown (videos are loaded from HF dataset: Darknsu/EGTEA_Dataset)
745
- 2. Choose the annotation split number
746
- 3. Click Submit to get action predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
 
748
  **Requirements:**
749
- - Model checkpoint: `01_ckp_best.pth.tar` in root directory
750
- - Video features: Downloaded from HF dataset at runtime
751
- """,
752
- examples=[
753
- [available_videos[0] if available_videos and available_videos[0] != "No videos found" else "example_video", "1"],
754
- ] if available_videos and available_videos[0] != "No videos found" else None,
755
- cache_examples=False,
756
- theme=gr.themes.Soft()
757
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
 
759
  if __name__ == '__main__':
760
- print(f"Available videos: {available_videos}")
761
  print(f"Using device: {device}")
 
762
  iface.launch(
763
  server_name="0.0.0.0",
764
- server_port=4444,
765
  share=False
766
  )
 
372
 
373
 
374
 
 
375
  import os
376
  import json
377
  import torch
 
385
  from iou_utils import non_max_suppression, check_overlap_proposal
386
  from typing import List, Dict, Optional
387
  from huggingface_hub import hf_hub_download, list_repo_files
388
+ import tempfile
389
+ import shutil
390
 
391
  # Configuration
392
  VIS_CONFIG = {
 
394
  'min_segment_duration': 1.0,
395
  }
396
 
397
+ # Hugging Face Dataset Configuration
 
 
 
 
398
  HF_DATASET_REPO = "Darknsu/EGTEA_Dataset"
399
+ HF_DATASET_SUBFOLDER = "I3D" # Adjust this based on your dataset structure
400
 
401
  # Determine device
402
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
403
  print(f"Using device: {device}")
404
 
405
+ # Create local cache directory for downloaded files
406
+ CACHE_DIR = "./hf_cache"
407
+ os.makedirs(CACHE_DIR, exist_ok=True)
408
+
409
  def download_npz_file(video_name: str) -> str:
410
+ """
411
+ Download .npz file from Hugging Face dataset repository
412
+ Returns: Local path to the downloaded file
413
+ """
 
 
 
414
  try:
415
+ # Construct the file path in the dataset repo
416
+ file_path = f"{HF_DATASET_SUBFOLDER}/{video_name}.npz"
417
+
418
+ # Check if file already exists in cache
419
+ local_path = os.path.join(CACHE_DIR, f"{video_name}.npz")
420
+ if os.path.exists(local_path):
421
+ print(f"Using cached file: {local_path}")
422
+ return local_path
423
+
424
+ # Download from Hugging Face dataset
425
+ print(f"Downloading {file_path} from {HF_DATASET_REPO}...")
426
  downloaded_path = hf_hub_download(
427
  repo_id=HF_DATASET_REPO,
428
+ filename=file_path,
429
  repo_type="dataset",
430
+ cache_dir=CACHE_DIR
431
+ )
432
+
433
+ # Copy to our expected location for easier access
434
+ shutil.copy2(downloaded_path, local_path)
435
+ print(f"File downloaded and cached: {local_path}")
436
+ return local_path
437
+
438
+ except Exception as e:
439
+ raise Exception(f"Failed to download {video_name}.npz: {str(e)}")
440
+
441
+ def get_available_videos_from_hf():
442
+ """Get list of available videos from Hugging Face dataset repository"""
443
+ try:
444
+ print("Fetching available videos from Hugging Face dataset...")
445
+ files = list_repo_files(
446
+ repo_id=HF_DATASET_REPO,
447
+ repo_type="dataset"
448
  )
449
+
450
+ # Filter for .npz files in the I3D subfolder
451
+ videos = []
452
+ for file in files:
453
+ if file.startswith(f"{HF_DATASET_SUBFOLDER}/") and file.endswith('.npz'):
454
+ video_name = os.path.basename(file).replace('.npz', '')
455
+ videos.append(video_name)
456
+
457
+ videos = sorted(videos)
458
+ print(f"Found {len(videos)} videos in dataset")
459
+ return videos
460
+
461
  except Exception as e:
462
+ print(f"Error fetching videos from HF dataset: {str(e)}")
463
+ return ["Error loading videos"]
464
+
465
+ class HFVideoDataSet(VideoDataSet):
466
+ """
467
+ Modified VideoDataSet that downloads files from Hugging Face on demand
468
+ """
469
+ def __init__(self, opt, subset='test', video_name=None):
470
+ # Store the original video_feature_all_test path
471
+ self.original_feature_path = opt['video_feature_all_test']
472
+
473
+ # Create temporary directory for this session
474
+ self.temp_dir = tempfile.mkdtemp(prefix="hf_video_")
475
+ opt['video_feature_all_test'] = self.temp_dir
476
+
477
+ # Download the specific video file if video_name is provided
478
+ if video_name:
479
+ try:
480
+ downloaded_path = download_npz_file(video_name)
481
+ # Copy to temp directory with expected structure
482
+ temp_file_path = os.path.join(self.temp_dir, f"{video_name}.npz")
483
+ shutil.copy2(downloaded_path, temp_file_path)
484
+ print(f"Video file ready: {temp_file_path}")
485
+ except Exception as e:
486
+ print(f"Warning: Could not download video {video_name}: {str(e)}")
487
+
488
+ # Initialize parent class
489
+ super().__init__(opt, subset, video_name)
490
+
491
+ def __del__(self):
492
+ # Clean up temporary directory
493
+ try:
494
+ if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
495
+ shutil.rmtree(self.temp_dir)
496
+ except:
497
+ pass
498
 
499
  def eval_frame(opt, model, dataset):
500
  """Evaluate model frame by frame"""
 
632
 
633
  return gt_segments, duration
634
 
635
+ def process_video(video_name, split_number, progress=gr.Progress()):
636
  """Process a single video for action localization"""
637
  try:
638
+ progress(0.1, desc="Initializing...")
639
+
640
  # Parse options
641
  opt = opts.parse_opt()
642
  opt = vars(opt)
643
  opt['mode'] = 'test'
644
  opt['split'] = str(split_number)
645
  opt['checkpoint_path'] = './checkpoint'
646
+ opt['video_feature_all_test'] = './data/I3D/' # This will be overridden by HFVideoDataSet
647
  opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
648
  opt['batch_size'] = 1
649
 
650
+ progress(0.2, desc="Checking model checkpoint...")
651
+
652
  # Check if required files exist
653
  checkpoint_path = './checkpoint/01_ckp_best.pth.tar'
654
  if not os.path.exists(checkpoint_path):
655
+ # Try alternative locations
656
+ alt_paths = ['./01_ckp_best.pth.tar', '01_ckp_best.pth.tar']
657
+ checkpoint_path = None
658
+ for alt_path in alt_paths:
659
+ if os.path.exists(alt_path):
660
+ checkpoint_path = alt_path
661
+ break
662
+
663
+ if checkpoint_path is None:
664
+ return "Error: Model checkpoint not found. Please ensure '01_ckp_best.pth.tar' is in the repository."
665
 
666
+ progress(0.3, desc="Loading model...")
 
 
 
667
 
668
  # Load model
669
  model = MYNET(opt).to(device)
 
677
 
678
  model.eval()
679
 
680
+ progress(0.4, desc=f"Downloading video features for {video_name}...")
681
+
682
+ # Create dataset with HF integration
683
+ dataset = HFVideoDataSet(opt, subset='test', video_name=video_name)
684
 
685
  if len(dataset.video_list) == 0:
686
+ return f"Error: No video found with name '{video_name}' in dataset or failed to download"
687
+
688
+ progress(0.6, desc="Running inference...")
689
 
690
  # Run inference
691
  output_cls, output_reg, labels_cls, labels_reg = eval_frame(opt, model, dataset)
692
+
693
+ progress(0.8, desc="Processing results...")
694
+
695
  result_dict = eval_map_nms(opt, dataset, output_cls, output_reg)
696
 
697
  # Load ground truth
 
709
  'score': pred['score']
710
  })
711
 
712
+ progress(0.9, desc="Generating output...")
713
+
714
  # Generate output text
715
  output_text = f"Predicted Actions for Video: {video_name}\n"
716
  output_text += "=" * 50 + "\n\n"
 
774
  output_text += f"Recall: {recall:.3f}\n"
775
  output_text += f"F1-Score: {f1:.3f}\n"
776
 
777
+ progress(1.0, desc="Complete!")
778
  return output_text
779
 
780
  except Exception as e:
781
+ return f"Error processing video: {str(e)}\n\nPlease check:\n1. Model checkpoint exists\n2. Video exists in HF dataset\n3. All dependencies are installed"
782
 
783
+ def refresh_video_list():
784
+ """Refresh the list of available videos"""
785
+ return gr.Dropdown(choices=get_available_videos_from_hf())
 
 
 
 
 
 
 
 
 
 
 
786
 
787
  # Initialize available videos
788
+ print("Loading available videos from Hugging Face dataset...")
789
+ available_videos = get_available_videos_from_hf()
790
+ if not available_videos or available_videos == ["Error loading videos"]:
791
+ available_videos = ["Error: Could not load videos from HF dataset"]
792
+
793
+ print(f"Available videos: {len(available_videos)} videos found")
794
 
795
  # Gradio Interface
796
+ with gr.Blocks(theme=gr.themes.Soft(), title="🎬 Temporal Action Localization") as iface:
797
+ gr.Markdown("""
798
+ # 🎬 Temporal Action Localization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
 
800
+ This app performs temporal action localization on videos using I3D features loaded dynamically from Hugging Face datasets.
801
+
802
+ **Features:**
803
+ - βœ… Dynamic loading from HF dataset repository
804
+ - βœ… Real-time inference with progress tracking
805
+ - βœ… Ground truth comparison when available
806
+ - βœ… Detailed action predictions with confidence scores
807
+ """)
808
+
809
+ with gr.Row():
810
+ with gr.Column(scale=1):
811
+ video_dropdown = gr.Dropdown(
812
+ label="Select Video",
813
+ choices=available_videos,
814
+ value=available_videos[0] if available_videos and "Error" not in available_videos[0] else None,
815
+ info="Videos loaded from Hugging Face dataset"
816
+ )
817
+
818
+ split_dropdown = gr.Dropdown(
819
+ label="Split Number",
820
+ choices=["1", "2", "3"],
821
+ value="1",
822
+ info="Dataset split for annotations"
823
+ )
824
+
825
+ refresh_btn = gr.Button("πŸ”„ Refresh Video List", variant="secondary")
826
+ submit_btn = gr.Button("πŸš€ Run Action Localization", variant="primary")
827
+
828
+ with gr.Column(scale=2):
829
+ output_text = gr.Textbox(
830
+ label="Action Predictions",
831
+ lines=25,
832
+ max_lines=50,
833
+ show_copy_button=True,
834
+ placeholder="Results will appear here..."
835
+ )
836
+
837
+ gr.Markdown(f"""
838
+ **Dataset Source:** [{HF_DATASET_REPO}](https://huggingface.co/datasets/{HF_DATASET_REPO})
839
 
840
  **Requirements:**
841
+ - Model checkpoint: `01_ckp_best.pth.tar` in repository root
842
+ - Video features: Automatically downloaded from HF dataset
843
+ """)
844
+
845
+ # Event handlers
846
+ refresh_btn.click(
847
+ fn=lambda: gr.Dropdown(choices=get_available_videos_from_hf()),
848
+ outputs=video_dropdown
849
+ )
850
+
851
+ submit_btn.click(
852
+ fn=process_video,
853
+ inputs=[video_dropdown, split_dropdown],
854
+ outputs=output_text
855
+ )
856
+
857
+ # Example
858
+ if available_videos and "Error" not in available_videos[0]:
859
+ gr.Examples(
860
+ examples=[[available_videos[0], "1"]],
861
+ inputs=[video_dropdown, split_dropdown],
862
+ fn=process_video,
863
+ outputs=output_text,
864
+ cache_examples=False
865
+ )
866
 
867
  if __name__ == '__main__':
868
+ print(f"Available videos: {len(available_videos)}")
869
  print(f"Using device: {device}")
870
+ print(f"HF Dataset: {HF_DATASET_REPO}")
871
  iface.launch(
872
  server_name="0.0.0.0",
873
+ server_port=7860,
874
  share=False
875
  )