Darknsu commited on
Commit
a51395e
·
verified ·
1 Parent(s): f7a6c33

Upload 24 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sakib Reza
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,91 @@
1
- ---
2
- title: HATTAL
3
- emoji: 🐠
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.34.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # HAT: History-Augmented Anchor Transformer for Online Temporal Action Localization (ECCV 2024)
3
+ ### Sakib Reza, Yuexi Zhang, Mohsen Moghaddam, Octavia Camps
4
+ #### Northeastern University, Boston, United States
5
+ {reza.s,zhang.yuex,mohsen,o.camps}@northeastern.edu
6
+
7
+ ## [Arxiv Preprint](https://arxiv.org/abs/2408.06437)
8
+
9
+
10
+ ## Updates
11
+ - Aug 22, 2024 - EGTEA pre-extracted features and config files for other datasets added
12
+ - Aug 14, 2024 - Arxiv preprint added
13
+ - July 7, 2024 - initial code release
14
+
15
+ ## Installation
16
+
17
+ ### Prerequisites
18
+ - Ubuntu 20.04
19
+ - Python 3.10.9
20
+ - CUDA 12.0
21
+
22
+ ### Requirements
23
+ - pytorch==2.0.0
24
+ - numpy==1.23.5
25
+ - h5py==3.9.0
26
+ - ...
27
+
28
+ To install all required libraries, execute the pip command below.
29
+ ```
30
+ pip install -r requirement.txt
31
+ ```
32
+
33
+ ## Training
34
+
35
+ ### Input Features
36
+ The Kinetics I3D pre-trained feature of EGTEA dataset can be downloaded from [GDrive link](https://drive.google.com/drive/folders/1Zj1B2UZnjPgLrylhKOfu7m_9rkQFa14T?usp=sharing).
37
+ Files should be located in 'data/'.
38
+ You can get other features from the following links -
39
+ - [EPIC-Kitchen 100](https://github.com/happyharrycn/actionformer_release)
40
+ - [THUMOS'14](https://github.com/YHKimGithub/OAT-OSN/)
41
+ - [MUSES](https://songbai.site/muses/)
42
+
43
+ ### Config Files
44
+ The configuration files for EGTEA are already provided in the repository. For other datasets, they can be downloaded from [GDrive link](https://drive.google.com/drive/folders/19__GnM2HZCCDshED9kadsLNAI9XBvrFd?usp=sharing).
45
+
46
+ ### Training Model
47
+ To train the main HAT model, execute the command below.
48
+ ```
49
+ python main.py --mode=train --split=[split #]*
50
+
51
+ ```
52
+ ```
53
+ !python main.py --mode=train --batch_size=256 --epoch=1
54
+ ```
55
+ *If the dataset has any splits (e.g., EGTEA has 4 splits)
56
+
57
+ To train the post-processing network (OSN), execute the commands below.
58
+ ```
59
+ python supnet.py --mode=make --inference_subset=train --split=[split #]
60
+ python supnet.py --mode=make --inference_subset=test --split=[split #]
61
+ python supnet.py --mode=train --split=[split #]
62
+ ```
63
+
64
+
65
+ ## Testing
66
+ To test HAT, execute the command below.
67
+ ```
68
+ python main.py --mode=test --split=[split #]
69
+ ```
70
+
71
+ ```
72
+ !python main.py --mode=test --batch_size=256 --epoch=1
73
+
74
+ ```
75
+
76
+ ## Citing HAT
77
+ Please cite our paper in your publications if it helps your research:
78
+
79
+ ```BibTeX
80
+ @inproceedings{reza2022history,
81
+ title={HAT: History-Augmented Anchor Transformer for Online Temporal Action Localization},
82
+ author={Reza, Sakib and Zhang, Yuexi and Moghaddam, Mohsen and Camps, Octavia},
83
+ booktitle={European Conference on Computer Vision},
84
+ pages={XXX--XXX},
85
+ year={2024},
86
+ organization={Springer}
87
+ }
88
+ ```
89
+
90
+ ## Acknowledgment
91
+ This repository is created based on the repository of the baseline work [OAT-OSN](https://github.com/YHKimGithub/OAT-OSN/).
annotated video generate main.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+
11
+ import time
12
+ import h5py
13
+ from tqdm import tqdm
14
+ from iou_utils import *
15
+ from eval import evaluation_detection
16
+ from tensorboardX import SummaryWriter
17
+ from dataset import VideoDataSet, calc_iou
18
+ from models import MYNET, SuppressNet
19
+ from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
20
+ from loss_func import MultiCrossEntropyLoss
21
+ from functools import *
22
+
23
+ import matplotlib.pyplot as plt
24
+ import matplotlib.patches as patches
25
+ import cv2
26
+ from typing import List, Dict, Optional
27
+
28
+ # Visualization Configuration
29
+ # Visualization Configuration
30
+ # Visualization Configuration (Updated)
31
+ # Visualization Configuration (Updated)
32
+ VIS_CONFIG = {
33
+ 'frame_interval': 1.0,
34
+ 'max_frames': 20,
35
+ 'save_dir': './output/visualizations',
36
+ 'video_save_dir': './output/videos',
37
+ 'gt_color': '#1f77b4', # Blue for ground truth (RGB: 31, 119, 180)
38
+ 'pred_color': '#ff7f0e', # Orange for predictions (RGB: 255, 127, 14)
39
+ 'fontsize_label': 10,
40
+ 'fontsize_title': 14,
41
+ 'frame_highlight_both': 'green',
42
+ 'frame_highlight_gt': 'red',
43
+ 'frame_highlight_pred': 'black',
44
+ 'iou_threshold': 0.3,
45
+ 'frame_scale_factor': 0.8,
46
+ 'video_text_scale': 0.5, # Smaller text size
47
+ 'video_gt_text_color': (180, 119, 31), # BGR for OpenCV
48
+ 'video_pred_text_color': (14, 127, 255), # BGR for OpenCV
49
+ 'video_text_thickness': 1, # Thinner for smaller text
50
+ 'video_font_path': './fonts/Roboto-Regular.ttf', # Path to TrueType font
51
+ 'video_pred_text_y': 0.45, # Fraction of frame height (slightly above middle)
52
+ 'video_gt_text_y': 0.55, # Fraction of frame height (slightly below middle)
53
+ }
54
+
55
+ from PIL import Image, ImageDraw, ImageFont
56
+ import warnings
57
+
58
+ def annotate_video_with_actions(
59
+ video_id: str,
60
+ pred_segments: List[Dict],
61
+ gt_segments: List[Dict],
62
+ video_path: str,
63
+ save_dir: str = VIS_CONFIG['video_save_dir'],
64
+ text_scale: float = VIS_CONFIG['video_text_scale'],
65
+ gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'],
66
+ pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'],
67
+ text_thickness: int = VIS_CONFIG['video_text_thickness']
68
+ ) -> None:
69
+ """
70
+ Annotate a video with predicted and ground truth action labels overlaid on frames using a stylish font.
71
+
72
+ Args:
73
+ video_id: Video identifier (e.g., 'my_video').
74
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
75
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
76
+ video_path: Path to the input video file.
77
+ save_dir: Directory to save the annotated video.
78
+ text_scale: Scale factor for text size.
79
+ gt_text_color: BGR color tuple for ground truth text.
80
+ pred_text_color: BGR color tuple for predicted text.
81
+ text_thickness: Thickness of text strokes.
82
+ """
83
+ os.makedirs(save_dir, exist_ok=True)
84
+
85
+ # Open input video
86
+ cap = cv2.VideoCapture(video_path)
87
+ if not cap.isOpened():
88
+ print(f"Error: Could not open video {video_path}. Skipping video annotation.")
89
+ return
90
+
91
+ # Get video properties
92
+ fps = cap.get(cv2.CAP_PROP_FPS)
93
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
94
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
96
+ print(f"Input Video: FPS={fps:.2f}, Resolution={frame_width}x{frame_height}, Total Frames={total_frames}")
97
+
98
+ # Define output video
99
+ output_path = os.path.join(save_dir, f"annotated_{video_id}_{opt['exp']}.avi")
100
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
101
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
102
+
103
+ if not out.isOpened():
104
+ print(f"Error: Could not initialize video writer for {output_path}. Check codec availability.")
105
+ cap.release()
106
+ return
107
+
108
+ # Load font
109
+ font_path = VIS_CONFIG['video_font_path']
110
+ font_size = int(20 * text_scale) # Base size adjusted by scale
111
+ try:
112
+ font = ImageFont.truetype(font_path, font_size)
113
+ except IOError:
114
+ print(f"Warning: Font {font_path} not found. Falling back to OpenCV default font.")
115
+ font = None
116
+
117
+ frame_idx = 0
118
+ written_frames = 0
119
+ while cap.isOpened():
120
+ ret, frame = cap.read()
121
+ if not ret:
122
+ break
123
+
124
+ # Calculate current timestamp
125
+ timestamp = frame_idx / fps
126
+
127
+ # Find active GT actions
128
+ gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
129
+ gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else "GT: None"
130
+
131
+ # Find active predicted actions
132
+ pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
133
+ pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else "Pred: None"
134
+
135
+ if font:
136
+ # Convert frame to PIL image
137
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
138
+ pil_image = Image.fromarray(frame_rgb)
139
+ draw = ImageDraw.Draw(pil_image)
140
+
141
+ # Draw GT text (left-middle, slightly below center)
142
+ gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
143
+ draw.text((10, gt_y), gt_text, font=font, fill=(gt_text_color[2], gt_text_color[1], gt_text_color[0]))
144
+
145
+ # Draw predicted text (left-middle, slightly above center)
146
+ pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
147
+ draw.text((10, pred_y), pred_text, font=font, fill=(pred_text_color[2], pred_text_color[1], pred_text_color[0]))
148
+
149
+ # Convert back to OpenCV frame
150
+ frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
151
+ else:
152
+ # Fallback to OpenCV font
153
+ cv2.putText(
154
+ frame,
155
+ gt_text,
156
+ (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
157
+ cv2.FONT_HERSHEY_DUPLEX, # Slightly more stylish than SIMPLEX
158
+ text_scale,
159
+ gt_text_color,
160
+ text_thickness,
161
+ cv2.LINE_AA
162
+ )
163
+ cv2.putText(
164
+ frame,
165
+ pred_text,
166
+ (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
167
+ cv2.FONT_HERSHEY_DUPLEX,
168
+ text_scale,
169
+ pred_text_color,
170
+ text_thickness,
171
+ cv2.LINE_AA
172
+ )
173
+
174
+ # Write frame to output video
175
+ out.write(frame)
176
+ written_frames += 1
177
+ frame_idx += 1
178
+
179
+ # Release resources
180
+ cap.release()
181
+ out.release()
182
+ print(f"[✅ Saved Annotated Video]: {output_path}, Written Frames={written_frames}")
183
+ print("Note: If .avi is not playable, convert to .mp4 using FFmpeg:")
184
+ print(f"ffmpeg -i {output_path} -vcodec libx264 -acodec aac {output_path.replace('.avi', '.mp4')}")
185
+
186
+ def visualize_action_lengths(
187
+ video_id: str,
188
+ pred_segments: List[Dict],
189
+ gt_segments: List[Dict],
190
+ video_path: str,
191
+ duration: float,
192
+ save_dir: str = VIS_CONFIG['save_dir'],
193
+ frame_interval: float = VIS_CONFIG['frame_interval']
194
+ ) -> None:
195
+ """
196
+ Generate a visualization plot comparing ground truth and predicted action lengths with video frames.
197
+
198
+ Args:
199
+ video_id: Video identifier (e.g., 'my_video').
200
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
201
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
202
+ video_path: Path to the input video file.
203
+ duration: Total duration of the video in seconds.
204
+ save_dir: Directory to save the output image.
205
+ frame_interval: Time interval between sampled frames (seconds).
206
+ """
207
+ os.makedirs(save_dir, exist_ok=True)
208
+
209
+ # Calculate frame sampling times
210
+ num_frames = int(duration / frame_interval) + 1
211
+ if num_frames > VIS_CONFIG['max_frames']:
212
+ frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
213
+ num_frames = VIS_CONFIG['max_frames']
214
+ print(f"Warning: Video duration ({duration:.1f}s) requires {num_frames} frames. Adjusted frame_interval to {frame_interval:.2f}s.")
215
+
216
+ frame_times = np.linspace(0, duration, num_frames, endpoint=False)
217
+
218
+ # Load video frames
219
+ frames = []
220
+ cap = cv2.VideoCapture(video_path)
221
+ if not cap.isOpened():
222
+ print(f"Warning: Could not open video {video_path}. Using placeholder frames.")
223
+ frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times]
224
+ else:
225
+ for t in frame_times:
226
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
227
+ ret, frame = cap.read()
228
+ if ret:
229
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
230
+ # Resize frame to reduce memory usage
231
+ frame = cv2.resize(frame, (int(frame.shape[1] * 0.5), int(frame.shape[0] * 0.5)))
232
+ frames.append(frame)
233
+ else:
234
+ frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
235
+ cap.release()
236
+
237
+ # Initialize figure
238
+ fig = plt.figure(figsize=(num_frames * VIS_CONFIG['frame_scale_factor'], 6), constrained_layout=True)
239
+ gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
240
+
241
+ # Plot frames
242
+ for i, (t, frame) in enumerate(zip(frame_times, frames)):
243
+ ax = fig.add_subplot(gs[0, i])
244
+
245
+ # Check if frame falls within GT or predicted segments
246
+ gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments)
247
+ pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments)
248
+
249
+ # Set border color
250
+ border_color = None
251
+ if gt_hit and pred_hit:
252
+ border_color = VIS_CONFIG['frame_highlight_both']
253
+ elif gt_hit:
254
+ border_color = VIS_CONFIG['frame_highlight_gt']
255
+ elif pred_hit:
256
+ border_color = VIS_CONFIG['frame_highlight_pred']
257
+
258
+ ax.imshow(frame)
259
+ ax.axis('off')
260
+ if border_color:
261
+ for spine in ax.spines.values():
262
+ spine.set_edgecolor(border_color)
263
+ spine.set_linewidth(2)
264
+
265
+ ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'],
266
+ color=border_color if border_color else 'black')
267
+
268
+ # Plot ground truth bar
269
+ ax_gt = fig.add_subplot(gs[1, :])
270
+ ax_gt.set_xlim(0, duration)
271
+ ax_gt.set_ylim(0, 1)
272
+ ax_gt.axis('off')
273
+ ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'],
274
+ va='center', ha='right', weight='bold')
275
+
276
+ for seg in gt_segments:
277
+ start, end = seg['start'], seg['end']
278
+ width = end - start
279
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
280
+ ax_gt.add_patch(patches.Rectangle(
281
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'],
282
+ edgecolor='black', alpha=0.8
283
+ ))
284
+ ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center',
285
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
286
+ ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black')
287
+ ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black')
288
+
289
+ # Plot prediction bar
290
+ ax_pred = fig.add_subplot(gs[2, :])
291
+ ax_pred.set_xlim(0, duration)
292
+ ax_pred.set_ylim(0, 1)
293
+ ax_pred.axis('off')
294
+ ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'],
295
+ va='center', ha='right', weight='bold')
296
+
297
+ for seg in pred_segments:
298
+ start, end = seg['start'], seg['end']
299
+ width = end - start
300
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
301
+ ax_pred.add_patch(patches.Rectangle(
302
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'],
303
+ edgecolor='black', alpha=0.8
304
+ ))
305
+ ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center',
306
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
307
+ ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black')
308
+ ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black')
309
+
310
+ # Save plot
311
+ jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png") # Use PNG
312
+ plt.savefig(jpg_path, dpi=100, bbox_inches='tight') # Lower DPI
313
+ print(f"[✅ Saved Visualization]: {jpg_path}")
314
+ plt.close()
315
+
316
+
317
+
318
+ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
319
+ train_loader = torch.utils.data.DataLoader(train_dataset,
320
+ batch_size=opt['batch_size'], shuffle=True,
321
+ num_workers=0, pin_memory=True, drop_last=False)
322
+ epoch_cost = 0
323
+ epoch_cost_cls = 0
324
+ epoch_cost_reg = 0
325
+ epoch_cost_snip = 0
326
+
327
+ total_iter = len(train_dataset) // opt['batch_size']
328
+ cls_loss = MultiCrossEntropyLoss(focal=True)
329
+ snip_loss = MultiCrossEntropyLoss(focal=True)
330
+ for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)):
331
+ if warmup:
332
+ for g in optimizer.param_groups:
333
+ g['lr'] = n_iter * (opt['lr']) / total_iter
334
+
335
+ act_cls, act_reg, snip_cls = model(input_data.float().cuda())
336
+
337
+ act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
338
+ snip_cls.register_hook(partial(snip_loss.collect_grad, snip_label))
339
+
340
+ cost_reg = 0
341
+ cost_cls = 0
342
+
343
+ loss = cls_loss_func_(cls_loss, cls_label, act_cls)
344
+ cost_cls = loss
345
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
346
+
347
+ loss = regress_loss_func(reg_label, act_reg)
348
+ cost_reg = loss
349
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
350
+
351
+ loss = cls_loss_func_(snip_loss, snip_label, snip_cls)
352
+ cost_snip = loss
353
+ epoch_cost_snip += cost_snip.detach().cpu().numpy()
354
+
355
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip
356
+ epoch_cost += cost.detach().cpu().numpy()
357
+
358
+ optimizer.zero_grad()
359
+ cost.backward()
360
+ optimizer.step()
361
+
362
+ return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
363
+
364
+ def eval_one_epoch(opt, model, test_dataset):
365
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, test_dataset)
366
+
367
+ result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
368
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
369
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
370
+ json.dump(output_dict, outfile, indent=2)
371
+ outfile.close()
372
+
373
+ IoUmAP = evaluation_detection(opt, verbose=False)
374
+ IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:])
375
+
376
+ return cls_loss, reg_loss, tot_loss, IoUmAP_5
377
+
378
+ def train(opt):
379
+ writer = SummaryWriter()
380
+ model = MYNET(opt).cuda()
381
+
382
+ rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
383
+ optimizer = optim.Adam([{'params': model.history_unit.parameters(), 'lr': 1e-6}, {'params': rest_of_model_params}], lr=opt["lr"], weight_decay=opt["weight_decay"])
384
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt["lr_step"])
385
+
386
+ train_dataset = VideoDataSet(opt, subset="train")
387
+ test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
388
+
389
+ warmup = False
390
+
391
+ for n_epoch in range(opt['epoch']):
392
+ if n_epoch >= 1:
393
+ warmup = False
394
+
395
+ n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
396
+
397
+ writer.add_scalars('data/cost', {'train': epoch_cost / (n_iter + 1)}, n_epoch)
398
+ print("training loss(epoch %d): %.03f, cls - %f, reg - %f, snip - %f, lr - %f" % (n_epoch,
399
+ epoch_cost / (n_iter + 1),
400
+ epoch_cost_cls / (n_iter + 1),
401
+ epoch_cost_reg / (n_iter + 1),
402
+ epoch_cost_snip / (n_iter + 1),
403
+ optimizer.param_groups[-1]["lr"]))
404
+
405
+ scheduler.step()
406
+ model.eval()
407
+
408
+ cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
409
+
410
+ writer.add_scalars('data/mAP', {'test': IoUmAP_5}, n_epoch)
411
+ print("testing loss(epoch %d): %.03f, cls - %f, reg - %f, mAP Avg - %f" % (n_epoch, tot_loss, cls_loss, reg_loss, IoUmAP_5))
412
+
413
+ state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
414
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_checkpoint_" + str(n_epoch + 1) + ".pth.tar")
415
+ if IoUmAP_5 > model.best_map:
416
+ model.best_map = IoUmAP_5
417
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_ckp_best.pth.tar")
418
+
419
+ model.train()
420
+
421
+ writer.close()
422
+ return model.best_map
423
+
424
+ def eval_frame(opt, model, dataset):
425
+ test_loader = torch.utils.data.DataLoader(dataset,
426
+ batch_size=opt['batch_size'], shuffle=False,
427
+ num_workers=0, pin_memory=True, drop_last=False)
428
+
429
+ labels_cls = {}
430
+ labels_reg = {}
431
+ output_cls = {}
432
+ output_reg = {}
433
+ for video_name in dataset.video_list:
434
+ labels_cls[video_name] = []
435
+ labels_reg[video_name] = []
436
+ output_cls[video_name] = []
437
+ output_reg[video_name] = []
438
+
439
+ start_time = time.time()
440
+ total_frames = 0
441
+ epoch_cost = 0
442
+ epoch_cost_cls = 0
443
+ epoch_cost_reg = 0
444
+
445
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
446
+ act_cls, act_reg, _ = model(input_data.float().cuda())
447
+ cost_reg = 0
448
+ cost_cls = 0
449
+
450
+ loss = cls_loss_func(cls_label, act_cls)
451
+ cost_cls = loss
452
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
453
+
454
+ loss = regress_loss_func(reg_label, act_reg)
455
+ cost_reg = loss
456
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
457
+
458
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
459
+ epoch_cost += cost.detach().cpu().numpy()
460
+
461
+ act_cls = torch.softmax(act_cls, dim=-1)
462
+
463
+ total_frames += input_data.size(0)
464
+
465
+ for b in range(0, input_data.size(0)):
466
+ video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + b]
467
+ output_cls[video_name] += [act_cls[b, :].detach().cpu().numpy()]
468
+ output_reg[video_name] += [act_reg[b, :].detach().cpu().numpy()]
469
+ labels_cls[video_name] += [cls_label[b, :].numpy()]
470
+ labels_reg[video_name] += [reg_label[b, :].numpy()]
471
+
472
+ end_time = time.time()
473
+ working_time = end_time - start_time
474
+
475
+ for video_name in dataset.video_list:
476
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
477
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
478
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
479
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
480
+
481
+ cls_loss = epoch_cost_cls / n_iter
482
+ reg_loss = epoch_cost_reg / n_iter
483
+ tot_loss = epoch_cost / n_iter
484
+
485
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
486
+
487
+ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
488
+ result_dict = {}
489
+ proposal_dict = []
490
+
491
+ num_class = opt["num_of_class"]
492
+ unit_size = opt['segment_size']
493
+ threshold = opt['threshold']
494
+ anchors = opt['anchors']
495
+
496
+ for video_name in dataset.video_list:
497
+ duration = dataset.video_len[video_name]
498
+ video_time = float(dataset.video_dict[video_name]["duration"])
499
+ frame_to_time = 100.0 * video_time / duration
500
+
501
+ for idx in range(0, duration):
502
+ cls_anc = output_cls[video_name][idx]
503
+ reg_anc = output_reg[video_name][idx]
504
+
505
+ proposal_anc_dict = []
506
+ for anc_idx in range(0, len(anchors)):
507
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
508
+
509
+ if len(cls) == 0:
510
+ continue
511
+
512
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
513
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
514
+ st = ed - length
515
+
516
+ for cidx in range(0, len(cls)):
517
+ label = cls[cidx]
518
+ tmp_dict = {}
519
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
520
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
521
+ tmp_dict["label"] = dataset.label_name[label]
522
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
523
+ proposal_anc_dict.append(tmp_dict)
524
+
525
+ proposal_dict += proposal_anc_dict
526
+
527
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
528
+ result_dict[video_name] = proposal_dict
529
+ proposal_dict = []
530
+
531
+ return result_dict
532
+
533
+ def eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
534
+ model = SuppressNet(opt).cuda()
535
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
536
+ base_dict = checkpoint['state_dict']
537
+ model.load_state_dict(base_dict)
538
+ model.eval()
539
+
540
+ result_dict = {}
541
+ proposal_dict = []
542
+
543
+ num_class = opt["num_of_class"]
544
+ unit_size = opt['segment_size']
545
+ threshold = opt['threshold']
546
+ anchors = opt['anchors']
547
+
548
+ for video_name in dataset.video_list:
549
+ duration = dataset.video_len[video_name]
550
+ video_time = float(dataset.video_dict[video_name]["duration"])
551
+ frame_to_time = 100.0 * video_time / duration
552
+ conf_queue = torch.zeros((unit_size, num_class - 1))
553
+
554
+ for idx in range(0, duration):
555
+ cls_anc = output_cls[video_name][idx]
556
+ reg_anc = output_reg[video_name][idx]
557
+
558
+ proposal_anc_dict = []
559
+ for anc_idx in range(0, len(anchors)):
560
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
561
+
562
+ if len(cls) == 0:
563
+ continue
564
+
565
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
566
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
567
+ st = ed - length
568
+
569
+ for cidx in range(0, len(cls)):
570
+ label = cls[cidx]
571
+ tmp_dict = {}
572
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
573
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
574
+ tmp_dict["label"] = dataset.label_name[label]
575
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
576
+ proposal_anc_dict.append(tmp_dict)
577
+
578
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
579
+
580
+ conf_queue[:-1, :] = conf_queue[1:, :].clone()
581
+ conf_queue[-1, :] = 0
582
+ for proposal in proposal_anc_dict:
583
+ cls_idx = dataset.label_name.index(proposal['label'])
584
+ conf_queue[-1, cls_idx] = proposal["score"]
585
+
586
+ minput = conf_queue.unsqueeze(0)
587
+ suppress_conf = model(minput.cuda())
588
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
589
+
590
+ for cls in range(0, num_class - 1):
591
+ if suppress_conf[cls] > opt['sup_threshold']:
592
+ for proposal in proposal_anc_dict:
593
+ if proposal['label'] == dataset.label_name[cls]:
594
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
595
+ proposal_dict.append(proposal)
596
+
597
+ result_dict[video_name] = proposal_dict
598
+ proposal_dict = []
599
+
600
+ return result_dict
601
+
602
+ def test_frame(opt, video_name=None):
603
+ model = MYNET(opt).cuda()
604
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
605
+ base_dict = checkpoint['state_dict']
606
+ model.load_state_dict(base_dict)
607
+ model.eval()
608
+
609
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
610
+ outfile = h5py.File(opt['frame_result_file'].format(opt['exp']), 'w')
611
+
612
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
613
+
614
+ print("testing loss: %f, cls_loss: %f, reg_loss: %f" % (tot_loss, cls_loss, reg_loss))
615
+
616
+ for video_name in dataset.video_list:
617
+ o_cls = output_cls[video_name]
618
+ o_reg = output_reg[video_name]
619
+ l_cls = labels_cls[video_name]
620
+ l_reg = labels_reg[video_name]
621
+
622
+ dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
623
+ dset_predcls[:, :] = o_cls[:, :]
624
+ dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
625
+ dset_predreg[:, :] = o_reg[:, :]
626
+ dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
627
+ dset_labelcls[:, :] = l_cls[:, :]
628
+ dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
629
+ dset_labelreg[:, :] = l_reg[:, :]
630
+ outfile.close()
631
+
632
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
633
+ return cls_loss, reg_loss, tot_loss
634
+
635
+ def patch_attention(m):
636
+ forward_orig = m.forward
637
+
638
+ def wrap(*args, **kwargs):
639
+ kwargs["need_weights"] = True
640
+ kwargs["average_attn_weights"] = False
641
+ return forward_orig(*args, **kwargs)
642
+
643
+ m.forward = wrap
644
+
645
+ class SaveOutput:
646
+ def __init__(self):
647
+ self.outputs = []
648
+
649
+ def __call__(self, module, module_in, module_out):
650
+ self.outputs.append(module_out[1])
651
+
652
+ def clear(self):
653
+ self.outputs = []
654
+
655
+ def test(opt, video_name=None):
656
+ model = MYNET(opt).cuda()
657
+ checkpoint = torch.load(opt["checkpoint_path"] + "/" + opt['exp'] + "_ckp_best.pth.tar")
658
+ base_dict = checkpoint['state_dict']
659
+ model.load_state_dict(base_dict)
660
+ model.eval()
661
+
662
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
663
+
664
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
665
+
666
+ if opt["pptype"] == "nms":
667
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
668
+ if opt["pptype"] == "net":
669
+ result_dict = eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
670
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
671
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
672
+ json.dump(output_dict, outfile, indent=2)
673
+ outfile.close()
674
+
675
+ mAP = evaluation_detection(opt)
676
+
677
+ # Compare predicted and ground truth action lengths
678
+ if video_name:
679
+ print("\nComparing Predicted and Ground Truth Action Lengths for Video:", video_name)
680
+ with open(opt["video_anno"].format(opt["split"]), 'r') as f:
681
+ anno_data = json.load(f)
682
+ gt_annotations = anno_data['database'][video_name]['annotations']
683
+ duration = anno_data['database'][video_name]['duration']
684
+
685
+ gt_segments = []
686
+ for anno in gt_annotations:
687
+ start, end = anno['segment']
688
+ label = anno['label']
689
+ duration_seg = end - start
690
+ gt_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg})
691
+
692
+ pred_segments = []
693
+ for pred in result_dict[video_name]:
694
+ start, end = pred['segment']
695
+ label = pred['label']
696
+ score = pred['score']
697
+ duration_seg = end - start
698
+ pred_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg, 'score': score})
699
+
700
+ # Print comparison table
701
+ matches = []
702
+ iou_threshold = VIS_CONFIG['iou_threshold']
703
+ used_gt_indices = set()
704
+ for pred in pred_segments:
705
+ best_iou = 0
706
+ best_gt_idx = None
707
+ for gt_idx, gt in enumerate(gt_segments):
708
+ if gt_idx in used_gt_indices:
709
+ continue
710
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
711
+ if iou > best_iou and iou >= iou_threshold:
712
+ best_iou = iou
713
+ best_gt_idx = gt_idx
714
+ if best_gt_idx is not None:
715
+ matches.append({
716
+ 'pred': pred,
717
+ 'gt': gt_segments[best_gt_idx],
718
+ 'iou': best_iou
719
+ })
720
+ used_gt_indices.add(best_gt_idx)
721
+ else:
722
+ matches.append({'pred': pred, 'gt': None, 'iou': 0})
723
+
724
+ for gt_idx, gt in enumerate(gt_segments):
725
+ if gt_idx not in used_gt_indices:
726
+ matches.append({'pred': None, 'gt': gt, 'iou': 0})
727
+
728
+ print("\n{:<20} {:<30} {:<30} {:<15} {:<10}".format(
729
+ "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU"))
730
+ print("-" * 105)
731
+ for match in matches:
732
+ pred = match['pred']
733
+ gt = match['gt']
734
+ iou = match['iou']
735
+ if pred and gt:
736
+ label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
737
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
738
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
739
+ duration_diff = pred['duration'] - gt['duration']
740
+ print("{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}".format(
741
+ label, pred_str, gt_str, duration_diff, iou))
742
+ elif pred:
743
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
744
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
745
+ pred['label'], pred_str, "None", "N/A", iou))
746
+ elif gt:
747
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
748
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
749
+ gt['label'], "None", gt_str, "N/A", iou))
750
+
751
+ # Summarize
752
+ matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
753
+ avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0
754
+ avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0
755
+ print(f"\nSummary:")
756
+ print(f"- Total Predictions: {len(pred_segments)}")
757
+ print(f"- Total Ground Truth: {len(gt_segments)}")
758
+ print(f"- Matched Segments: {matched_count}")
759
+ print(f"- Average Duration Difference (Matched): {avg_duration_diff:.2f}s")
760
+ print(f"- Average IoU (Matched): {avg_iou:.2f}")
761
+
762
+ # Generate static visualization
763
+ video_path = opt.get('video_path', '')
764
+ if os.path.exists(video_path):
765
+ visualize_action_lengths(
766
+ video_id=video_name,
767
+ pred_segments=pred_segments,
768
+ gt_segments=gt_segments,
769
+ video_path=video_path,
770
+ duration=duration
771
+ )
772
+ # Generate annotated video
773
+ annotate_video_with_actions(
774
+ video_id=video_name,
775
+ pred_segments=pred_segments,
776
+ gt_segments=gt_segments,
777
+ video_path=video_path
778
+ )
779
+ else:
780
+ print(f"Warning: Video path {video_path} not found. Skipping visualization and video annotation.")
781
+
782
+ return mAP
783
+
784
+ def test_online(opt, video_name=None):
785
+ model = MYNET(opt).cuda()
786
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
787
+ base_dict = checkpoint['state_dict']
788
+ model.load_state_dict(base_dict)
789
+ model.eval()
790
+
791
+ sup_model = SuppressNet(opt).cuda()
792
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
793
+ base_dict = checkpoint['state_dict']
794
+ sup_model.load_state_dict(base_dict)
795
+ sup_model.eval()
796
+
797
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
798
+ test_loader = torch.utils.data.DataLoader(dataset,
799
+ batch_size=1, shuffle=False,
800
+ num_workers=0, pin_memory=True, drop_last=False)
801
+
802
+ result_dict = {}
803
+ proposal_dict = []
804
+
805
+ num_class = opt["num_of_class"]
806
+ unit_size = opt['segment_size']
807
+ threshold = opt['threshold']
808
+ anchors = opt['anchors']
809
+
810
+ start_time = time.time()
811
+ total_frames = 0
812
+
813
+ for video_name in dataset.video_list:
814
+ input_queue = torch.zeros((unit_size, opt['feat_dim']))
815
+ sup_queue = torch.zeros(((unit_size, num_class - 1)))
816
+
817
+ duration = dataset.video_len[video_name]
818
+ video_time = float(dataset.video_dict[video_name]["duration"])
819
+ frame_to_time = 100.0 * video_time / duration
820
+
821
+ for idx in range(0, duration):
822
+ total_frames += 1
823
+ input_queue[:-1, :] = input_queue[1:, :].clone()
824
+ input_queue[-1:, :] = dataset._get_base_data(video_name, idx, idx + 1)
825
+
826
+ minput = input_queue.unsqueeze(0)
827
+ act_cls, act_reg, _ = model(minput.cuda())
828
+ act_cls = torch.softmax(act_cls, dim=-1)
829
+
830
+ cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
831
+ reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
832
+
833
+ proposal_anc_dict = []
834
+ for anc_idx in range(0, len(anchors)):
835
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
836
+
837
+ if len(cls) == 0:
838
+ continue
839
+
840
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
841
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
842
+ st = ed - length
843
+
844
+ for cidx in range(0, len(cls)):
845
+ label = cls[cidx]
846
+ tmp_dict = {}
847
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
848
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
849
+ tmp_dict["label"] = dataset.label_name[label]
850
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
851
+ proposal_anc_dict.append(tmp_dict)
852
+
853
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
854
+
855
+ sup_queue[:-1, :] = sup_queue[1:, :].clone()
856
+ sup_queue[-1, :] = 0
857
+ for proposal in proposal_anc_dict:
858
+ cls_idx = dataset.label_name.index(proposal['label'])
859
+ sup_queue[-1, cls_idx] = proposal["score"]
860
+
861
+ minput = sup_queue.unsqueeze(0)
862
+ suppress_conf = sup_model(minput.cuda())
863
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
864
+
865
+ for cls in range(0, num_class - 1):
866
+ if suppress_conf[cls] > opt['sup_threshold']:
867
+ for proposal in proposal_anc_dict:
868
+ if proposal['label'] == dataset.label_name[cls]:
869
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
870
+ proposal_dict.append(proposal)
871
+
872
+ result_dict[video_name] = proposal_dict
873
+ proposal_dict = []
874
+
875
+ end_time = time.time()
876
+ working_time = end_time - start_time
877
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
878
+
879
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
880
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
881
+ json.dump(output_dict, outfile, indent=2)
882
+ outfile.close()
883
+
884
+ mAP = evaluation_detection(opt)
885
+ return mAP
886
+
887
+ def main(opt, video_name=None):
888
+ max_perf = 0
889
+ if not video_name and 'video_name' in opt:
890
+ video_name = opt['video_name']
891
+
892
+ if opt['mode'] == 'train':
893
+ max_perf = train(opt)
894
+ if opt['mode'] == 'test':
895
+ max_perf = test(opt, video_name=video_name)
896
+ if opt['mode'] == 'test_frame':
897
+ max_perf = test_frame(opt, video_name=video_name)
898
+ if opt['mode'] == 'test_online':
899
+ max_perf = test_online(opt, video_name=video_name)
900
+ if opt['mode'] == 'eval':
901
+ max_perf = evaluation_detection(opt)
902
+
903
+ return max_perf
904
+
905
+ if __name__ == '__main__':
906
+ opt = opts.parse_opt()
907
+ opt = vars(opt)
908
+ if not os.path.exists(opt["checkpoint_path"]):
909
+ os.makedirs(opt["checkpoint_path"])
910
+ opt_file = open(opt["checkpoint_path"] + "/" + opt["exp"] + "_opts.json", "w")
911
+ json.dump(opt, opt_file)
912
+ opt_file.close()
913
+
914
+ if opt['seed'] >= 0:
915
+ seed = opt['seed']
916
+ torch.manual_seed(seed)
917
+ np.random.seed(seed)
918
+
919
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
920
+
921
+ video_name = opt.get('video_name', None)
922
+ main(opt, video_name=video_name)
923
+ while(opt['wterm']):
924
+ pass
annotated video with bar main.py ADDED
The diff for this file is too large to render. See raw diff
 
dataset.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import h5py
3
+ import json
4
+ import torch
5
+ import torch.utils.data as data
6
+ import os
7
+ import pickle
8
+ from multiprocessing import Pool
9
+
10
+ def load_json(file):
11
+ with open(file) as json_file:
12
+ data = json.load(json_file)
13
+ return data
14
+
15
+ def calc_iou(a, b):
16
+ st = a[0] - a[1]
17
+ ed = a[0]
18
+ target_st = b[0] - b[1]
19
+ target_ed = b[0]
20
+ sst = min(st, target_st)
21
+ led = max(ed, target_ed)
22
+ lst = max(st, target_st)
23
+ sed = min(ed, target_ed)
24
+ iou = (sed - lst) / max(led - sst, 1)
25
+ return iou
26
+
27
+ def box_include(y, target):
28
+ st = y[0] - y[1]
29
+ ed = y[0]
30
+ target_st = target[0] - target[1]
31
+ target_ed = target[0]
32
+ detection_point = target_st
33
+ if ed > detection_point and target_st < st and target_ed > ed:
34
+ return True
35
+ return False
36
+
37
+ class VideoDataSet(data.Dataset):
38
+ def __init__(self, opt, subset="train", video_name=None):
39
+ self.subset = subset
40
+ self.mode = opt["mode"]
41
+ self.predefined_fps = opt["predefined_fps"]
42
+ self.video_anno_path = opt["video_anno"].format(opt["split"])
43
+ self.video_len_path = opt["video_len_file"].format(self.subset + '_' + opt["setup"])
44
+ self.num_of_class = opt["num_of_class"]
45
+ self.segment_size = opt["segment_size"]
46
+ self.label_name = []
47
+ self.match_score = {}
48
+ self.match_score_end = {}
49
+ self.match_length = {}
50
+ self.gt_action = {}
51
+ self.cls_label = {}
52
+ self.reg_label = {}
53
+ self.snip_label = {}
54
+ self.inputs = []
55
+ self.inputs_all = []
56
+ self.data_rescale = opt["data_rescale"]
57
+ self.anchors = opt["anchors"]
58
+ self.pos_threshold = opt["pos_threshold"]
59
+ self.single_video_name = video_name
60
+
61
+ self._getDatasetDict()
62
+ self._loadFeaturelen(opt)
63
+ self._getMatchScore()
64
+ self._makeInputSeq()
65
+ self._loadPropLabel(opt['proposal_label_file'].format(self.subset + '_' + opt["setup"]))
66
+
67
+ if self.subset == "train":
68
+ if opt['data_format'] == "h5":
69
+ feature_rgb_file = h5py.File(opt["video_feature_rgb_train"], 'r')
70
+ self.feature_rgb_file = {}
71
+ keys = self.video_list
72
+ for vidx in range(len(keys)):
73
+ if keys[vidx] not in feature_rgb_file:
74
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_train']}")
75
+ self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:])
76
+ if opt['rgb_only']:
77
+ self.feature_flow_file = None
78
+ else:
79
+ self.feature_flow_file = {}
80
+ feature_flow_file = h5py.File(opt["video_feature_flow_train"], 'r')
81
+ for vidx in range(len(keys)):
82
+ if keys[vidx] not in feature_flow_file:
83
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_train']}")
84
+ self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:])
85
+ elif opt['data_format'] == "pickle":
86
+ feature_All = pickle.load(open(opt["video_feature_all_train"], 'rb'))
87
+ self.feature_rgb_file = {}
88
+ self.feature_flow_file = {}
89
+ keys = self.video_list
90
+ for vidx in range(len(keys)):
91
+ if keys[vidx] not in feature_All:
92
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_train']}")
93
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
94
+ self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
95
+ elif opt['data_format'] == "npz":
96
+ feature_All = {}
97
+ self.feature_rgb_file = {}
98
+ self.feature_flow_file = {}
99
+ for file in self.video_list:
100
+ feature_path = opt["video_feature_all_train"] + file + '.npz'
101
+ if not os.path.exists(feature_path):
102
+ raise ValueError(f"Feature file {feature_path} not found")
103
+ feature_All[file] = np.load(feature_path)['feats']
104
+ keys = self.video_list
105
+ for vidx in range(len(keys)):
106
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
107
+ self.feature_flow_file = None
108
+ elif opt['data_format'] == "npz_i3d":
109
+ feature_All = {}
110
+ self.feature_rgb_file = {}
111
+ self.feature_flow_file = {}
112
+ for file in self.video_list:
113
+ feature_path = opt["video_feature_all_train"] + file + '.npz'
114
+ if not os.path.exists(feature_path):
115
+ raise ValueError(f"Feature file {feature_path} not found")
116
+ feature_All[file] = np.load(feature_path)
117
+ keys = self.video_list
118
+ for vidx in range(len(keys)):
119
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
120
+ self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
121
+ elif opt['data_format'] == "pt":
122
+ feature_All = {}
123
+ self.feature_rgb_file = {}
124
+ self.feature_flow_file = {}
125
+ for file in self.video_list:
126
+ feature_path = opt["video_feature_all_train"] + file + '.pt'
127
+ if not os.path.exists(feature_path):
128
+ raise ValueError(f"Feature file {feature_path} not found")
129
+ feature_All[file] = torch.load(feature_path)
130
+ keys = self.video_list
131
+ for vidx in range(len(keys)):
132
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
133
+ self.feature_flow_file = None
134
+ else:
135
+ if opt['data_format'] == "h5":
136
+ feature_rgb_file = h5py.File(opt["video_feature_rgb_test"], 'r')
137
+ self.feature_rgb_file = {}
138
+ keys = self.video_list
139
+ for vidx in range(len(keys)):
140
+ if keys[vidx] not in feature_rgb_file:
141
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_test']}")
142
+ self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:])
143
+ if opt['rgb_only']:
144
+ self.feature_flow_file = None
145
+ else:
146
+ self.feature_flow_file = {}
147
+ feature_flow_file = h5py.File(opt["video_feature_flow_test"], 'r')
148
+ for vidx in range(len(keys)):
149
+ if keys[vidx] not in feature_flow_file:
150
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_test']}")
151
+ self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:])
152
+ elif opt['data_format'] == "pickle":
153
+ feature_All = pickle.load(open(opt["video_feature_all_test"], 'rb'))
154
+ self.feature_rgb_file = {}
155
+ self.feature_flow_file = {}
156
+ keys = self.video_list
157
+ for vidx in range(len(keys)):
158
+ if keys[vidx] not in feature_All:
159
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_test']}")
160
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
161
+ self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
162
+ elif opt['data_format'] == "npz":
163
+ feature_All = {}
164
+ self.feature_rgb_file = {}
165
+ self.feature_flow_file = {}
166
+ for file in self.video_list:
167
+ feature_path = opt["video_feature_all_test"] + file + '.npz'
168
+ if not os.path.exists(feature_path):
169
+ raise ValueError(f"Feature file {feature_path} not found")
170
+ feature_All[file] = np.load(feature_path)['feats']
171
+ keys = self.video_list
172
+ for vidx in range(len(keys)):
173
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
174
+ self.feature_flow_file = None
175
+ elif opt['data_format'] == "npz_i3d":
176
+ feature_All = {}
177
+ self.feature_rgb_file = {}
178
+ self.feature_flow_file = {}
179
+ for file in self.video_list:
180
+ feature_path = opt["video_feature_all_test"] + file + '.npz'
181
+ if not os.path.exists(feature_path):
182
+ raise ValueError(f"Feature file {feature_path} not found")
183
+ feature_All[file] = np.load(feature_path)
184
+ keys = self.video_list
185
+ for vidx in range(len(keys)):
186
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
187
+ self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
188
+ elif opt['data_format'] == "pt":
189
+ feature_All = {}
190
+ self.feature_rgb_file = {}
191
+ self.feature_flow_file = {}
192
+ for file in self.video_list:
193
+ feature_path = opt["video_feature_all_test"] + file + '.pt'
194
+ if not os.path.exists(feature_path):
195
+ raise ValueError(f"Feature file {feature_path} not found")
196
+ feature_All[file] = torch.load(feature_path)
197
+ keys = self.video_list
198
+ for vidx in range(len(keys)):
199
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
200
+ self.feature_flow_file = None
201
+
202
+ def _loadFeaturelen(self, opt):
203
+ if os.path.exists(self.video_len_path):
204
+ self.video_len = load_json(self.video_len_path)
205
+ return
206
+
207
+ self.video_len = {}
208
+ if self.subset == "train":
209
+ if opt['data_format'] == "h5":
210
+ feature_file = h5py.File(opt["video_feature_rgb_train"], 'r')
211
+ elif opt['data_format'] == "pickle":
212
+ feature_file = pickle.load(open(opt["video_feature_all_train"], 'rb'))
213
+ elif opt['data_format'] == "npz":
214
+ feature_file = {}
215
+ for file in self.video_list:
216
+ feature_file[file] = np.load(opt["video_feature_all_train"] + file + '.npz')['feats']
217
+ elif opt['data_format'] == "npz_i3d":
218
+ feature_file = {}
219
+ for file in self.video_list:
220
+ feature_file[file] = np.load(opt["video_feature_all_train"] + file + '.npz')
221
+ elif opt['data_format'] == "pt":
222
+ feature_file = {}
223
+ for file in self.video_list:
224
+ feature_file[file] = torch.load(opt["video_feature_all_train"] + file + '.pt')
225
+ else:
226
+ if opt['data_format'] == "h5":
227
+ feature_file = h5py.File(opt["video_feature_rgb_test"], 'r')
228
+ elif opt['data_format'] == "pickle":
229
+ feature_file = pickle.load(open(opt["video_feature_all_test"], 'rb'))
230
+ elif opt['data_format'] == "npz":
231
+ feature_file = {}
232
+ for file in self.video_list:
233
+ feature_file[file] = np.load(opt["video_feature_all_test"] + file + '.npz')['feats']
234
+ elif opt['data_format'] == "npz_i3d":
235
+ feature_file = {}
236
+ for file in self.video_list:
237
+ feature_file[file] = np.load(opt["video_feature_all_test"] + file + '.npz')
238
+ elif opt['data_format'] == "pt":
239
+ feature_file = {}
240
+ for file in self.video_list:
241
+ feature_file[file] = torch.load(opt["video_feature_all_test"] + file + '.pt')
242
+
243
+ keys = self.video_list
244
+ if opt['data_format'] == "h5":
245
+ for vidx in range(len(keys)):
246
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
247
+ elif opt['data_format'] == "pickle":
248
+ for vidx in range(len(keys)):
249
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
250
+ elif opt['data_format'] == "npz":
251
+ for vidx in range(len(keys)):
252
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
253
+ elif opt['data_format'] == "npz_i3d":
254
+ for vidx in range(len(keys)):
255
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
256
+ elif opt['data_format'] == "pt":
257
+ for vidx in range(len(keys)):
258
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
259
+ outfile = open(self.video_len_path, "w")
260
+ json.dump(self.video_len, outfile, indent=2)
261
+ outfile.close()
262
+
263
+ def _getDatasetDict(self):
264
+ anno_database = load_json(self.video_anno_path)
265
+ anno_database = anno_database['database']
266
+ self.video_dict = {}
267
+ if self.single_video_name:
268
+ if self.single_video_name in anno_database:
269
+ video_info = anno_database[self.single_video_name]
270
+ video_subset = video_info['subset']
271
+ if self.subset == "full" or self.subset in video_subset:
272
+ self.video_dict[self.single_video_name] = video_info
273
+ for seg in video_info['annotations']:
274
+ if not seg['label'] in self.label_name:
275
+ self.label_name.append(seg['label'])
276
+ else:
277
+ raise ValueError(f"Video {self.single_video_name} not found in annotation database")
278
+ else:
279
+ for video_name in anno_database:
280
+ video_info = anno_database[video_name]
281
+ video_subset = anno_database[video_name]['subset']
282
+ if self.subset == "full" or self.subset in video_subset:
283
+ self.video_dict[video_name] = video_info
284
+ for seg in video_info['annotations']:
285
+ if not seg['label'] in self.label_name:
286
+ self.label_name.append(seg['label'])
287
+
288
+ # Ensure all 22 EGTEA action classes are included
289
+ expected_labels = [
290
+ 'Clean/Wipe', 'Close', 'Compress', 'Crack', 'Cut', 'Divide/Pull Apart',
291
+ 'Dry', 'Inspect/Read', 'Mix', 'Move Around', 'Open', 'Operate', 'Other',
292
+ 'Pour', 'Put', 'Squeeze', 'Take', 'Transfer', 'Turn off', 'Turn on', 'Wash',
293
+ 'Spread' # Assumed missing label; replace with actual label if known
294
+ ]
295
+ for label in expected_labels:
296
+ if label not in self.label_name:
297
+ self.label_name.append(label)
298
+
299
+ self.label_name.sort()
300
+ self.video_list = list(self.video_dict.keys())
301
+ print(f"Labels in dataset.label_name: {self.label_name}")
302
+ print(f"Number of labels: {len(self.label_name)}, Expected: {self.num_of_class-1}")
303
+ print(f"{self.subset} subset video numbers: {len(self.video_list)}")
304
+
305
+ def _getMatchScore(self):
306
+ self.action_end_count = torch.zeros(2)
307
+ for index in range(0, len(self.video_list)):
308
+ video_name = self.video_list[index]
309
+ video_info = self.video_dict[video_name]
310
+ video_labels = video_info['annotations']
311
+ gt_bbox = []
312
+ gt_edlen = []
313
+
314
+ second_to_frame = self.video_len[video_name] / float(video_info['duration'])
315
+ for j in range(len(video_labels)):
316
+ tmp_info = video_labels[j]
317
+ tmp_start = tmp_info['segment'][0] * second_to_frame
318
+ tmp_end = tmp_info['segment'][1] * second_to_frame
319
+ tmp_label = self.label_name.index(tmp_info['label'])
320
+ gt_bbox.append([tmp_start, tmp_end, tmp_label])
321
+ gt_edlen.append([gt_bbox[-1][1], gt_bbox[-1][1] - gt_bbox[-1][0], tmp_label])
322
+
323
+ gt_bbox = np.array(gt_bbox)
324
+ gt_edlen = np.array(gt_edlen)
325
+ self.gt_action[video_name] = gt_edlen
326
+
327
+ match_score = np.zeros((self.video_len[video_name], self.num_of_class - 1), dtype=np.float32)
328
+ for idx in range(gt_bbox.shape[0]):
329
+ ed = int(gt_bbox[idx, 1]) + 1
330
+ st = int(gt_bbox[idx, 0])
331
+ match_score[st:ed, int(gt_bbox[idx, 2])] = idx + 1
332
+ self.match_score[video_name] = match_score
333
+
334
+ def _makeInputSeq(self):
335
+ data_idx = 0
336
+ for index in range(0, len(self.video_list)):
337
+ video_name = self.video_list[index]
338
+ duration = self.match_score[video_name].shape[0]
339
+ for i in range(1, duration + 1):
340
+ st = i - self.segment_size
341
+ ed = i
342
+ self.inputs_all.append([video_name, st, ed, data_idx])
343
+ data_idx += 1
344
+
345
+ self.inputs = self.inputs_all.copy()
346
+ print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
347
+
348
+ def _makePropLabelUnit(self, i):
349
+ video_name = self.inputs_all[i][0]
350
+ st = self.inputs_all[i][1]
351
+ ed = self.inputs_all[i][2]
352
+ cls_anc = []
353
+ reg_anc = []
354
+
355
+ for j in range(0, len(self.anchors)):
356
+ v1 = np.zeros(self.num_of_class)
357
+ v1[-1] = 1
358
+ v2 = np.zeros(2)
359
+ v2[-1] = -1e3
360
+ y_box = [ed - 1, self.anchors[j]]
361
+
362
+ subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[j], ed)
363
+ idx_list = []
364
+ for ii in range(0, subset_label.shape[0]):
365
+ for jj in range(0, subset_label.shape[1]):
366
+ idx = int(subset_label[ii, jj])
367
+ if idx > 0 and idx - 1 not in idx_list:
368
+ idx_list.append(idx - 1)
369
+
370
+ for idx in idx_list:
371
+ target_box = self.gt_action[video_name][idx]
372
+ cls = int(target_box[2])
373
+ iou = calc_iou(y_box, target_box)
374
+ if iou >= self.pos_threshold or (j == len(self.anchors) - 1 and box_include(y_box, target_box)) or (j == 0 and box_include(target_box, y_box)):
375
+ v1[cls] = 1
376
+ v1[-1] = 0
377
+ v2[0] = 1.0 * (target_box[0] - y_box[0]) / self.anchors[j]
378
+ v2[1] = np.log(1.0 * max(1, target_box[1]) / y_box[1])
379
+
380
+ cls_anc.append(v1)
381
+ reg_anc.append(v2)
382
+
383
+ v0 = np.zeros(self.num_of_class)
384
+ v0[-1] = 1
385
+ segment_size = ed - st
386
+ y_box = [ed - 1, self.anchors[-1]]
387
+ subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[-1], ed)
388
+ idx_list = []
389
+ for ii in range(0, subset_label.shape[0]):
390
+ for jj in range(0, subset_label.shape[1]):
391
+ idx = int(subset_label[ii, jj])
392
+ if idx > 0 and idx - 1 not in idx_list:
393
+ idx_list.append(idx - 1)
394
+
395
+ for idx in idx_list:
396
+ target_box = self.gt_action[video_name][idx]
397
+ cls = int(target_box[2])
398
+ iou = calc_iou(y_box, target_box)
399
+ if iou >= 0:
400
+ v0[cls] = 1
401
+ v0[-1] = 0
402
+
403
+ cls_anc = np.stack(cls_anc, axis=0)
404
+ reg_anc = np.stack(reg_anc, axis=0)
405
+ cls_snip = np.array(v0)
406
+ return cls_anc, reg_anc, cls_snip
407
+
408
+ def _loadPropLabel(self, filename):
409
+ if os.path.exists(filename):
410
+ prop_label_file = h5py.File(filename, 'r')
411
+ self.cls_label = np.array(prop_label_file['cls_label'][:])
412
+ self.reg_label = np.array(prop_label_file['reg_label'][:])
413
+ self.snip_label = np.array(prop_label_file['snip_label'][:])
414
+ prop_label_file.close()
415
+ self.action_frame_count = np.sum(self.cls_label.reshape((-1, self.cls_label.shape[-1])), axis=0)
416
+ self.action_frame_count = torch.Tensor(self.action_frame_count)
417
+ return
418
+
419
+ pool = Pool(os.cpu_count() // 2)
420
+ labels = pool.map(self._makePropLabelUnit, range(0, len(self.inputs_all)))
421
+ pool.close()
422
+ pool.join()
423
+
424
+ cls_label = []
425
+ reg_label = []
426
+ snip_label = []
427
+ for i in range(0, len(labels)):
428
+ cls_label.append(labels[i][0])
429
+ reg_label.append(labels[i][1])
430
+ snip_label.append(labels[i][2])
431
+ self.cls_label = np.stack(cls_label, axis=0)
432
+ self.reg_label = np.stack(reg_label, axis=0)
433
+ self.snip_label = np.stack(snip_label, axis=0)
434
+
435
+ outfile = h5py.File(filename, 'w')
436
+ dset_cls = outfile.create_dataset('/cls_label', self.cls_label.shape, maxshape=self.cls_label.shape, chunks=True, dtype=np.float32)
437
+ dset_cls[:, :] = self.cls_label[:, :]
438
+ dset_reg = outfile.create_dataset('/reg_label', self.reg_label.shape, maxshape=self.reg_label.shape, chunks=True, dtype=np.float32)
439
+ dset_reg[:, :] = self.reg_label[:, :]
440
+ dset_snip = outfile.create_dataset('/snip_label', self.snip_label.shape, maxshape=self.snip_label.shape, chunks=True, dtype=np.float32)
441
+ dset_snip[:, :] = self.snip_label[:, :]
442
+ outfile.close()
443
+
444
+ return
445
+
446
+ def __getitem__(self, index):
447
+ video_name, st, ed, data_idx = self.inputs[index]
448
+ if st >= 0:
449
+ feature = self._get_base_data(video_name, st, ed)
450
+ else:
451
+ feature = self._get_base_data(video_name, 0, ed)
452
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, -st, 0), 0)
453
+ feature = padfunc2d(feature)
454
+
455
+ cls_label = torch.Tensor(self.cls_label[data_idx])
456
+ reg_label = torch.Tensor(self.reg_label[data_idx])
457
+ snip_label = torch.Tensor(self.snip_label[data_idx])
458
+
459
+ return feature, cls_label, reg_label, snip_label
460
+
461
+ def _get_base_data(self, video_name, st, ed):
462
+ feature_rgb = self.feature_rgb_file[video_name]
463
+ feature_rgb = feature_rgb[st:ed, :]
464
+
465
+ if self.feature_flow_file is not None:
466
+ feature_flow = self.feature_flow_file[video_name]
467
+ feature_flow = feature_flow[st:ed, :]
468
+ feature = np.append(feature_rgb, feature_flow, axis=1)
469
+ else:
470
+ feature = feature_rgb
471
+ feature = torch.from_numpy(np.array(feature))
472
+
473
+ return feature
474
+
475
+ def _get_train_label_with_class(self, video_name, st, ed):
476
+ duration = len(self.match_score[video_name])
477
+ st_padding = 0
478
+ ed_padding = 0
479
+ if st < 0:
480
+ st_padding = -st
481
+ st = 0
482
+ if ed > duration:
483
+ ed_padding = ed - duration
484
+ ed = duration
485
+
486
+ match_score = torch.Tensor(self.match_score[video_name][st:ed])
487
+ if st_padding > 0:
488
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, st_padding, 0), 0)
489
+ match_score = padfunc2d(match_score)
490
+ if ed_padding > 0:
491
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, 0, ed_padding), 0)
492
+ match_score = padfunc2d(match_score)
493
+ return match_score
494
+
495
+ def __len__(self):
496
+ return len(self.inputs)
497
+
498
+ def reset_sample(self):
499
+ self.inputs = self.inputs_all.copy()
500
+
501
+ def select_sample(self, idx):
502
+ inputs = [self.inputs_all[i] for i in idx]
503
+ self.inputs = inputs.copy()
504
+ return
505
+
506
+ class SuppressDataSet(data.Dataset):
507
+ def __init__(self, opt, subset="train"):
508
+ self.subset = subset
509
+ self.mode = opt["mode"]
510
+ self.data_file = h5py.File(opt["suppress_label_file"].format(self.subset + "_" + opt['setup']), 'r')
511
+ self.video_list = list(self.data_file.keys())
512
+ self.inputs = []
513
+ for index in range(0, len(self.video_list)):
514
+ video_name = self.video_list[index]
515
+ duration = self.data_file[video_name + '/input'].shape[0]
516
+ for i in range(0, duration):
517
+ self.inputs.append([video_name, i])
518
+
519
+ print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
520
+
521
+ def __getitem__(self, index):
522
+ video_name, idx = self.inputs[index]
523
+
524
+ input_seq = self.data_file[video_name + '/input'][idx]
525
+ label = self.data_file[video_name + '/label'][idx]
526
+
527
+ input_seq = torch.from_numpy(input_seq)
528
+ label = torch.from_numpy(label)
529
+
530
+ return input_seq, label
531
+
532
+ def __len__(self):
533
+ return len(self.inputs)
eval.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import sys
3
+ sys.path.append('./Evaluation')
4
+ from eval_detection_gentime import ANETdetection
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ def run_evaluation_detection(opt, ground_truth_filename, prediction_filename,
9
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
10
+ subset='validation', verbose=True):
11
+
12
+ anet_detection = ANETdetection(opt, ground_truth_filename, prediction_filename,
13
+ subset=subset, tiou_thresholds=tiou_thresholds,
14
+ verbose=verbose, check_status=False)
15
+ anet_detection.evaluate()
16
+
17
+ ap = anet_detection.ap
18
+ mAP = anet_detection.mAP
19
+ tdiff = anet_detection.tdiff
20
+
21
+ return (mAP, ap, tdiff)
22
+
23
+ def evaluation_detection(opt, verbose=True):
24
+
25
+ mAP, AP, tdiff = run_evaluation_detection(
26
+ opt,
27
+ opt["video_anno"].format(opt["split"]),
28
+ opt["result_file"].format(opt['exp']),
29
+ tiou_thresholds=np.linspace(0.1, 0.50, 5),
30
+ subset=opt['inference_subset'], verbose=verbose)
31
+
32
+ if verbose:
33
+ print('mAP')
34
+ print(mAP)
35
+ print('AEDT')
36
+ print(tdiff)
37
+
38
+ return mAP
39
+
feature_extractor.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.i3d.extract_i3d import ExtractI3D
2
+ from utils.utils import build_cfg_path
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ from tqdm import tqdm
6
+ import os
7
+ import numpy as np
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ print(torch.cuda.get_device_name(0))
11
+ # Select the feature type
12
+ feature_type = 'i3d'
13
+
14
+ # Load and patch the config
15
+ args = OmegaConf.load(build_cfg_path(feature_type))
16
+ args.step_size = 12
17
+ args.flow_type = 'raft' # 'pwc'
18
+
19
+ # Load the model
20
+ extractor = ExtractI3D(args)
21
+
22
+ args.video_paths = os.listdir('./Videos')
23
+
24
+ # Extract features
25
+ for video_path in tqdm(args.video_paths):
26
+ print(f'Extracting for {video_path}')
27
+ feature_dict = extractor.extract('./Videos/'+video_path)
28
+ np.savez('./I3D/'+video_path[:-4]+'.npz', **feature_dict)
29
+ [(print(k), print(v.shape)) for k, v in feature_dict.items()]
frame fps none bar color main.py ADDED
@@ -0,0 +1,1234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+
11
+ import time
12
+ import h5py
13
+ from tqdm import tqdm
14
+ from iou_utils import *
15
+ from eval import evaluation_detection
16
+ from tensorboardX import SummaryWriter
17
+ from dataset import VideoDataSet, calc_iou
18
+ from models import MYNET, SuppressNet
19
+ from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
20
+ from loss_func import MultiCrossEntropyLoss
21
+ from functools import *
22
+
23
+ import matplotlib.pyplot as plt
24
+ import matplotlib.patches as patches
25
+ import cv2
26
+ from typing import List, Dict, Optional
27
+
28
+ from PIL import Image, ImageDraw, ImageFont
29
+ import warnings
30
+
31
+ # Visualization Configuration (Updated)
32
+ VIS_CONFIG = {
33
+ 'frame_interval': 1.0,
34
+ 'max_frames': 20,
35
+ 'save_dir': './output/visualizations',
36
+ 'video_save_dir': './output/videos',
37
+ 'gt_color': '#1f77b4', # Blue for ground truth (RGB: 31, 119, 180)
38
+ 'pred_color': '#ff7f0e', # Orange for predictions (RGB: 255, 127, 14)
39
+ 'fontsize_label': 10,
40
+ 'fontsize_title': 14,
41
+ 'frame_highlight_both': 'green',
42
+ 'frame_highlight_gt': 'red',
43
+ 'frame_highlight_pred': 'black',
44
+ 'iou_threshold': 0.3,
45
+ 'frame_scale_factor': 0.8,
46
+ 'video_text_scale': 0.5,
47
+ 'video_gt_text_color': (180, 119, 31), # BGR for OpenCV
48
+ 'video_pred_text_color': (14, 127, 255), # BGR for OpenCV
49
+ 'video_text_thickness': 1,
50
+ 'video_font_path': "./data/Poppins ExtraBold Italic 800.ttf",
51
+ 'video_font_fallback': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
52
+ 'video_pred_text_y': 0.45,
53
+ 'video_gt_text_y': 0.55,
54
+ 'video_footer_height': 150, # Increased to accommodate labels
55
+ 'video_gt_bar_y': 0.5,
56
+ 'video_pred_bar_y': 0.8,
57
+ 'video_bar_height': 0.15,
58
+ 'video_bar_text_scale': 0.4,
59
+ 'min_segment_duration': 1.0,
60
+ 'video_frame_text_y': 0.05, # Position for frame number and FPS
61
+ 'video_bar_label_x': 10, # X-position for GT/Pred labels
62
+ 'video_bar_label_scale': 0.5,
63
+ 'scroll_window_duration': 30.0, # Duration of the visible time window (seconds)
64
+ 'scroll_speed': 0.5, # Seconds to advance the window per second of video
65
+ }
66
+
67
+
68
+ def annotate_video_with_actions(
69
+ video_id: str,
70
+ pred_segments: List[Dict],
71
+ gt_segments: List[Dict],
72
+ video_path: str,
73
+ save_dir: str = VIS_CONFIG['video_save_dir'],
74
+ text_scale: float = VIS_CONFIG['video_text_scale'],
75
+ gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'],
76
+ pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'],
77
+ text_thickness: int = VIS_CONFIG['video_text_thickness']
78
+ ) -> None:
79
+ """
80
+ Annotate a video with predicted and ground truth action labels, cumulative bars, frame number, and FPS.
81
+ Use fixed 20-second windows with original bar animation, resetting bars at each window boundary.
82
+ Assign different colors to different actions for GT and Pred bars, with reduced vertical gap.
83
+
84
+ Args:
85
+ video_id: Video identifier (e.g., 'my_video').
86
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
87
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
88
+ video_path: Path to the input video file.
89
+ save_dir: Directory to save the annotated video.
90
+ text_scale: Scale factor for text size in video.
91
+ gt_text_color: BGR color tuple for ground truth text (fallback).
92
+ pred_text_color: BGR color tuple for predicted text (fallback).
93
+ text_thickness: Thickness of text strokes.
94
+ """
95
+ os.makedirs(save_dir, exist_ok=True)
96
+
97
+ # Open input video
98
+ cap = cv2.VideoCapture(video_path)
99
+ if not cap.isOpened():
100
+ print(f"Error: Could not open video {video_path}. Skipping video annotation.")
101
+ return
102
+
103
+ # Get video properties
104
+ fps = cap.get(cv2.CAP_PROP_FPS)
105
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
106
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
107
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
108
+ duration = total_frames / fps
109
+ print(f"Input Video: FPS={fps:.2f}, Resolution={frame_width}x{frame_height}, Total Frames={total_frames}, Duration={duration:.2f}s")
110
+
111
+ # Define output video with extended height for footer
112
+ footer_height = VIS_CONFIG['video_footer_height']
113
+ output_height = frame_height + footer_height
114
+ output_path = os.path.join(save_dir, f"annotated_{video_id}_{opt['exp']}.avi")
115
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
116
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, output_height))
117
+
118
+ if not out.isOpened():
119
+ print(f"Error: Could not initialize video writer for {output_path}. Check codec availability.")
120
+ cap.release()
121
+ return
122
+
123
+ # Filter short segments
124
+ min_duration = VIS_CONFIG['min_segment_duration']
125
+ gt_segments = [seg for seg in gt_segments if seg['duration'] >= min_duration]
126
+ pred_segments = [seg for seg in pred_segments if seg['duration'] >= min_duration]
127
+ print(f"Filtered Segments: GT={len(gt_segments)}, Pred={len(pred_segments)} (min_duration={min_duration}s)")
128
+
129
+ # Create color mapping for actions
130
+ action_labels = set(seg['label'] for seg in gt_segments).union(seg['label'] for seg in pred_segments)
131
+ # Define a BGR color palette (20 distinct colors)
132
+ color_palette = [
133
+ (255, 0, 0), # Red
134
+ (0, 255, 0), # Green
135
+ (0, 0, 255), # Blue
136
+ (255, 255, 0), # Yellow
137
+ (255, 0, 255), # Magenta
138
+ (0, 255, 255), # Cyan
139
+ (128, 0, 0), # Maroon
140
+ (0, 128, 0), # Dark Green
141
+ (0, 0, 128), # Navy
142
+ (128, 128, 0), # Olive
143
+ (128, 0, 128), # Purple
144
+ (0, 128, 128), # Teal
145
+ (255, 165, 0), # Orange
146
+ (255, 192, 203), # Pink
147
+ (128, 128, 128), # Gray
148
+ (210, 105, 30), # Chocolate
149
+ (100, 149, 237), # Cornflower Blue
150
+ (154, 205, 50), # Yellow Green
151
+ (75, 0, 130), # Indigo
152
+ (245, 245, 220), # Beige
153
+ ]
154
+ action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
155
+ print(f"Action Color Mapping: {action_color_map}")
156
+
157
+ # Convert fallback colors to RGB for PIL
158
+ gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0]) # BGR to RGB
159
+ pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0]) # BGR to RGB
160
+
161
+ # Load font
162
+ font_path = VIS_CONFIG['video_font_path']
163
+ font_fallback = VIS_CONFIG['video_font_fallback']
164
+ font_size = int(20 * text_scale)
165
+ bar_font_size = int(20 * VIS_CONFIG['video_bar_text_scale'])
166
+ font = None
167
+ bar_font = None
168
+ if font_path:
169
+ try:
170
+ font = ImageFont.truetype(font_path, font_size)
171
+ bar_font = ImageFont.truetype(font_path, bar_font_size)
172
+ print(f"Using font: {font_path}")
173
+ except IOError:
174
+ print(f"Warning: Font {font_path} not found. Trying fallback font.")
175
+ if not font:
176
+ try:
177
+ font = ImageFont.truetype(font_fallback, font_size)
178
+ bar_font = ImageFont.truetype(font_fallback, bar_font_size)
179
+ print(f"Using fallback font: {font_fallback}")
180
+ except IOError:
181
+ print(f"Warning: Fallback font {font_fallback} not found. Using OpenCV default font.")
182
+ font = None
183
+ bar_font = None
184
+
185
+ # Fixed window configuration
186
+ window_size = 20.0 # 20-second windows
187
+ num_windows = int(np.ceil(duration / window_size))
188
+
189
+ frame_idx = 0
190
+ written_frames = 0
191
+ while cap.isOpened():
192
+ ret, frame = cap.read()
193
+ if not ret:
194
+ break
195
+
196
+ # Create extended frame with footer
197
+ extended_frame = np.zeros((output_height, frame_width, 3), dtype=np.uint8)
198
+ extended_frame[:frame_height, :, :] = frame
199
+ extended_frame[frame_height:, :, :] = 255 # White footer
200
+
201
+ # Calculate current timestamp
202
+ timestamp = frame_idx / fps
203
+
204
+ # Determine current window
205
+ window_idx = int(timestamp // window_size)
206
+ window_start = window_idx * window_size
207
+ window_end = min(window_start + window_size, duration)
208
+ window_duration = window_end - window_start
209
+ window_timestamp = timestamp - window_start # Relative timestamp within window
210
+
211
+ # Find active GT actions (for text overlay)
212
+ gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
213
+ gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else ""
214
+
215
+ # Find active predicted actions (for text overlay)
216
+ pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
217
+ pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else ""
218
+
219
+ # Draw GT and prediction bars in footer (within current window, using original animation)
220
+ footer_y = frame_height
221
+ gt_bar_y = footer_y + int(0.2 * footer_height) # Reduced gap
222
+ pred_bar_y = footer_y + int(0.5 * footer_height) # Reduced gap
223
+ bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
224
+
225
+ for seg in gt_segments:
226
+ if seg['start'] <= window_end and seg['end'] >= window_start:
227
+ start_t = max(seg['start'], window_start)
228
+ end_t = min(seg['end'], window_start + window_timestamp) # Original animation
229
+ start_x = int(((start_t - window_start) / window_duration) * frame_width)
230
+ end_x = int(((end_t - window_start) / window_duration) * frame_width)
231
+ if end_x > start_x:
232
+ cv2.rectangle(
233
+ extended_frame,
234
+ (start_x, gt_bar_y),
235
+ (end_x, gt_bar_y + bar_height),
236
+ action_color_map[seg['label']], # Action-specific color
237
+ -1
238
+ )
239
+
240
+ for seg in pred_segments:
241
+ if seg['start'] <= window_end and seg['end'] >= window_start:
242
+ start_t = max(seg['start'], window_start)
243
+ end_t = min(seg['end'], window_start + window_timestamp) # Original animation
244
+ start_x = int(((start_t - window_start) / window_duration) * frame_width)
245
+ end_x = int(((end_t - window_start) / window_duration) * frame_width)
246
+ if end_x > start_x:
247
+ cv2.rectangle(
248
+ extended_frame,
249
+ (start_x, pred_bar_y),
250
+ (end_x, pred_bar_y + bar_height),
251
+ action_color_map[seg['label']], # Action-specific color
252
+ -1
253
+ )
254
+
255
+ if font:
256
+ # Convert frame to PIL image
257
+ frame_rgb = cv2.cvtColor(extended_frame, cv2.COLOR_BGR2RGB)
258
+ pil_image = Image.fromarray(frame_rgb)
259
+ draw = ImageDraw.Draw(pil_image)
260
+
261
+ # Draw frame number and FPS at top center
262
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
263
+ frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
264
+ frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
265
+ frame_text_x = (frame_width - frame_text_width) // 2
266
+ draw.text((frame_text_x, 10), frame_info, font=font, fill=(0, 0, 0))
267
+
268
+ # Draw window timestamp range at top of footer
269
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
270
+ window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
271
+ window_text_width = window_text_bbox[2] - window_text_bbox[0]
272
+ window_text_x = (frame_width - window_text_width) // 2
273
+ draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
274
+
275
+ # Draw GT text in video only if there are actions
276
+ if gt_text:
277
+ gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
278
+ draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
279
+
280
+ # Draw predicted text in video only if there are actions
281
+ if pred_text:
282
+ pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
283
+ draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
284
+
285
+ # Draw labels in bars
286
+ for seg in gt_segments:
287
+ if seg['start'] <= window_end and seg['end'] >= window_start:
288
+ label = seg['label'][:8] + '...' if len(seg['label']) > 8 else seg['label']
289
+ start_t = max(seg['start'], window_start)
290
+ end_t = min(seg['end'], window_start + window_timestamp)
291
+ start_x = int(((start_t - window_start) / window_duration) * frame_width)
292
+ end_x = int(((end_t - window_start) / window_duration) * frame_width)
293
+ if end_x - start_x >= 20:
294
+ draw.text(
295
+ ((start_x + end_x) / 2, gt_bar_y + bar_height / 2),
296
+ label,
297
+ font=bar_font,
298
+ fill=(255, 255, 255) # White for readability
299
+ )
300
+ action_color_rgb = (action_color_map[seg['label']][2], action_color_map[seg['label']][1], action_color_map[seg['label']][0])
301
+ draw.text((start_x, gt_bar_y - 10), f"{start_t:.1f}", font=bar_font, fill=action_color_rgb)
302
+ draw.text((end_x, gt_bar_y - 10), f"{end_t:.1f}", font=bar_font, fill=action_color_rgb)
303
+
304
+ for seg in pred_segments:
305
+ if seg['start'] <= window_end and seg['end'] >= window_start:
306
+ label = seg['label'][:8] + '...' if len(seg['label']) > 8 else seg['label']
307
+ start_t = max(seg['start'], window_start)
308
+ end_t = min(seg['end'], window_start + window_timestamp)
309
+ start_x = int(((start_t - window_start) / window_duration) * frame_width)
310
+ end_x = int(((end_t - window_start) / window_duration) * frame_width)
311
+ if end_x - start_x >= 20:
312
+ draw.text(
313
+ ((start_x + end_x) / 2, pred_bar_y + bar_height / 2),
314
+ label,
315
+ font=bar_font,
316
+ fill=(255, 255, 255) # White for readability
317
+ )
318
+ action_color_rgb = (action_color_map[seg['label']][2], action_color_map[seg['label']][1], action_color_map[seg['label']][0])
319
+ draw.text((start_x, pred_bar_y + bar_height + 10), f"{start_t:.1f}", font=bar_font, fill=action_color_rgb)
320
+ draw.text((end_x, pred_bar_y + bar_height + 10), f"{end_t:.1f}", font=bar_font, fill=action_color_rgb)
321
+
322
+ # Draw GT and Pred labels before bars
323
+ draw.text((10, gt_bar_y + bar_height / 2), "GT", font=bar_font, fill=gt_color_rgb)
324
+ draw.text((10, pred_bar_y + bar_height / 2), "Pred", font=bar_font, fill=pred_color_rgb)
325
+
326
+ # Convert back to OpenCV frame
327
+ extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
328
+ else:
329
+ # Fallback to OpenCV font
330
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
331
+ text_size, _ = cv2.getTextSize(frame_info, cv2.FONT_HERSHEY_DUPLEX, text_scale, text_thickness)
332
+ frame_text_x = (frame_width - text_size[0]) // 2
333
+ cv2.putText(
334
+ extended_frame,
335
+ frame_info,
336
+ (frame_text_x, 30),
337
+ cv2.FONT_HERSHEY_DUPLEX,
338
+ text_scale,
339
+ (0, 0, 0),
340
+ text_thickness,
341
+ cv2.LINE_AA
342
+ )
343
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
344
+ window_text_size, _ = cv2.getTextSize(window_info, cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
345
+ window_text_x = (frame_width - window_text_size[0]) // 2
346
+ cv2.putText(
347
+ extended_frame,
348
+ window_info,
349
+ (window_text_x, footer_y + 20),
350
+ cv2.FONT_HERSHEY_DUPLEX,
351
+ VIS_CONFIG['video_bar_text_scale'],
352
+ (0, 0, 0),
353
+ 1,
354
+ cv2.LINE_AA
355
+ )
356
+ if gt_text:
357
+ cv2.putText(
358
+ extended_frame,
359
+ gt_text,
360
+ (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
361
+ cv2.FONT_HERSHEY_DUPLEX,
362
+ text_scale,
363
+ gt_text_color,
364
+ text_thickness,
365
+ cv2.LINE_AA
366
+ )
367
+ if pred_text:
368
+ cv2.putText(
369
+ extended_frame,
370
+ pred_text,
371
+ (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
372
+ cv2.FONT_HERSHEY_DUPLEX,
373
+ text_scale,
374
+ pred_text_color,
375
+ text_thickness,
376
+ cv2.LINE_AA
377
+ )
378
+ for seg in gt_segments:
379
+ if seg['start'] <= window_end and seg['end'] >= window_start:
380
+ label = seg['label'][:8] + '...' if len(seg['label']) > 8 else seg['label']
381
+ start_t = max(seg['start'], window_start)
382
+ end_t = min(seg['end'], window_start + window_timestamp)
383
+ start_x = int(((start_t - window_start) / window_duration) * frame_width)
384
+ end_x = int(((end_t - window_start) / window_duration) * frame_width)
385
+ if end_x - start_x >= 20:
386
+ cv2.putText(
387
+ extended_frame,
388
+ label,
389
+ (start_x + (end_x - start_x) // 2 - 20, gt_bar_y + bar_height // 2 + 5),
390
+ cv2.FONT_HERSHEY_DUPLEX,
391
+ VIS_CONFIG['video_bar_text_scale'],
392
+ (255, 255, 255),
393
+ 1,
394
+ cv2.LINE_AA
395
+ )
396
+ cv2.putText(
397
+ extended_frame,
398
+ f"{start_t:.1f}",
399
+ (start_x, gt_bar_y - 5),
400
+ cv2.FONT_HERSHEY_DUPLEX,
401
+ VIS_CONFIG['video_bar_text_scale'],
402
+ action_color_map[seg['label']],
403
+ 1,
404
+ cv2.LINE_AA
405
+ )
406
+ cv2.putText(
407
+ extended_frame,
408
+ f"{end_t:.1f}",
409
+ (end_x, gt_bar_y - 5),
410
+ cv2.FONT_HERSHEY_DUPLEX,
411
+ VIS_CONFIG['video_bar_text_scale'],
412
+ action_color_map[seg['label']],
413
+ 1,
414
+ cv2.LINE_AA
415
+ )
416
+ for seg in pred_segments:
417
+ if seg['start'] <= window_end and seg['end'] >= window_start:
418
+ label = seg['label'][:8] + '...' if len(seg['label']) > 8 else seg['label']
419
+ start_t = max(seg['start'], window_start)
420
+ end_t = min(seg['end'], window_start + window_timestamp)
421
+ start_x = int(((start_t - window_start) / window_duration) * frame_width)
422
+ end_x = int(((end_t - window_start) / window_duration) * frame_width)
423
+ if end_x - start_x >= 20:
424
+ cv2.putText(
425
+ extended_frame,
426
+ label,
427
+ (start_x + (end_x - start_x) // 2 - 20, pred_bar_y + bar_height // 2 + 5),
428
+ cv2.FONT_HERSHEY_DUPLEX,
429
+ VIS_CONFIG['video_bar_text_scale'],
430
+ (255, 255, 255),
431
+ 1,
432
+ cv2.LINE_AA
433
+ )
434
+ cv2.putText(
435
+ extended_frame,
436
+ f"{start_t:.1f}",
437
+ (start_x, pred_bar_y + bar_height + 15),
438
+ cv2.FONT_HERSHEY_DUPLEX,
439
+ VIS_CONFIG['video_bar_text_scale'],
440
+ action_color_map[seg['label']],
441
+ 1,
442
+ cv2.LINE_AA
443
+ )
444
+ cv2.putText(
445
+ extended_frame,
446
+ f"{end_t:.1f}",
447
+ (end_x, pred_bar_y + bar_height + 15),
448
+ cv2.FONT_HERSHEY_DUPLEX,
449
+ VIS_CONFIG['video_bar_text_scale'],
450
+ action_color_map[seg['label']],
451
+ 1,
452
+ cv2.LINE_AA
453
+ )
454
+ cv2.putText(
455
+ extended_frame,
456
+ "GT",
457
+ (10, gt_bar_y + bar_height // 2 + 5),
458
+ cv2.FONT_HERSHEY_DUPLEX,
459
+ VIS_CONFIG['video_bar_text_scale'],
460
+ gt_text_color,
461
+ 1,
462
+ cv2.LINE_AA
463
+ )
464
+ cv2.putText(
465
+ extended_frame,
466
+ "Pred",
467
+ (10, pred_bar_y + bar_height // 2 + 5),
468
+ cv2.FONT_HERSHEY_DUPLEX,
469
+ VIS_CONFIG['video_bar_text_scale'],
470
+ pred_text_color,
471
+ 1,
472
+ cv2.LINE_AA
473
+ )
474
+
475
+ # Write frame to output video
476
+ out.write(extended_frame)
477
+ written_frames += 1
478
+ frame_idx += 1
479
+
480
+ # Release resources
481
+ cap.release()
482
+ out.release()
483
+ print(f"[✅ Saved Annotated Video]: {output_path}, Written Frames={written_frames}")
484
+ print("Note: If .avi is not playable, convert to .mp4 using FFmpeg:")
485
+ print(f"ffmpeg -i {output_path} -vcodec libx264 -acodec aac {output_path.replace('.avi', '.mp4')}")
486
+
487
+
488
+
489
+
490
+
491
+
492
+
493
+
494
+
495
+
496
+ def visualize_action_lengths(
497
+ video_id: str,
498
+ pred_segments: List[Dict],
499
+ gt_segments: List[Dict],
500
+ video_path: str,
501
+ duration: float,
502
+ save_dir: str = VIS_CONFIG['save_dir'],
503
+ frame_interval: float = VIS_CONFIG['frame_interval']
504
+ ) -> None:
505
+ """
506
+ Generate a visualization plot comparing ground truth and predicted action lengths with video frames.
507
+
508
+ Args:
509
+ video_id: Video identifier (e.g., 'my_video').
510
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
511
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
512
+ video_path: Path to the input video file.
513
+ duration: Total duration of the video in seconds.
514
+ save_dir: Directory to save the output image.
515
+ frame_interval: Time interval between sampled frames (seconds).
516
+ """
517
+ os.makedirs(save_dir, exist_ok=True)
518
+
519
+ # Calculate frame sampling times
520
+ num_frames = int(duration / frame_interval) + 1
521
+ if num_frames > VIS_CONFIG['max_frames']:
522
+ frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
523
+ num_frames = VIS_CONFIG['max_frames']
524
+ print(f"Warning: Video duration ({duration:.1f}s) requires {num_frames} frames. Adjusted frame_interval to {frame_interval:.2f}s.")
525
+
526
+ frame_times = np.linspace(0, duration, num_frames, endpoint=False)
527
+
528
+ # Load video frames
529
+ frames = []
530
+ cap = cv2.VideoCapture(video_path)
531
+ if not cap.isOpened():
532
+ print(f"Warning: Could not open video {video_path}. Using placeholder frames.")
533
+ frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times]
534
+ else:
535
+ for t in frame_times:
536
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
537
+ ret, frame = cap.read()
538
+ if ret:
539
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
540
+ # Resize frame to reduce memory usage
541
+ frame = cv2.resize(frame, (int(frame.shape[1] * 0.5), int(frame.shape[0] * 0.5)))
542
+ frames.append(frame)
543
+ else:
544
+ frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
545
+ cap.release()
546
+
547
+ # Initialize figure
548
+ fig = plt.figure(figsize=(num_frames * VIS_CONFIG['frame_scale_factor'], 6), constrained_layout=True)
549
+ gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
550
+
551
+ # Plot frames
552
+ for i, (t, frame) in enumerate(zip(frame_times, frames)):
553
+ ax = fig.add_subplot(gs[0, i])
554
+
555
+ # Check if frame falls within GT or predicted segments
556
+ gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments)
557
+ pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments)
558
+
559
+ # Set border color
560
+ border_color = None
561
+ if gt_hit and pred_hit:
562
+ border_color = VIS_CONFIG['frame_highlight_both']
563
+ elif gt_hit:
564
+ border_color = VIS_CONFIG['frame_highlight_gt']
565
+ elif pred_hit:
566
+ border_color = VIS_CONFIG['frame_highlight_pred']
567
+
568
+ ax.imshow(frame)
569
+ ax.axis('off')
570
+ if border_color:
571
+ for spine in ax.spines.values():
572
+ spine.set_edgecolor(border_color)
573
+ spine.set_linewidth(2)
574
+
575
+ ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'],
576
+ color=border_color if border_color else 'black')
577
+
578
+ # Plot ground truth bar
579
+ ax_gt = fig.add_subplot(gs[1, :])
580
+ ax_gt.set_xlim(0, duration)
581
+ ax_gt.set_ylim(0, 1)
582
+ ax_gt.axis('off')
583
+ ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'],
584
+ va='center', ha='right', weight='bold')
585
+
586
+ for seg in gt_segments:
587
+ start, end = seg['start'], seg['end']
588
+ width = end - start
589
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
590
+ ax_gt.add_patch(patches.Rectangle(
591
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'],
592
+ edgecolor='black', alpha=0.8
593
+ ))
594
+ ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center',
595
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
596
+ ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black')
597
+ ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black')
598
+
599
+ # Plot prediction bar
600
+ ax_pred = fig.add_subplot(gs[2, :])
601
+ ax_pred.set_xlim(0, duration)
602
+ ax_pred.set_ylim(0, 1)
603
+ ax_pred.axis('off')
604
+ ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'],
605
+ va='center', ha='right', weight='bold')
606
+
607
+ for seg in pred_segments:
608
+ start, end = seg['start'], seg['end']
609
+ width = end - start
610
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
611
+ ax_pred.add_patch(patches.Rectangle(
612
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'],
613
+ edgecolor='black', alpha=0.8
614
+ ))
615
+ ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center',
616
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
617
+ ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black')
618
+ ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black')
619
+
620
+ # Save plot
621
+ jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png") # Use PNG
622
+ plt.savefig(jpg_path, dpi=100, bbox_inches='tight') # Lower DPI
623
+ print(f"[✅ Saved Visualization]: {jpg_path}")
624
+ plt.close()
625
+
626
+
627
+
628
+ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
629
+ train_loader = torch.utils.data.DataLoader(train_dataset,
630
+ batch_size=opt['batch_size'], shuffle=True,
631
+ num_workers=0, pin_memory=True, drop_last=False)
632
+ epoch_cost = 0
633
+ epoch_cost_cls = 0
634
+ epoch_cost_reg = 0
635
+ epoch_cost_snip = 0
636
+
637
+ total_iter = len(train_dataset) // opt['batch_size']
638
+ cls_loss = MultiCrossEntropyLoss(focal=True)
639
+ snip_loss = MultiCrossEntropyLoss(focal=True)
640
+ for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)):
641
+ if warmup:
642
+ for g in optimizer.param_groups:
643
+ g['lr'] = n_iter * (opt['lr']) / total_iter
644
+
645
+ act_cls, act_reg, snip_cls = model(input_data.float().cuda())
646
+
647
+ act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
648
+ snip_cls.register_hook(partial(snip_loss.collect_grad, snip_label))
649
+
650
+ cost_reg = 0
651
+ cost_cls = 0
652
+
653
+ loss = cls_loss_func_(cls_loss, cls_label, act_cls)
654
+ cost_cls = loss
655
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
656
+
657
+ loss = regress_loss_func(reg_label, act_reg)
658
+ cost_reg = loss
659
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
660
+
661
+ loss = cls_loss_func_(snip_loss, snip_label, snip_cls)
662
+ cost_snip = loss
663
+ epoch_cost_snip += cost_snip.detach().cpu().numpy()
664
+
665
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip
666
+ epoch_cost += cost.detach().cpu().numpy()
667
+
668
+ optimizer.zero_grad()
669
+ cost.backward()
670
+ optimizer.step()
671
+
672
+ return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
673
+
674
+ def eval_one_epoch(opt, model, test_dataset):
675
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, test_dataset)
676
+
677
+ result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
678
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
679
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
680
+ json.dump(output_dict, outfile, indent=2)
681
+ outfile.close()
682
+
683
+ IoUmAP = evaluation_detection(opt, verbose=False)
684
+ IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:])
685
+
686
+ return cls_loss, reg_loss, tot_loss, IoUmAP_5
687
+
688
+ def train(opt):
689
+ writer = SummaryWriter()
690
+ model = MYNET(opt).cuda()
691
+
692
+ rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
693
+ optimizer = optim.Adam([{'params': model.history_unit.parameters(), 'lr': 1e-6}, {'params': rest_of_model_params}], lr=opt["lr"], weight_decay=opt["weight_decay"])
694
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt["lr_step"])
695
+
696
+ train_dataset = VideoDataSet(opt, subset="train")
697
+ test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
698
+
699
+ warmup = False
700
+
701
+ for n_epoch in range(opt['epoch']):
702
+ if n_epoch >= 1:
703
+ warmup = False
704
+
705
+ n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
706
+
707
+ writer.add_scalars('data/cost', {'train': epoch_cost / (n_iter + 1)}, n_epoch)
708
+ print("training loss(epoch %d): %.03f, cls - %f, reg - %f, snip - %f, lr - %f" % (n_epoch,
709
+ epoch_cost / (n_iter + 1),
710
+ epoch_cost_cls / (n_iter + 1),
711
+ epoch_cost_reg / (n_iter + 1),
712
+ epoch_cost_snip / (n_iter + 1),
713
+ optimizer.param_groups[-1]["lr"]))
714
+
715
+ scheduler.step()
716
+ model.eval()
717
+
718
+ cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
719
+
720
+ writer.add_scalars('data/mAP', {'test': IoUmAP_5}, n_epoch)
721
+ print("testing loss(epoch %d): %.03f, cls - %f, reg - %f, mAP Avg - %f" % (n_epoch, tot_loss, cls_loss, reg_loss, IoUmAP_5))
722
+
723
+ state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
724
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_checkpoint_" + str(n_epoch + 1) + ".pth.tar")
725
+ if IoUmAP_5 > model.best_map:
726
+ model.best_map = IoUmAP_5
727
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_ckp_best.pth.tar")
728
+
729
+ model.train()
730
+
731
+ writer.close()
732
+ return model.best_map
733
+
734
+ def eval_frame(opt, model, dataset):
735
+ test_loader = torch.utils.data.DataLoader(dataset,
736
+ batch_size=opt['batch_size'], shuffle=False,
737
+ num_workers=0, pin_memory=True, drop_last=False)
738
+
739
+ labels_cls = {}
740
+ labels_reg = {}
741
+ output_cls = {}
742
+ output_reg = {}
743
+ for video_name in dataset.video_list:
744
+ labels_cls[video_name] = []
745
+ labels_reg[video_name] = []
746
+ output_cls[video_name] = []
747
+ output_reg[video_name] = []
748
+
749
+ start_time = time.time()
750
+ total_frames = 0
751
+ epoch_cost = 0
752
+ epoch_cost_cls = 0
753
+ epoch_cost_reg = 0
754
+
755
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
756
+ act_cls, act_reg, _ = model(input_data.float().cuda())
757
+ cost_reg = 0
758
+ cost_cls = 0
759
+
760
+ loss = cls_loss_func(cls_label, act_cls)
761
+ cost_cls = loss
762
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
763
+
764
+ loss = regress_loss_func(reg_label, act_reg)
765
+ cost_reg = loss
766
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
767
+
768
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
769
+ epoch_cost += cost.detach().cpu().numpy()
770
+
771
+ act_cls = torch.softmax(act_cls, dim=-1)
772
+
773
+ total_frames += input_data.size(0)
774
+
775
+ for b in range(0, input_data.size(0)):
776
+ video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + b]
777
+ output_cls[video_name] += [act_cls[b, :].detach().cpu().numpy()]
778
+ output_reg[video_name] += [act_reg[b, :].detach().cpu().numpy()]
779
+ labels_cls[video_name] += [cls_label[b, :].numpy()]
780
+ labels_reg[video_name] += [reg_label[b, :].numpy()]
781
+
782
+ end_time = time.time()
783
+ working_time = end_time - start_time
784
+
785
+ for video_name in dataset.video_list:
786
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
787
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
788
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
789
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
790
+
791
+ cls_loss = epoch_cost_cls / n_iter
792
+ reg_loss = epoch_cost_reg / n_iter
793
+ tot_loss = epoch_cost / n_iter
794
+
795
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
796
+
797
+ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
798
+ result_dict = {}
799
+ proposal_dict = []
800
+
801
+ num_class = opt["num_of_class"]
802
+ unit_size = opt['segment_size']
803
+ threshold = opt['threshold']
804
+ anchors = opt['anchors']
805
+
806
+ for video_name in dataset.video_list:
807
+ duration = dataset.video_len[video_name]
808
+ video_time = float(dataset.video_dict[video_name]["duration"])
809
+ frame_to_time = 100.0 * video_time / duration
810
+
811
+ for idx in range(0, duration):
812
+ cls_anc = output_cls[video_name][idx]
813
+ reg_anc = output_reg[video_name][idx]
814
+
815
+ proposal_anc_dict = []
816
+ for anc_idx in range(0, len(anchors)):
817
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
818
+
819
+ if len(cls) == 0:
820
+ continue
821
+
822
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
823
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
824
+ st = ed - length
825
+
826
+ for cidx in range(0, len(cls)):
827
+ label = cls[cidx]
828
+ tmp_dict = {}
829
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
830
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
831
+ tmp_dict["label"] = dataset.label_name[label]
832
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
833
+ proposal_anc_dict.append(tmp_dict)
834
+
835
+ proposal_dict += proposal_anc_dict
836
+
837
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
838
+ result_dict[video_name] = proposal_dict
839
+ proposal_dict = []
840
+
841
+ return result_dict
842
+
843
+ def eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
844
+ model = SuppressNet(opt).cuda()
845
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
846
+ base_dict = checkpoint['state_dict']
847
+ model.load_state_dict(base_dict)
848
+ model.eval()
849
+
850
+ result_dict = {}
851
+ proposal_dict = []
852
+
853
+ num_class = opt["num_of_class"]
854
+ unit_size = opt['segment_size']
855
+ threshold = opt['threshold']
856
+ anchors = opt['anchors']
857
+
858
+ for video_name in dataset.video_list:
859
+ duration = dataset.video_len[video_name]
860
+ video_time = float(dataset.video_dict[video_name]["duration"])
861
+ frame_to_time = 100.0 * video_time / duration
862
+ conf_queue = torch.zeros((unit_size, num_class - 1))
863
+
864
+ for idx in range(0, duration):
865
+ cls_anc = output_cls[video_name][idx]
866
+ reg_anc = output_reg[video_name][idx]
867
+
868
+ proposal_anc_dict = []
869
+ for anc_idx in range(0, len(anchors)):
870
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
871
+
872
+ if len(cls) == 0:
873
+ continue
874
+
875
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
876
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
877
+ st = ed - length
878
+
879
+ for cidx in range(0, len(cls)):
880
+ label = cls[cidx]
881
+ tmp_dict = {}
882
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
883
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
884
+ tmp_dict["label"] = dataset.label_name[label]
885
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
886
+ proposal_anc_dict.append(tmp_dict)
887
+
888
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
889
+
890
+ conf_queue[:-1, :] = conf_queue[1:, :].clone()
891
+ conf_queue[-1, :] = 0
892
+ for proposal in proposal_anc_dict:
893
+ cls_idx = dataset.label_name.index(proposal['label'])
894
+ conf_queue[-1, cls_idx] = proposal["score"]
895
+
896
+ minput = conf_queue.unsqueeze(0)
897
+ suppress_conf = model(minput.cuda())
898
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
899
+
900
+ for cls in range(0, num_class - 1):
901
+ if suppress_conf[cls] > opt['sup_threshold']:
902
+ for proposal in proposal_anc_dict:
903
+ if proposal['label'] == dataset.label_name[cls]:
904
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
905
+ proposal_dict.append(proposal)
906
+
907
+ result_dict[video_name] = proposal_dict
908
+ proposal_dict = []
909
+
910
+ return result_dict
911
+
912
+ def test_frame(opt, video_name=None):
913
+ model = MYNET(opt).cuda()
914
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
915
+ base_dict = checkpoint['state_dict']
916
+ model.load_state_dict(base_dict)
917
+ model.eval()
918
+
919
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
920
+ outfile = h5py.File(opt['frame_result_file'].format(opt['exp']), 'w')
921
+
922
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
923
+
924
+ print("testing loss: %f, cls_loss: %f, reg_loss: %f" % (tot_loss, cls_loss, reg_loss))
925
+
926
+ for video_name in dataset.video_list:
927
+ o_cls = output_cls[video_name]
928
+ o_reg = output_reg[video_name]
929
+ l_cls = labels_cls[video_name]
930
+ l_reg = labels_reg[video_name]
931
+
932
+ dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
933
+ dset_predcls[:, :] = o_cls[:, :]
934
+ dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
935
+ dset_predreg[:, :] = o_reg[:, :]
936
+ dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
937
+ dset_labelcls[:, :] = l_cls[:, :]
938
+ dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
939
+ dset_labelreg[:, :] = l_reg[:, :]
940
+ outfile.close()
941
+
942
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
943
+ return cls_loss, reg_loss, tot_loss
944
+
945
+ def patch_attention(m):
946
+ forward_orig = m.forward
947
+
948
+ def wrap(*args, **kwargs):
949
+ kwargs["need_weights"] = True
950
+ kwargs["average_attn_weights"] = False
951
+ return forward_orig(*args, **kwargs)
952
+
953
+ m.forward = wrap
954
+
955
+ class SaveOutput:
956
+ def __init__(self):
957
+ self.outputs = []
958
+
959
+ def __call__(self, module, module_in, module_out):
960
+ self.outputs.append(module_out[1])
961
+
962
+ def clear(self):
963
+ self.outputs = []
964
+
965
+ def test(opt, video_name=None):
966
+ model = MYNET(opt).cuda()
967
+ checkpoint = torch.load(opt["checkpoint_path"] + "/" + opt['exp'] + "_ckp_best.pth.tar")
968
+ base_dict = checkpoint['state_dict']
969
+ model.load_state_dict(base_dict)
970
+ model.eval()
971
+
972
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
973
+
974
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
975
+
976
+ if opt["pptype"] == "nms":
977
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
978
+ if opt["pptype"] == "net":
979
+ result_dict = eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
980
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
981
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
982
+ json.dump(output_dict, outfile, indent=2)
983
+ outfile.close()
984
+
985
+ mAP = evaluation_detection(opt)
986
+
987
+ # Compare predicted and ground truth action lengths
988
+ if video_name:
989
+ print("\nComparing Predicted and Ground Truth Action Lengths for Video:", video_name)
990
+ with open(opt["video_anno"].format(opt["split"]), 'r') as f:
991
+ anno_data = json.load(f)
992
+ gt_annotations = anno_data['database'][video_name]['annotations']
993
+ duration = anno_data['database'][video_name]['duration']
994
+
995
+ gt_segments = []
996
+ for anno in gt_annotations:
997
+ start, end = anno['segment']
998
+ label = anno['label']
999
+ duration_seg = end - start
1000
+ gt_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg})
1001
+
1002
+ pred_segments = []
1003
+ for pred in result_dict[video_name]:
1004
+ start, end = pred['segment']
1005
+ label = pred['label']
1006
+ score = pred['score']
1007
+ duration_seg = end - start
1008
+ pred_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg, 'score': score})
1009
+
1010
+ # Print comparison table
1011
+ matches = []
1012
+ iou_threshold = VIS_CONFIG['iou_threshold']
1013
+ used_gt_indices = set()
1014
+ for pred in pred_segments:
1015
+ best_iou = 0
1016
+ best_gt_idx = None
1017
+ for gt_idx, gt in enumerate(gt_segments):
1018
+ if gt_idx in used_gt_indices:
1019
+ continue
1020
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
1021
+ if iou > best_iou and iou >= iou_threshold:
1022
+ best_iou = iou
1023
+ best_gt_idx = gt_idx
1024
+ if best_gt_idx is not None:
1025
+ matches.append({
1026
+ 'pred': pred,
1027
+ 'gt': gt_segments[best_gt_idx],
1028
+ 'iou': best_iou
1029
+ })
1030
+ used_gt_indices.add(best_gt_idx)
1031
+ else:
1032
+ matches.append({'pred': pred, 'gt': None, 'iou': 0})
1033
+
1034
+ for gt_idx, gt in enumerate(gt_segments):
1035
+ if gt_idx not in used_gt_indices:
1036
+ matches.append({'pred': None, 'gt': gt, 'iou': 0})
1037
+
1038
+ print("\n{:<20} {:<30} {:<30} {:<15} {:<10}".format(
1039
+ "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU"))
1040
+ print("-" * 105)
1041
+ for match in matches:
1042
+ pred = match['pred']
1043
+ gt = match['gt']
1044
+ iou = match['iou']
1045
+ if pred and gt:
1046
+ label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
1047
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
1048
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
1049
+ duration_diff = pred['duration'] - gt['duration']
1050
+ print("{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}".format(
1051
+ label, pred_str, gt_str, duration_diff, iou))
1052
+ elif pred:
1053
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
1054
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
1055
+ pred['label'], pred_str, "None", "N/A", iou))
1056
+ elif gt:
1057
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
1058
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
1059
+ gt['label'], "None", gt_str, "N/A", iou))
1060
+
1061
+ # Summarize
1062
+ matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
1063
+ avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0
1064
+ avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0
1065
+ print(f"\nSummary:")
1066
+ print(f"- Total Predictions: {len(pred_segments)}")
1067
+ print(f"- Total Ground Truth: {len(gt_segments)}")
1068
+ print(f"- Matched Segments: {matched_count}")
1069
+ print(f"- Average Duration Difference (Matched): {avg_duration_diff:.2f}s")
1070
+ print(f"- Average IoU (Matched): {avg_iou:.2f}")
1071
+
1072
+ # Generate static visualization
1073
+ video_path = opt.get('video_path', '')
1074
+ if os.path.exists(video_path):
1075
+ visualize_action_lengths(
1076
+ video_id=video_name,
1077
+ pred_segments=pred_segments,
1078
+ gt_segments=gt_segments,
1079
+ video_path=video_path,
1080
+ duration=duration
1081
+ )
1082
+ # Generate annotated video
1083
+ annotate_video_with_actions(
1084
+ video_id=video_name,
1085
+ pred_segments=pred_segments,
1086
+ gt_segments=gt_segments,
1087
+ video_path=video_path
1088
+ )
1089
+ else:
1090
+ print(f"Warning: Video path {video_path} not found. Skipping visualization and video annotation.")
1091
+
1092
+ return mAP
1093
+
1094
+ def test_online(opt, video_name=None):
1095
+ model = MYNET(opt).cuda()
1096
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
1097
+ base_dict = checkpoint['state_dict']
1098
+ model.load_state_dict(base_dict)
1099
+ model.eval()
1100
+
1101
+ sup_model = SuppressNet(opt).cuda()
1102
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
1103
+ base_dict = checkpoint['state_dict']
1104
+ sup_model.load_state_dict(base_dict)
1105
+ sup_model.eval()
1106
+
1107
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
1108
+ test_loader = torch.utils.data.DataLoader(dataset,
1109
+ batch_size=1, shuffle=False,
1110
+ num_workers=0, pin_memory=True, drop_last=False)
1111
+
1112
+ result_dict = {}
1113
+ proposal_dict = []
1114
+
1115
+ num_class = opt["num_of_class"]
1116
+ unit_size = opt['segment_size']
1117
+ threshold = opt['threshold']
1118
+ anchors = opt['anchors']
1119
+
1120
+ start_time = time.time()
1121
+ total_frames = 0
1122
+
1123
+ for video_name in dataset.video_list:
1124
+ input_queue = torch.zeros((unit_size, opt['feat_dim']))
1125
+ sup_queue = torch.zeros(((unit_size, num_class - 1)))
1126
+
1127
+ duration = dataset.video_len[video_name]
1128
+ video_time = float(dataset.video_dict[video_name]["duration"])
1129
+ frame_to_time = 100.0 * video_time / duration
1130
+
1131
+ for idx in range(0, duration):
1132
+ total_frames += 1
1133
+ input_queue[:-1, :] = input_queue[1:, :].clone()
1134
+ input_queue[-1:, :] = dataset._get_base_data(video_name, idx, idx + 1)
1135
+
1136
+ minput = input_queue.unsqueeze(0)
1137
+ act_cls, act_reg, _ = model(minput.cuda())
1138
+ act_cls = torch.softmax(act_cls, dim=-1)
1139
+
1140
+ cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
1141
+ reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
1142
+
1143
+ proposal_anc_dict = []
1144
+ for anc_idx in range(0, len(anchors)):
1145
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
1146
+
1147
+ if len(cls) == 0:
1148
+ continue
1149
+
1150
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
1151
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
1152
+ st = ed - length
1153
+
1154
+ for cidx in range(0, len(cls)):
1155
+ label = cls[cidx]
1156
+ tmp_dict = {}
1157
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
1158
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
1159
+ tmp_dict["label"] = dataset.label_name[label]
1160
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
1161
+ proposal_anc_dict.append(tmp_dict)
1162
+
1163
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
1164
+
1165
+ sup_queue[:-1, :] = sup_queue[1:, :].clone()
1166
+ sup_queue[-1, :] = 0
1167
+ for proposal in proposal_anc_dict:
1168
+ cls_idx = dataset.label_name.index(proposal['label'])
1169
+ sup_queue[-1, cls_idx] = proposal["score"]
1170
+
1171
+ minput = sup_queue.unsqueeze(0)
1172
+ suppress_conf = sup_model(minput.cuda())
1173
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
1174
+
1175
+ for cls in range(0, num_class - 1):
1176
+ if suppress_conf[cls] > opt['sup_threshold']:
1177
+ for proposal in proposal_anc_dict:
1178
+ if proposal['label'] == dataset.label_name[cls]:
1179
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
1180
+ proposal_dict.append(proposal)
1181
+
1182
+ result_dict[video_name] = proposal_dict
1183
+ proposal_dict = []
1184
+
1185
+ end_time = time.time()
1186
+ working_time = end_time - start_time
1187
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
1188
+
1189
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
1190
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
1191
+ json.dump(output_dict, outfile, indent=2)
1192
+ outfile.close()
1193
+
1194
+ mAP = evaluation_detection(opt)
1195
+ return mAP
1196
+
1197
+ def main(opt, video_name=None):
1198
+ max_perf = 0
1199
+ if not video_name and 'video_name' in opt:
1200
+ video_name = opt['video_name']
1201
+
1202
+ if opt['mode'] == 'train':
1203
+ max_perf = train(opt)
1204
+ if opt['mode'] == 'test':
1205
+ max_perf = test(opt, video_name=video_name)
1206
+ if opt['mode'] == 'test_frame':
1207
+ max_perf = test_frame(opt, video_name=video_name)
1208
+ if opt['mode'] == 'test_online':
1209
+ max_perf = test_online(opt, video_name=video_name)
1210
+ if opt['mode'] == 'eval':
1211
+ max_perf = evaluation_detection(opt)
1212
+
1213
+ return max_perf
1214
+
1215
+ if __name__ == '__main__':
1216
+ opt = opts.parse_opt()
1217
+ opt = vars(opt)
1218
+ if not os.path.exists(opt["checkpoint_path"]):
1219
+ os.makedirs(opt["checkpoint_path"])
1220
+ opt_file = open(opt["checkpoint_path"] + "/" + opt["exp"] + "_opts.json", "w")
1221
+ json.dump(opt, opt_file)
1222
+ opt_file.close()
1223
+
1224
+ if opt['seed'] >= 0:
1225
+ seed = opt['seed']
1226
+ torch.manual_seed(seed)
1227
+ np.random.seed(seed)
1228
+
1229
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
1230
+
1231
+ video_name = opt.get('video_name', None)
1232
+ main(opt, video_name=video_name)
1233
+ while(opt['wterm']):
1234
+ pass
iou_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def non_max_suppression(proposals, overlapThresh=0.3):
4
+ # if there are no intervals, return an empty list
5
+ if len(proposals) == 0:
6
+ return []
7
+
8
+ # initialize the list of picked indexes
9
+ pick = []
10
+
11
+ sorted_proposal = sorted(proposals, key=lambda proposal:proposal['score'], reverse=True)
12
+ idx=0
13
+ total_proposal= len(sorted_proposal)
14
+ while idx < total_proposal:
15
+ proposal = sorted_proposal[idx]
16
+ st = proposal['segment'][0]
17
+ ed = proposal['segment'][1]
18
+ label = proposal['label']
19
+
20
+ delete_item = []
21
+ for j in range(idx+1, total_proposal):
22
+ target_proposal = sorted_proposal[j]
23
+ target_st = target_proposal['segment'][0]
24
+ target_ed = target_proposal['segment'][1]
25
+ target_label = target_proposal['label']
26
+
27
+ if(label == target_label):
28
+ sst = np.minimum(st, target_st)
29
+ led = np.maximum(ed, target_ed)
30
+ lst = np.maximum(st, target_st)
31
+ sed = np.minimum(ed, target_ed)
32
+
33
+ iou = (sed-lst) / max(led-sst,1)
34
+ if(iou > overlapThresh):
35
+ delete_item.append(target_proposal)
36
+
37
+ for item in delete_item:
38
+ sorted_proposal.remove(item)
39
+ total_proposal=len(sorted_proposal)
40
+ idx+=1
41
+
42
+ return sorted_proposal
43
+
44
+
45
+ def check_overlap_proposal(proposal_list, new_proposal, overlapThresh=0.3):
46
+ for proposal in proposal_list:
47
+ st = proposal['segment'][0]
48
+ ed = proposal['segment'][1]
49
+ label = proposal['label']
50
+
51
+ new_st = new_proposal['segment'][0]
52
+ new_ed = new_proposal['segment'][1]
53
+ new_label = new_proposal['label']
54
+
55
+ if(label == new_label):
56
+ sst = np.minimum(st, new_st)
57
+ led = np.maximum(ed, new_ed)
58
+ lst = np.maximum(st, new_st)
59
+ sed = np.minimum(ed, new_ed)
60
+
61
+ iou = (sed-lst) / max(led-sst,1)
62
+ if(iou > overlapThresh):
63
+ return proposal
64
+
65
+ return None
loss_func.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+ from functools import partial
7
+
8
+ class MultiCrossEntropyLoss(nn.Module):
9
+ def __init__(self, focal=False, weight=None, reduce=True):
10
+ super(MultiCrossEntropyLoss, self).__init__()
11
+ self.num_classes = 23
12
+ self.focal = focal
13
+ self.weight= weight
14
+ self.reduce = reduce
15
+ self.gamma_ = torch.zeros(self.num_classes).cuda() + 0.025
16
+ self.gamma_f = 0.05
17
+
18
+ self.register_buffer('pos_grad', torch.zeros(self.num_classes-1).cuda())
19
+ self.register_buffer('neg_grad', torch.zeros(self.num_classes-1).cuda())
20
+ self.register_buffer('pos_neg', torch.ones(self.num_classes-1).cuda())
21
+
22
+ def forward(self, input, target):
23
+ target_sum = torch.sum(target, dim=1)
24
+ target_div = torch.where(target_sum != 0, target_sum, torch.ones_like(target_sum)).unsqueeze(1)
25
+ target = target/target_div
26
+ logsoftmax = nn.LogSoftmax(dim=1).to(input.device)
27
+ gamma = self.gamma_.clone()
28
+ gamma[:-1] = gamma[:-1] + self.gamma_f * (1 - self.pos_neg)
29
+
30
+ if not self.focal:
31
+ if self.weight is None:
32
+ output = torch.sum(-target * logsoftmax(input), 1)
33
+ else:
34
+ output = torch.sum(-target * logsoftmax(input) /self.weight, 1)
35
+ else:
36
+ softmax = nn.Softmax(dim=1).to(input.device)
37
+ p = softmax(input)
38
+
39
+ output = torch.sum(-target * (1 - p)**gamma * logsoftmax(input), 1)
40
+
41
+
42
+ if self.reduce:
43
+ return torch.mean(output)
44
+ else:
45
+ return output
46
+
47
+
48
+ def map_func(self, x, s):
49
+ min_val = torch.min(x)
50
+ max_val = torch.max(x)
51
+ mu = torch.mean(x)
52
+ x = (x - min_val) / (max_val - min_val)
53
+ return 1 / (1 + torch.exp(-s * (x - mu)))
54
+
55
+ def collect_grad(self, target, grad):
56
+ grad = torch.abs(grad.reshape(-1, grad.shape[-1])).cuda()
57
+ target = target.reshape(-1, target.shape[-1]).cuda()
58
+ pos_grad = torch.sum(grad * target, dim=0)[:-1]
59
+ neg_grad = torch.sum(grad * (1 - target), dim=0)[:-1]
60
+ self.pos_grad += pos_grad
61
+ self.neg_grad += neg_grad
62
+ self.pos_neg = torch.clamp(self.pos_grad / (self.neg_grad + 1e-10), min=0, max=1)
63
+ self.pos_neg = self.map_func(self.pos_neg, 1)
64
+
65
+
66
+ def cls_loss_func(y,output, use_focal=False, weight=None, reduce=True):
67
+ input_size=y.size()
68
+ y = y.float().cuda()
69
+ if weight is not None:
70
+ weight = weight.cuda()
71
+ loss_func = MultiCrossEntropyLoss(focal=True, weight=weight, reduce=reduce)
72
+
73
+ y=y.reshape(-1,y.size(-1))
74
+ output=output.reshape(-1,output.size(-1))
75
+ loss = loss_func(output,y)
76
+
77
+ if not reduce:
78
+ loss = loss.reshape(input_size[:-1])
79
+
80
+ return loss
81
+
82
+
83
+ def cls_loss_func_(loss_func, y,output, use_focal=False, weight=None, reduce=True):
84
+ input_size=y.size()
85
+ y = y.float().cuda()
86
+ if weight is not None:
87
+ weight = weight.cuda()
88
+
89
+ y=y.reshape(-1,y.size(-1))
90
+ output=output.reshape(-1,output.size(-1))
91
+ loss = loss_func(output,y)
92
+
93
+ if not reduce:
94
+ loss = loss.reshape(input_size[:-1])
95
+
96
+ return loss
97
+
98
+ def regress_loss_func(y,output):
99
+ y = y.float().cuda()
100
+ y=y.reshape(-1,y.size(-1))
101
+ output=output.reshape(-1,output.size(-1))
102
+
103
+ bgmask= y[:,1] < -1e2
104
+
105
+ fg_logits = output[~bgmask]
106
+ bg_logits = output[bgmask]
107
+
108
+ fg_target = y[~bgmask]
109
+ bg_target = y[bgmask]
110
+
111
+ loss = nn.functional.l1_loss(fg_logits,fg_target)
112
+
113
+ if(loss.isnan()):
114
+ return torch.tensor([0.0], requires_grad=True).cuda()
115
+ return loss
116
+
117
+
118
+ def suppress_loss_func(y,output):
119
+ y = y.float().cuda()
120
+ y=y.reshape(-1,y.size(-1))
121
+ output=output.reshape(-1,output.size(-1))
122
+
123
+ loss = nn.functional.binary_cross_entropy(output,y)
124
+
125
+ return loss
126
+
127
+
128
+ # import torch
129
+ # import numpy as np
130
+ # import torch.nn as nn
131
+ # import torch.nn.functional as F
132
+ # import torch.distributed as dist
133
+ # from functools import partial
134
+
135
+ # class MultiCrossEntropyLoss(nn.Module):
136
+ # def __init__(self, focal=False, weight=None, reduce=True):
137
+ # super(MultiCrossEntropyLoss, self).__init__()
138
+ # self.num_classes = 23
139
+ # self.focal = focal
140
+ # self.weight= weight
141
+ # self.reduce = reduce
142
+ # self.gamma_ = torch.zeros(self.num_classes).cuda() + 0.025
143
+ # self.gamma_f = 0.05
144
+
145
+ # self.register_buffer('pos_grad', torch.zeros(self.num_classes-1).cuda())
146
+ # self.register_buffer('neg_grad', torch.zeros(self.num_classes-1).cuda())
147
+ # self.register_buffer('pos_neg', torch.ones(self.num_classes-1).cuda())
148
+
149
+ # def forward(self, input, target):
150
+ # target_sum = torch.sum(target, dim=1)
151
+ # target_div = torch.where(target_sum != 0, target_sum, torch.ones_like(target_sum)).unsqueeze(1)
152
+ # target = target/target_div
153
+ # logsoftmax = nn.LogSoftmax(dim=1).to(input.device)
154
+ # gamma = self.gamma_.clone()
155
+ # gamma[:-1] = gamma[:-1] + self.gamma_f * (1 - self.pos_neg)
156
+
157
+ # if not self.focal:
158
+ # if self.weight is None:
159
+ # output = torch.sum(-target * logsoftmax(input), 1)
160
+ # else:
161
+ # output = torch.sum(-target * logsoftmax(input) /self.weight, 1)
162
+ # else:
163
+ # softmax = nn.Softmax(dim=1).to(input.device)
164
+ # p = softmax(input)
165
+
166
+ # output = torch.sum(-target * (1 - p)**gamma * logsoftmax(input), 1)
167
+
168
+
169
+ # if self.reduce:
170
+ # return torch.mean(output)
171
+ # else:
172
+ # return output
173
+
174
+
175
+ # def map_func(self, x, s):
176
+ # min_val = torch.min(x)
177
+ # max_val = torch.max(x)
178
+ # mu = torch.mean(x)
179
+ # x = (x - min_val) / (max_val - min_val)
180
+ # return 1 / (1 + torch.exp(-s * (x - mu)))
181
+
182
+ # def collect_grad(self, target, grad):
183
+ # grad = torch.abs(grad.reshape(-1, grad.shape[-1])).cuda()
184
+ # target = target.reshape(-1, target.shape[-1]).cuda()
185
+ # pos_grad = torch.sum(grad * target, dim=0)[:-1]
186
+ # neg_grad = torch.sum(grad * (1 - target), dim=0)[:-1]
187
+ # self.pos_grad += pos_grad
188
+ # self.neg_grad += neg_grad
189
+ # self.pos_neg = torch.clamp(self.pos_grad / (self.neg_grad + 1e-10), min=0, max=1)
190
+ # self.pos_neg = self.map_func(self.pos_neg, 1)
191
+
192
+
193
+ # def cls_loss_func(y,output, use_focal=False, weight=None, reduce=True):
194
+ # input_size=y.size()
195
+ # y = y.float().cuda()
196
+ # if weight is not None:
197
+ # weight = weight.cuda()
198
+ # loss_func = MultiCrossEntropyLoss(focal=True, weight=weight, reduce=reduce)
199
+
200
+ # y=y.reshape(-1,y.size(-1))
201
+ # output=output.reshape(-1,output.size(-1))
202
+ # loss = loss_func(output,y)
203
+
204
+ # if not reduce:
205
+ # loss = loss.reshape(input_size[:-1])
206
+
207
+ # return loss
208
+
209
+
210
+ # def cls_loss_func_(loss_func, y,output, use_focal=False, weight=None, reduce=True):
211
+ # input_size=y.size()
212
+ # y = y.float().cuda()
213
+ # if weight is not None:
214
+ # weight = weight.cuda()
215
+
216
+ # y=y.reshape(-1,y.size(-1))
217
+ # output=output.reshape(-1,output.size(-1))
218
+ # loss = loss_func(output,y)
219
+
220
+ # if not reduce:
221
+ # loss = loss.reshape(input_size[:-1])
222
+
223
+ # return loss
224
+
225
+ # def regress_loss_func(y,output):
226
+ # y = y.float().cuda()
227
+ # y=y.reshape(-1,y.size(-1))
228
+ # output=output.reshape(-1,output.size(-1))
229
+
230
+ # bgmask= y[:,1] < -1e2
231
+
232
+ # fg_logits = output[~bgmask]
233
+ # bg_logits = output[bgmask]
234
+
235
+ # fg_target = y[~bgmask]
236
+ # bg_target = y[bgmask]
237
+
238
+ # loss = nn.functional.l1_loss(fg_logits,fg_target)
239
+
240
+ # if(loss.isnan()):
241
+ # return torch.tensor([0.0], requires_grad=True).cuda()
242
+ # return loss
243
+
244
+
245
+ # def suppress_loss_func(y,output):
246
+ # y = y.float().cuda()
247
+ # y=y.reshape(-1,y.size(-1))
248
+ # output=output.reshape(-1,output.size(-1))
249
+
250
+ # loss = nn.functional.binary_cross_entropy(output,y)
251
+
252
+ # return loss
253
+
254
+
255
+
256
+ # import torch
257
+ # import numpy as np
258
+ # import torch.nn as nn
259
+ # import torch.nn.functional as F
260
+ # import torch.distributed as dist
261
+ # from functools import partial
262
+
263
+ # class MultiCrossEntropyLoss(nn.Module):
264
+ # def __init__(self, num_classes, focal=False, weight=None, reduce=True):
265
+ # super(MultiCrossEntropyLoss, self).__init__()
266
+ # self.num_classes = num_classes # Use the provided num_classes
267
+ # self.focal = focal
268
+ # self.weight = weight
269
+ # self.reduce = reduce
270
+ # self.gamma_ = torch.zeros(self.num_classes).cuda() + 0.025
271
+ # self.gamma_f = 0.05
272
+
273
+ # self.register_buffer('pos_grad', torch.zeros(self.num_classes-1).cuda())
274
+ # self.register_buffer('neg_grad', torch.zeros(self.num_classes-1).cuda())
275
+ # self.register_buffer('pos_neg', torch.ones(self.num_classes-1).cuda())
276
+
277
+ # def forward(self, input, target):
278
+ # target_sum = torch.sum(target, dim=1)
279
+ # target_div = torch.where(target_sum != 0, target_sum, torch.ones_like(target_sum)).unsqueeze(1)
280
+ # target = target / target_div
281
+ # logsoftmax = nn.LogSoftmax(dim=1).to(input.device)
282
+ # gamma = self.gamma_.clone()
283
+ # gamma[:-1] = gamma[:-1] + self.gamma_f * (1 - self.pos_neg)
284
+
285
+ # if not self.focal:
286
+ # if self.weight is None:
287
+ # output = torch.sum(-target * logsoftmax(input), 1)
288
+ # else:
289
+ # output = torch.sum(-target * logsoftmax(input) / self.weight, 1)
290
+ # else:
291
+ # softmax = nn.Softmax(dim=1).to(input.device)
292
+ # p = softmax(input)
293
+ # output = torch.sum(-target * (1 - p)**gamma * logsoftmax(input), 1)
294
+
295
+ # if self.reduce:
296
+ # return torch.mean(output)
297
+ # else:
298
+ # return output
299
+
300
+ # def map_func(self, x, s):
301
+ # min_val = torch.min(x)
302
+ # max_val = torch.max(x)
303
+ # mu = torch.mean(x)
304
+ # x = (x - min_val) / (max_val - min_val)
305
+ # return 1 / (1 + torch.exp(-s * (x - mu)))
306
+
307
+ # def collect_grad(self, target, grad):
308
+ # grad = torch.abs(grad.reshape(-1, grad.shape[-1])).cuda()
309
+ # target = target.reshape(-1, target.shape[-1]).cuda()
310
+ # pos_grad = torch.sum(grad * target, dim=0)[:-1]
311
+ # neg_grad = torch.sum(grad * (1 - target), dim=0)[:-1]
312
+ # self.pos_grad += pos_grad
313
+ # self.neg_grad += neg_grad
314
+ # self.pos_neg = torch.clamp(self.pos_grad / (self.neg_grad + 1e-10), min=0, max=1)
315
+ # self.pos_neg = self.map_func(self.pos_neg, 1)
316
+
317
+ # def cls_loss_func(y, output, use_focal=False, weight=None, reduce=True):
318
+ # input_size = y.size()
319
+ # y = y.float().cuda()
320
+ # if weight is not None:
321
+ # weight = weight.cuda()
322
+ # loss_func = MultiCrossEntropyLoss(num_classes=y.size(-1), focal=use_focal, weight=weight, reduce=reduce)
323
+
324
+ # y = y.reshape(-1, y.size(-1))
325
+ # output = output.reshape(-1, output.size(-1))
326
+ # loss = loss_func(output, y)
327
+
328
+ # if not reduce:
329
+ # loss = loss.reshape(input_size[:-1])
330
+
331
+ # return loss
332
+
333
+ # def cls_loss_func_(loss_func, y, output, use_focal=False, weight=None, reduce=True):
334
+ # input_size = y.size()
335
+ # y = y.float().cuda()
336
+ # if weight is not None:
337
+ # weight = weight.cuda()
338
+
339
+ # y = y.reshape(-1, y.size(-1))
340
+ # output = output.reshape(-1, output.size(-1))
341
+ # loss = loss_func(output, y)
342
+
343
+ # if not reduce:
344
+ # loss = loss.reshape(input_size[:-1])
345
+
346
+ # return loss
347
+
348
+ # def regress_loss_func(y, output):
349
+ # y = y.float().cuda()
350
+ # y = y.reshape(-1, y.size(-1))
351
+ # output = output.reshape(-1, output.size(-1))
352
+
353
+ # bgmask = y[:, 1] < -1e2
354
+
355
+ # fg_logits = output[~bgmask]
356
+ # bg_logits = output[bgmask]
357
+
358
+ # fg_target = y[~bgmask]
359
+ # bg_target = y[bgmask]
360
+
361
+ # loss = nn.functional.l1_loss(fg_logits, fg_target)
362
+
363
+ # if loss.isnan():
364
+ # return torch.tensor([0.0], requires_grad=True).cuda()
365
+ # return loss
366
+
367
+ # def suppress_loss_func(y, output):
368
+ # y = y.float().cuda()
369
+ # y = y.reshape(-1, y.size(-1))
370
+ # output = output.reshape(-1, output.size(-1))
371
+
372
+ # loss = nn.functional.binary_cross_entropy(output, y)
373
+
374
+ # return loss
main.py ADDED
@@ -0,0 +1,1144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+
11
+ import time
12
+ import h5py
13
+ from tqdm import tqdm
14
+ from iou_utils import *
15
+ from eval import evaluation_detection
16
+ from tensorboardX import SummaryWriter
17
+ from dataset import VideoDataSet, calc_iou
18
+ from models import MYNET, SuppressNet
19
+ from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
20
+ from loss_func import MultiCrossEntropyLoss
21
+ from functools import *
22
+
23
+ import matplotlib.pyplot as plt
24
+ import matplotlib.patches as patches
25
+ import cv2
26
+ from typing import List, Dict, Optional
27
+
28
+ from PIL import Image, ImageDraw, ImageFont
29
+ import warnings
30
+
31
+ # Visualization Configuration (Updated)
32
+ VIS_CONFIG = {
33
+ 'frame_interval': 1.0,
34
+ 'max_frames': 20,
35
+ 'save_dir': './output/visualizations',
36
+ 'video_save_dir': './output/videos',
37
+ 'gt_color': '#1f77b4', # Blue for ground truth (RGB: 31, 119, 180)
38
+ 'pred_color': '#ff7f0e', # Orange for predictions (RGB: 255, 127, 14)
39
+ 'fontsize_label': 10,
40
+ 'fontsize_title': 14,
41
+ 'frame_highlight_both': 'green',
42
+ 'frame_highlight_gt': 'red',
43
+ 'frame_highlight_pred': 'black',
44
+ 'iou_threshold': 0.3,
45
+ 'frame_scale_factor': 0.8,
46
+ 'video_text_scale': 0.5,
47
+ 'video_gt_text_color': (180, 119, 31), # BGR for OpenCV
48
+ 'video_pred_text_color': (14, 127, 255), # BGR for OpenCV
49
+ 'video_text_thickness': 1,
50
+ 'video_font_path': "./data/Poppins ExtraBold Italic 800.ttf",
51
+ 'video_font_fallback': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
52
+ 'video_pred_text_y': 0.45,
53
+ 'video_gt_text_y': 0.55,
54
+ 'video_footer_height': 150, # Increased to accommodate labels
55
+ 'video_gt_bar_y': 0.5,
56
+ 'video_pred_bar_y': 0.8,
57
+ 'video_bar_height': 0.15,
58
+ 'video_bar_text_scale': 0.7,
59
+ 'min_segment_duration': 1.0,
60
+ 'video_frame_text_y': 0.05, # Position for frame number and FPS
61
+ 'video_bar_label_x': 10, # X-position for GT/Pred labels
62
+ 'video_bar_label_scale': 0.5,
63
+ 'scroll_window_duration': 30.0, # Duration of the visible time window (seconds)
64
+ 'scroll_speed': 0.5, # Seconds to advance the window per second of video
65
+ }
66
+
67
+
68
+ def annotate_video_with_actions(
69
+ video_id: str,
70
+ pred_segments: List[Dict],
71
+ gt_segments: List[Dict],
72
+ video_path: str,
73
+ save_dir: str = VIS_CONFIG['video_save_dir'],
74
+ text_scale: float = VIS_CONFIG['video_text_scale'] * 1.5, # Increased text size by 50%
75
+ gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'],
76
+ pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'],
77
+ text_thickness: int = VIS_CONFIG['video_text_thickness']
78
+ ) -> None:
79
+ """
80
+ Annotate a video with predicted and ground truth action labels, cumulative bars, frame number, and FPS.
81
+ Use fixed 20-second windows with original bar animation, resetting bars at each window boundary.
82
+ Different colors for different action classes, no labels or timestamps on bars, increased text size.
83
+ GT and Pred text labels are on the left, with bars starting 0.5 inches (48 pixels) to the right.
84
+
85
+ Args:
86
+ video_id: Video identifier (e.g., 'my_video').
87
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
88
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
89
+ video_path: Path to the input video file.
90
+ save_dir: Directory to save the annotated video.
91
+ text_scale: Scale factor for text size in video (increased).
92
+ gt_text_color: BGR color tuple for ground truth text.
93
+ pred_text_color: BGR color tuple for predicted text.
94
+ text_thickness: Thickness of text strokes.
95
+ """
96
+ os.makedirs(save_dir, exist_ok=True)
97
+
98
+ # Open input video
99
+ cap = cv2.VideoCapture(video_path)
100
+ if not cap.isOpened():
101
+ print(f"Error: Could not open video {video_path}. Skipping video annotation.")
102
+ return
103
+
104
+ # Get video properties
105
+ fps = cap.get(cv2.CAP_PROP_FPS)
106
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
107
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
108
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
109
+ duration = total_frames / fps
110
+ print(f"Input Video: FPS={fps:.2f}, Resolution={frame_width}x{frame_height}, Total Frames={total_frames}, Duration={duration:.2f}s")
111
+
112
+ # Define output video with extended height for footer
113
+ footer_height = VIS_CONFIG['video_footer_height']
114
+ output_height = frame_height + footer_height
115
+ output_path = os.path.join(save_dir, f"annotated_{video_id}_{opt['exp']}.avi")
116
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
117
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, output_height))
118
+
119
+ if not out.isOpened():
120
+ print(f"Error: Could not initialize video writer for {output_path}. Check codec availability.")
121
+ cap.release()
122
+ return
123
+
124
+ # Filter short segments
125
+ min_duration = VIS_CONFIG['min_segment_duration']
126
+ gt_segments = [seg for seg in gt_segments if seg['duration'] >= min_duration]
127
+ pred_segments = [seg for seg in pred_segments if seg['duration'] >= min_duration]
128
+ print(f"Filtered Segments: GT={len(gt_segments)}, Pred={len(pred_segments)} (min_duration={min_duration}s)")
129
+
130
+ # Define color palette (BGR)
131
+ color_palette = [
132
+ (128, 0, 0), # Navy Blue
133
+ (60, 20, 220), # Crimson Red
134
+ (0, 128, 0), # Emerald Green
135
+ (128, 0, 128), # Royal Purple
136
+ (79, 69, 54), # Charcoal Gray
137
+ (128, 128, 0), # Teal
138
+ (0, 0, 128), # Maroon
139
+ (130, 0, 75), # Indigo
140
+ (34, 139, 34), # Forest Green
141
+ (0, 85, 204), # Burnt Orange
142
+ (149, 146, 209), # Dusty Rose
143
+ (235, 206, 135), # Sky Blue
144
+ (250, 230, 230), # Lavender
145
+ (191, 226, 159), # Seafoam Green
146
+ (185, 218, 255), # Peach
147
+ (255, 204, 204), # Periwinkle
148
+ (193, 182, 255), # Blush Pink
149
+ (201, 252, 189), # Mint Green
150
+ (144, 128, 112), # Slate Gray
151
+ (112, 25, 25), # Midnight Blue
152
+ (102, 51, 102), # Deep Plum
153
+ (0, 128, 128), # Olive Green
154
+ (171, 71, 0) # Cobalt Blue
155
+ ]
156
+
157
+ # Create color mapping for actions
158
+ action_labels = set(seg['label'] for seg in gt_segments).union(seg['label'] for seg in pred_segments)
159
+ action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
160
+ print(f"Action Color Mapping: {action_color_map}")
161
+
162
+ # Convert fallback colors to RGB for PIL
163
+ gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0]) # BGR to RGB
164
+ pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0]) # BGR to RGB
165
+
166
+ # Load font
167
+ font_path = VIS_CONFIG['video_font_path']
168
+ font_fallback = VIS_CONFIG['video_font_fallback']
169
+ font_size = int(20 * text_scale)
170
+ bar_font_size = int(20 * VIS_CONFIG['video_bar_text_scale'])
171
+ font = None
172
+ bar_font = None
173
+ if font_path:
174
+ try:
175
+ font = ImageFont.truetype(font_path, font_size)
176
+ bar_font = ImageFont.truetype(font_path, bar_font_size)
177
+ print(f"Using font: {font_path}")
178
+ except IOError:
179
+ print(f"Warning: Font {font_path} not found. Trying fallback font.")
180
+ if not font:
181
+ try:
182
+ font = ImageFont.truetype(font_fallback, font_size)
183
+ bar_font = ImageFont.truetype(font_fallback, bar_font_size)
184
+ print(f"Using fallback font: {font_fallback}")
185
+ except IOError:
186
+ print(f"Warning: Fallback font {font_fallback} not found. Using OpenCV default font.")
187
+ font = None
188
+ bar_font = None
189
+
190
+ # Fixed window configuration
191
+ window_size = 20.0 # 20-second windows
192
+ num_windows = int(np.ceil(duration / window_size))
193
+
194
+ # Define horizontal gap (0.5 inch = 48 pixels at 96 DPI)
195
+ text_bar_gap = 48 # Pixels
196
+ text_x = 10 # Fixed x-position for GT and Pred labels
197
+
198
+ frame_idx = 0
199
+ written_frames = 0
200
+ while cap.isOpened():
201
+ ret, frame = cap.read()
202
+ if not ret:
203
+ break
204
+
205
+ # Create extended frame with footer
206
+ extended_frame = np.zeros((output_height, frame_width, 3), dtype=np.uint8)
207
+ extended_frame[:frame_height, :, :] = frame
208
+ extended_frame[frame_height:, :, :] = 255 # White footer
209
+
210
+ # Calculate current timestamp
211
+ timestamp = frame_idx / fps
212
+
213
+ # Determine current window
214
+ window_idx = int(timestamp // window_size)
215
+ window_start = window_idx * window_size
216
+ window_end = min(window_start + window_size, duration)
217
+ window_duration = window_end - window_start
218
+ window_timestamp = timestamp - window_start # Relative timestamp within window
219
+
220
+ # Find active GT actions (for text overlay)
221
+ gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
222
+ gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else ""
223
+
224
+ # Find active predicted actions (for text overlay)
225
+ pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
226
+ pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else ""
227
+
228
+ # Draw GT and prediction bars in footer (within current window, using original animation)
229
+ footer_y = frame_height
230
+ gt_bar_y = footer_y + int(0.2 * footer_height) # GT bar position
231
+ pred_bar_y = footer_y + int(0.5 * footer_height) # Pred bar position
232
+ bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
233
+
234
+ # Calculate text width for GT and Pred labels to determine bar start
235
+ if font:
236
+ gt_text_bbox = bar_font.getbbox("GT")
237
+ pred_text_bbox = bar_font.getbbox("Pred")
238
+ gt_text_width = gt_text_bbox[2] - gt_text_bbox[0]
239
+ pred_text_width = pred_text_bbox[2] - pred_text_bbox[0]
240
+ else:
241
+ gt_text_size, _ = cv2.getTextSize("GT", cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
242
+ pred_text_size, _ = cv2.getTextSize("Pred", cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
243
+ gt_text_width = gt_text_size[0]
244
+ pred_text_width = pred_text_size[0]
245
+ max_text_width = max(gt_text_width, pred_text_width)
246
+ bar_start_x = text_x + max_text_width + text_bar_gap # Bars start after text + 0.5-inch gap
247
+ bar_width = frame_width - bar_start_x # Adjust bar width to fit remaining space
248
+
249
+ # Draw bars with action-specific colors
250
+ for seg in gt_segments:
251
+ if seg['start'] <= window_end and seg['end'] >= window_start:
252
+ start_t = max(seg['start'], window_start)
253
+ end_t = min(seg['end'], window_start + window_timestamp) # Original animation
254
+ start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
255
+ end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
256
+ if end_x > start_x:
257
+ cv2.rectangle(
258
+ extended_frame,
259
+ (start_x, gt_bar_y),
260
+ (end_x, gt_bar_y + bar_height),
261
+ action_color_map[seg['label']], # Action-specific color
262
+ -1
263
+ )
264
+
265
+ for seg in pred_segments:
266
+ if seg['start'] <= window_end and seg['end'] >= window_start:
267
+ start_t = max(seg['start'], window_start)
268
+ end_t = min(seg['end'], window_start + window_timestamp) # Original animation
269
+ start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
270
+ end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
271
+ if end_x > start_x:
272
+ cv2.rectangle(
273
+ extended_frame,
274
+ (start_x, pred_bar_y),
275
+ (end_x, pred_bar_y + bar_height),
276
+ action_color_map[seg['label']], # Action-specific color
277
+ -1
278
+ )
279
+
280
+ if font:
281
+ # Convert frame to PIL image
282
+ frame_rgb = cv2.cvtColor(extended_frame, cv2.COLOR_BGR2RGB)
283
+ pil_image = Image.fromarray(frame_rgb)
284
+ draw = ImageDraw.Draw(pil_image)
285
+
286
+ # Draw frame number and FPS at top center
287
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
288
+ frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
289
+ frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
290
+ frame_text_x = (frame_width - frame_text_width) // 2
291
+ draw.text((frame_text_x, 10), frame_info, font=font, fill=(0, 0, 0))
292
+
293
+ # Draw window timestamp range at top of footer
294
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
295
+ window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
296
+ window_text_width = window_text_bbox[2] - window_text_bbox[0]
297
+ window_text_x = (frame_width - window_text_width) // 2
298
+ draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
299
+
300
+ # Draw GT text in video only if there are actions
301
+ if gt_text:
302
+ gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
303
+ draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
304
+
305
+ # Draw predicted text in video only if there are actions
306
+ if pred_text:
307
+ pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
308
+ draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
309
+
310
+ # Draw GT and Pred labels in footer
311
+ draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
312
+ draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
313
+
314
+ # Convert back to OpenCV frame
315
+ extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
316
+ else:
317
+ # Fallback to OpenCV font
318
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
319
+ text_size, _ = cv2.getTextSize(frame_info, cv2.FONT_HERSHEY_DUPLEX, text_scale, text_thickness)
320
+ frame_text_x = (frame_width - text_size[0]) // 2
321
+ cv2.putText(
322
+ extended_frame,
323
+ frame_info,
324
+ (frame_text_x, 30),
325
+ cv2.FONT_HERSHEY_DUPLEX,
326
+ text_scale,
327
+ (0, 0, 0),
328
+ text_thickness,
329
+ cv2.LINE_AA
330
+ )
331
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
332
+ window_text_size, _ = cv2.getTextSize(window_info, cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
333
+ window_text_x = (frame_width - window_text_size[0]) // 2
334
+ cv2.putText(
335
+ extended_frame,
336
+ window_info,
337
+ (window_text_x, footer_y + 20),
338
+ cv2.FONT_HERSHEY_DUPLEX,
339
+ VIS_CONFIG['video_bar_text_scale'],
340
+ (0, 0, 0),
341
+ 1,
342
+ cv2.LINE_AA
343
+ )
344
+ if gt_text:
345
+ cv2.putText(
346
+ extended_frame,
347
+ gt_text,
348
+ (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
349
+ cv2.FONT_HERSHEY_DUPLEX,
350
+ text_scale,
351
+ gt_text_color,
352
+ text_thickness,
353
+ cv2.LINE_AA
354
+ )
355
+ if pred_text:
356
+ cv2.putText(
357
+ extended_frame,
358
+ pred_text,
359
+ (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
360
+ cv2.FONT_HERSHEY_DUPLEX,
361
+ text_scale,
362
+ pred_text_color,
363
+ text_thickness,
364
+ cv2.LINE_AA
365
+ )
366
+ cv2.putText(
367
+ extended_frame,
368
+ "GT",
369
+ (text_x, gt_bar_y + bar_height // 2 + 5),
370
+ cv2.FONT_HERSHEY_DUPLEX,
371
+ VIS_CONFIG['video_bar_text_scale'],
372
+ gt_text_color,
373
+ 1,
374
+ cv2.LINE_AA
375
+ )
376
+ cv2.putText(
377
+ extended_frame,
378
+ "Pred",
379
+ (text_x, pred_bar_y + bar_height // 2 + 5),
380
+ cv2.FONT_HERSHEY_DUPLEX,
381
+ VIS_CONFIG['video_bar_text_scale'],
382
+ pred_text_color,
383
+ 1,
384
+ cv2.LINE_AA
385
+ )
386
+
387
+ # Write frame to output video
388
+ out.write(extended_frame)
389
+ written_frames += 1
390
+ frame_idx += 1
391
+
392
+ # Release resources
393
+ cap.release()
394
+ out.release()
395
+ print(f"[✅ Saved Annotated Video]: {output_path}, Written Frames={written_frames}")
396
+ print("Note: If .avi is not playable, convert to .mp4 using FFmpeg:")
397
+ print(f"ffmpeg -i {output_path} -vcodec libx264 -acodec aac {output_path.replace('.avi', '.mp4')}")
398
+
399
+
400
+
401
+
402
+
403
+
404
+
405
+
406
+ def visualize_action_lengths(
407
+ video_id: str,
408
+ pred_segments: List[Dict],
409
+ gt_segments: List[Dict],
410
+ video_path: str,
411
+ duration: float,
412
+ save_dir: str = VIS_CONFIG['save_dir'],
413
+ frame_interval: float = VIS_CONFIG['frame_interval']
414
+ ) -> None:
415
+ """
416
+ Generate a visualization plot comparing ground truth and predicted action lengths with video frames.
417
+
418
+ Args:
419
+ video_id: Video identifier (e.g., 'my_video').
420
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
421
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
422
+ video_path: Path to the input video file.
423
+ duration: Total duration of the video in seconds.
424
+ save_dir: Directory to save the output image.
425
+ frame_interval: Time interval between sampled frames (seconds).
426
+ """
427
+ os.makedirs(save_dir, exist_ok=True)
428
+
429
+ # Calculate frame sampling times
430
+ num_frames = int(duration / frame_interval) + 1
431
+ if num_frames > VIS_CONFIG['max_frames']:
432
+ frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
433
+ num_frames = VIS_CONFIG['max_frames']
434
+ print(f"Warning: Video duration ({duration:.1f}s) requires {num_frames} frames. Adjusted frame_interval to {frame_interval:.2f}s.")
435
+
436
+ frame_times = np.linspace(0, duration, num_frames, endpoint=False)
437
+
438
+ # Load video frames
439
+ frames = []
440
+ cap = cv2.VideoCapture(video_path)
441
+ if not cap.isOpened():
442
+ print(f"Warning: Could not open video {video_path}. Using placeholder frames.")
443
+ frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times]
444
+ else:
445
+ for t in frame_times:
446
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
447
+ ret, frame = cap.read()
448
+ if ret:
449
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
450
+ # Resize frame to reduce memory usage
451
+ frame = cv2.resize(frame, (int(frame.shape[1] * 0.5), int(frame.shape[0] * 0.5)))
452
+ frames.append(frame)
453
+ else:
454
+ frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
455
+ cap.release()
456
+
457
+ # Initialize figure
458
+ fig = plt.figure(figsize=(num_frames * VIS_CONFIG['frame_scale_factor'], 6), constrained_layout=True)
459
+ gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
460
+
461
+ # Plot frames
462
+ for i, (t, frame) in enumerate(zip(frame_times, frames)):
463
+ ax = fig.add_subplot(gs[0, i])
464
+
465
+ # Check if frame falls within GT or predicted segments
466
+ gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments)
467
+ pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments)
468
+
469
+ # Set border color
470
+ border_color = None
471
+ if gt_hit and pred_hit:
472
+ border_color = VIS_CONFIG['frame_highlight_both']
473
+ elif gt_hit:
474
+ border_color = VIS_CONFIG['frame_highlight_gt']
475
+ elif pred_hit:
476
+ border_color = VIS_CONFIG['frame_highlight_pred']
477
+
478
+ ax.imshow(frame)
479
+ ax.axis('off')
480
+ if border_color:
481
+ for spine in ax.spines.values():
482
+ spine.set_edgecolor(border_color)
483
+ spine.set_linewidth(2)
484
+
485
+ ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'],
486
+ color=border_color if border_color else 'black')
487
+
488
+ # Plot ground truth bar
489
+ ax_gt = fig.add_subplot(gs[1, :])
490
+ ax_gt.set_xlim(0, duration)
491
+ ax_gt.set_ylim(0, 1)
492
+ ax_gt.axis('off')
493
+ ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'],
494
+ va='center', ha='right', weight='bold')
495
+
496
+ for seg in gt_segments:
497
+ start, end = seg['start'], seg['end']
498
+ width = end - start
499
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
500
+ ax_gt.add_patch(patches.Rectangle(
501
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'],
502
+ edgecolor='black', alpha=0.8
503
+ ))
504
+ ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center',
505
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
506
+ ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black')
507
+ ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black')
508
+
509
+ # Plot prediction bar
510
+ ax_pred = fig.add_subplot(gs[2, :])
511
+ ax_pred.set_xlim(0, duration)
512
+ ax_pred.set_ylim(0, 1)
513
+ ax_pred.axis('off')
514
+ ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'],
515
+ va='center', ha='right', weight='bold')
516
+
517
+ for seg in pred_segments:
518
+ start, end = seg['start'], seg['end']
519
+ width = end - start
520
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
521
+ ax_pred.add_patch(patches.Rectangle(
522
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'],
523
+ edgecolor='black', alpha=0.8
524
+ ))
525
+ ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center',
526
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
527
+ ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black')
528
+ ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black')
529
+
530
+ # Save plot
531
+ jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png") # Use PNG
532
+ plt.savefig(jpg_path, dpi=100, bbox_inches='tight') # Lower DPI
533
+ print(f"[✅ Saved Visualization]: {jpg_path}")
534
+ plt.close()
535
+
536
+
537
+
538
+ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
539
+ train_loader = torch.utils.data.DataLoader(train_dataset,
540
+ batch_size=opt['batch_size'], shuffle=True,
541
+ num_workers=0, pin_memory=True, drop_last=False)
542
+ epoch_cost = 0
543
+ epoch_cost_cls = 0
544
+ epoch_cost_reg = 0
545
+ epoch_cost_snip = 0
546
+
547
+ total_iter = len(train_dataset) // opt['batch_size']
548
+ cls_loss = MultiCrossEntropyLoss(focal=True)
549
+ snip_loss = MultiCrossEntropyLoss(focal=True)
550
+ for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)):
551
+ if warmup:
552
+ for g in optimizer.param_groups:
553
+ g['lr'] = n_iter * (opt['lr']) / total_iter
554
+
555
+ act_cls, act_reg, snip_cls = model(input_data.float().cuda())
556
+
557
+ act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
558
+ snip_cls.register_hook(partial(snip_loss.collect_grad, snip_label))
559
+
560
+ cost_reg = 0
561
+ cost_cls = 0
562
+
563
+ loss = cls_loss_func_(cls_loss, cls_label, act_cls)
564
+ cost_cls = loss
565
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
566
+
567
+ loss = regress_loss_func(reg_label, act_reg)
568
+ cost_reg = loss
569
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
570
+
571
+ loss = cls_loss_func_(snip_loss, snip_label, snip_cls)
572
+ cost_snip = loss
573
+ epoch_cost_snip += cost_snip.detach().cpu().numpy()
574
+
575
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip
576
+ epoch_cost += cost.detach().cpu().numpy()
577
+
578
+ optimizer.zero_grad()
579
+ cost.backward()
580
+ optimizer.step()
581
+
582
+ return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
583
+
584
+ def eval_one_epoch(opt, model, test_dataset):
585
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, test_dataset)
586
+
587
+ result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
588
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
589
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
590
+ json.dump(output_dict, outfile, indent=2)
591
+ outfile.close()
592
+
593
+ IoUmAP = evaluation_detection(opt, verbose=False)
594
+ IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:])
595
+
596
+ return cls_loss, reg_loss, tot_loss, IoUmAP_5
597
+
598
+ def train(opt):
599
+ writer = SummaryWriter()
600
+ model = MYNET(opt).cuda()
601
+
602
+ rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
603
+ optimizer = optim.Adam([{'params': model.history_unit.parameters(), 'lr': 1e-6}, {'params': rest_of_model_params}], lr=opt["lr"], weight_decay=opt["weight_decay"])
604
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt["lr_step"])
605
+
606
+ train_dataset = VideoDataSet(opt, subset="train")
607
+ test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
608
+
609
+ warmup = False
610
+
611
+ for n_epoch in range(opt['epoch']):
612
+ if n_epoch >= 1:
613
+ warmup = False
614
+
615
+ n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
616
+
617
+ writer.add_scalars('data/cost', {'train': epoch_cost / (n_iter + 1)}, n_epoch)
618
+ print("training loss(epoch %d): %.03f, cls - %f, reg - %f, snip - %f, lr - %f" % (n_epoch,
619
+ epoch_cost / (n_iter + 1),
620
+ epoch_cost_cls / (n_iter + 1),
621
+ epoch_cost_reg / (n_iter + 1),
622
+ epoch_cost_snip / (n_iter + 1),
623
+ optimizer.param_groups[-1]["lr"]))
624
+
625
+ scheduler.step()
626
+ model.eval()
627
+
628
+ cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
629
+
630
+ writer.add_scalars('data/mAP', {'test': IoUmAP_5}, n_epoch)
631
+ print("testing loss(epoch %d): %.03f, cls - %f, reg - %f, mAP Avg - %f" % (n_epoch, tot_loss, cls_loss, reg_loss, IoUmAP_5))
632
+
633
+ state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
634
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_checkpoint_" + str(n_epoch + 1) + ".pth.tar")
635
+ if IoUmAP_5 > model.best_map:
636
+ model.best_map = IoUmAP_5
637
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_ckp_best.pth.tar")
638
+
639
+ model.train()
640
+
641
+ writer.close()
642
+ return model.best_map
643
+
644
+ def eval_frame(opt, model, dataset):
645
+ test_loader = torch.utils.data.DataLoader(dataset,
646
+ batch_size=opt['batch_size'], shuffle=False,
647
+ num_workers=0, pin_memory=True, drop_last=False)
648
+
649
+ labels_cls = {}
650
+ labels_reg = {}
651
+ output_cls = {}
652
+ output_reg = {}
653
+ for video_name in dataset.video_list:
654
+ labels_cls[video_name] = []
655
+ labels_reg[video_name] = []
656
+ output_cls[video_name] = []
657
+ output_reg[video_name] = []
658
+
659
+ start_time = time.time()
660
+ total_frames = 0
661
+ epoch_cost = 0
662
+ epoch_cost_cls = 0
663
+ epoch_cost_reg = 0
664
+
665
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
666
+ act_cls, act_reg, _ = model(input_data.float().cuda())
667
+ cost_reg = 0
668
+ cost_cls = 0
669
+
670
+ loss = cls_loss_func(cls_label, act_cls)
671
+ cost_cls = loss
672
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
673
+
674
+ loss = regress_loss_func(reg_label, act_reg)
675
+ cost_reg = loss
676
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
677
+
678
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
679
+ epoch_cost += cost.detach().cpu().numpy()
680
+
681
+ act_cls = torch.softmax(act_cls, dim=-1)
682
+
683
+ total_frames += input_data.size(0)
684
+
685
+ for b in range(0, input_data.size(0)):
686
+ video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + b]
687
+ output_cls[video_name] += [act_cls[b, :].detach().cpu().numpy()]
688
+ output_reg[video_name] += [act_reg[b, :].detach().cpu().numpy()]
689
+ labels_cls[video_name] += [cls_label[b, :].numpy()]
690
+ labels_reg[video_name] += [reg_label[b, :].numpy()]
691
+
692
+ end_time = time.time()
693
+ working_time = end_time - start_time
694
+
695
+ for video_name in dataset.video_list:
696
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
697
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
698
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
699
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
700
+
701
+ cls_loss = epoch_cost_cls / n_iter
702
+ reg_loss = epoch_cost_reg / n_iter
703
+ tot_loss = epoch_cost / n_iter
704
+
705
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
706
+
707
+ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
708
+ result_dict = {}
709
+ proposal_dict = []
710
+
711
+ num_class = opt["num_of_class"]
712
+ unit_size = opt['segment_size']
713
+ threshold = opt['threshold']
714
+ anchors = opt['anchors']
715
+
716
+ for video_name in dataset.video_list:
717
+ duration = dataset.video_len[video_name]
718
+ video_time = float(dataset.video_dict[video_name]["duration"])
719
+ frame_to_time = 100.0 * video_time / duration
720
+
721
+ for idx in range(0, duration):
722
+ cls_anc = output_cls[video_name][idx]
723
+ reg_anc = output_reg[video_name][idx]
724
+
725
+ proposal_anc_dict = []
726
+ for anc_idx in range(0, len(anchors)):
727
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
728
+
729
+ if len(cls) == 0:
730
+ continue
731
+
732
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
733
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
734
+ st = ed - length
735
+
736
+ for cidx in range(0, len(cls)):
737
+ label = cls[cidx]
738
+ tmp_dict = {}
739
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
740
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
741
+ tmp_dict["label"] = dataset.label_name[label]
742
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
743
+ proposal_anc_dict.append(tmp_dict)
744
+
745
+ proposal_dict += proposal_anc_dict
746
+
747
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
748
+ result_dict[video_name] = proposal_dict
749
+ proposal_dict = []
750
+
751
+ return result_dict
752
+
753
+ def eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
754
+ model = SuppressNet(opt).cuda()
755
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
756
+ base_dict = checkpoint['state_dict']
757
+ model.load_state_dict(base_dict)
758
+ model.eval()
759
+
760
+ result_dict = {}
761
+ proposal_dict = []
762
+
763
+ num_class = opt["num_of_class"]
764
+ unit_size = opt['segment_size']
765
+ threshold = opt['threshold']
766
+ anchors = opt['anchors']
767
+
768
+ for video_name in dataset.video_list:
769
+ duration = dataset.video_len[video_name]
770
+ video_time = float(dataset.video_dict[video_name]["duration"])
771
+ frame_to_time = 100.0 * video_time / duration
772
+ conf_queue = torch.zeros((unit_size, num_class - 1))
773
+
774
+ for idx in range(0, duration):
775
+ cls_anc = output_cls[video_name][idx]
776
+ reg_anc = output_reg[video_name][idx]
777
+
778
+ proposal_anc_dict = []
779
+ for anc_idx in range(0, len(anchors)):
780
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
781
+
782
+ if len(cls) == 0:
783
+ continue
784
+
785
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
786
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
787
+ st = ed - length
788
+
789
+ for cidx in range(0, len(cls)):
790
+ label = cls[cidx]
791
+ tmp_dict = {}
792
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
793
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
794
+ tmp_dict["label"] = dataset.label_name[label]
795
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
796
+ proposal_anc_dict.append(tmp_dict)
797
+
798
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
799
+
800
+ conf_queue[:-1, :] = conf_queue[1:, :].clone()
801
+ conf_queue[-1, :] = 0
802
+ for proposal in proposal_anc_dict:
803
+ cls_idx = dataset.label_name.index(proposal['label'])
804
+ conf_queue[-1, cls_idx] = proposal["score"]
805
+
806
+ minput = conf_queue.unsqueeze(0)
807
+ suppress_conf = model(minput.cuda())
808
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
809
+
810
+ for cls in range(0, num_class - 1):
811
+ if suppress_conf[cls] > opt['sup_threshold']:
812
+ for proposal in proposal_anc_dict:
813
+ if proposal['label'] == dataset.label_name[cls]:
814
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
815
+ proposal_dict.append(proposal)
816
+
817
+ result_dict[video_name] = proposal_dict
818
+ proposal_dict = []
819
+
820
+ return result_dict
821
+
822
+ def test_frame(opt, video_name=None):
823
+ model = MYNET(opt).cuda()
824
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
825
+ base_dict = checkpoint['state_dict']
826
+ model.load_state_dict(base_dict)
827
+ model.eval()
828
+
829
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
830
+ outfile = h5py.File(opt['frame_result_file'].format(opt['exp']), 'w')
831
+
832
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
833
+
834
+ print("testing loss: %f, cls_loss: %f, reg_loss: %f" % (tot_loss, cls_loss, reg_loss))
835
+
836
+ for video_name in dataset.video_list:
837
+ o_cls = output_cls[video_name]
838
+ o_reg = output_reg[video_name]
839
+ l_cls = labels_cls[video_name]
840
+ l_reg = labels_reg[video_name]
841
+
842
+ dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
843
+ dset_predcls[:, :] = o_cls[:, :]
844
+ dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
845
+ dset_predreg[:, :] = o_reg[:, :]
846
+ dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
847
+ dset_labelcls[:, :] = l_cls[:, :]
848
+ dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
849
+ dset_labelreg[:, :] = l_reg[:, :]
850
+ outfile.close()
851
+
852
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
853
+ return cls_loss, reg_loss, tot_loss
854
+
855
+ def patch_attention(m):
856
+ forward_orig = m.forward
857
+
858
+ def wrap(*args, **kwargs):
859
+ kwargs["need_weights"] = True
860
+ kwargs["average_attn_weights"] = False
861
+ return forward_orig(*args, **kwargs)
862
+
863
+ m.forward = wrap
864
+
865
+ class SaveOutput:
866
+ def __init__(self):
867
+ self.outputs = []
868
+
869
+ def __call__(self, module, module_in, module_out):
870
+ self.outputs.append(module_out[1])
871
+
872
+ def clear(self):
873
+ self.outputs = []
874
+
875
+ def test(opt, video_name=None):
876
+ model = MYNET(opt).cuda()
877
+ checkpoint = torch.load(opt["checkpoint_path"] + "/" + opt['exp'] + "_ckp_best.pth.tar")
878
+ base_dict = checkpoint['state_dict']
879
+ model.load_state_dict(base_dict)
880
+ model.eval()
881
+
882
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
883
+
884
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
885
+
886
+ if opt["pptype"] == "nms":
887
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
888
+ if opt["pptype"] == "net":
889
+ result_dict = eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
890
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
891
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
892
+ json.dump(output_dict, outfile, indent=2)
893
+ outfile.close()
894
+
895
+ mAP = evaluation_detection(opt)
896
+
897
+ # Compare predicted and ground truth action lengths
898
+ if video_name:
899
+ print("\nComparing Predicted and Ground Truth Action Lengths for Video:", video_name)
900
+ with open(opt["video_anno"].format(opt["split"]), 'r') as f:
901
+ anno_data = json.load(f)
902
+ gt_annotations = anno_data['database'][video_name]['annotations']
903
+ duration = anno_data['database'][video_name]['duration']
904
+
905
+ gt_segments = []
906
+ for anno in gt_annotations:
907
+ start, end = anno['segment']
908
+ label = anno['label']
909
+ duration_seg = end - start
910
+ gt_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg})
911
+
912
+ pred_segments = []
913
+ for pred in result_dict[video_name]:
914
+ start, end = pred['segment']
915
+ label = pred['label']
916
+ score = pred['score']
917
+ duration_seg = end - start
918
+ pred_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg, 'score': score})
919
+
920
+ # Print comparison table
921
+ matches = []
922
+ iou_threshold = VIS_CONFIG['iou_threshold']
923
+ used_gt_indices = set()
924
+ for pred in pred_segments:
925
+ best_iou = 0
926
+ best_gt_idx = None
927
+ for gt_idx, gt in enumerate(gt_segments):
928
+ if gt_idx in used_gt_indices:
929
+ continue
930
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
931
+ if iou > best_iou and iou >= iou_threshold:
932
+ best_iou = iou
933
+ best_gt_idx = gt_idx
934
+ if best_gt_idx is not None:
935
+ matches.append({
936
+ 'pred': pred,
937
+ 'gt': gt_segments[best_gt_idx],
938
+ 'iou': best_iou
939
+ })
940
+ used_gt_indices.add(best_gt_idx)
941
+ else:
942
+ matches.append({'pred': pred, 'gt': None, 'iou': 0})
943
+
944
+ for gt_idx, gt in enumerate(gt_segments):
945
+ if gt_idx not in used_gt_indices:
946
+ matches.append({'pred': None, 'gt': gt, 'iou': 0})
947
+
948
+ print("\n{:<20} {:<30} {:<30} {:<15} {:<10}".format(
949
+ "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU"))
950
+ print("-" * 105)
951
+ for match in matches:
952
+ pred = match['pred']
953
+ gt = match['gt']
954
+ iou = match['iou']
955
+ if pred and gt:
956
+ label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
957
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
958
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
959
+ duration_diff = pred['duration'] - gt['duration']
960
+ print("{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}".format(
961
+ label, pred_str, gt_str, duration_diff, iou))
962
+ elif pred:
963
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
964
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
965
+ pred['label'], pred_str, "None", "N/A", iou))
966
+ elif gt:
967
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
968
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
969
+ gt['label'], "None", gt_str, "N/A", iou))
970
+
971
+ # Summarize
972
+ matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
973
+ avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0
974
+ avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0
975
+ print(f"\nSummary:")
976
+ print(f"- Total Predictions: {len(pred_segments)}")
977
+ print(f"- Total Ground Truth: {len(gt_segments)}")
978
+ print(f"- Matched Segments: {matched_count}")
979
+ print(f"- Average Duration Difference (Matched): {avg_duration_diff:.2f}s")
980
+ print(f"- Average IoU (Matched): {avg_iou:.2f}")
981
+
982
+ # Generate static visualization
983
+ video_path = opt.get('video_path', '')
984
+ if os.path.exists(video_path):
985
+ visualize_action_lengths(
986
+ video_id=video_name,
987
+ pred_segments=pred_segments,
988
+ gt_segments=gt_segments,
989
+ video_path=video_path,
990
+ duration=duration
991
+ )
992
+ # Generate annotated video
993
+ annotate_video_with_actions(
994
+ video_id=video_name,
995
+ pred_segments=pred_segments,
996
+ gt_segments=gt_segments,
997
+ video_path=video_path
998
+ )
999
+ else:
1000
+ print(f"Warning: Video path {video_path} not found. Skipping visualization and video annotation.")
1001
+
1002
+ return mAP
1003
+
1004
+ def test_online(opt, video_name=None):
1005
+ model = MYNET(opt).cuda()
1006
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
1007
+ base_dict = checkpoint['state_dict']
1008
+ model.load_state_dict(base_dict)
1009
+ model.eval()
1010
+
1011
+ sup_model = SuppressNet(opt).cuda()
1012
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
1013
+ base_dict = checkpoint['state_dict']
1014
+ sup_model.load_state_dict(base_dict)
1015
+ sup_model.eval()
1016
+
1017
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
1018
+ test_loader = torch.utils.data.DataLoader(dataset,
1019
+ batch_size=1, shuffle=False,
1020
+ num_workers=0, pin_memory=True, drop_last=False)
1021
+
1022
+ result_dict = {}
1023
+ proposal_dict = []
1024
+
1025
+ num_class = opt["num_of_class"]
1026
+ unit_size = opt['segment_size']
1027
+ threshold = opt['threshold']
1028
+ anchors = opt['anchors']
1029
+
1030
+ start_time = time.time()
1031
+ total_frames = 0
1032
+
1033
+ for video_name in dataset.video_list:
1034
+ input_queue = torch.zeros((unit_size, opt['feat_dim']))
1035
+ sup_queue = torch.zeros(((unit_size, num_class - 1)))
1036
+
1037
+ duration = dataset.video_len[video_name]
1038
+ video_time = float(dataset.video_dict[video_name]["duration"])
1039
+ frame_to_time = 100.0 * video_time / duration
1040
+
1041
+ for idx in range(0, duration):
1042
+ total_frames += 1
1043
+ input_queue[:-1, :] = input_queue[1:, :].clone()
1044
+ input_queue[-1:, :] = dataset._get_base_data(video_name, idx, idx + 1)
1045
+
1046
+ minput = input_queue.unsqueeze(0)
1047
+ act_cls, act_reg, _ = model(minput.cuda())
1048
+ act_cls = torch.softmax(act_cls, dim=-1)
1049
+
1050
+ cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
1051
+ reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
1052
+
1053
+ proposal_anc_dict = []
1054
+ for anc_idx in range(0, len(anchors)):
1055
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
1056
+
1057
+ if len(cls) == 0:
1058
+ continue
1059
+
1060
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
1061
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
1062
+ st = ed - length
1063
+
1064
+ for cidx in range(0, len(cls)):
1065
+ label = cls[cidx]
1066
+ tmp_dict = {}
1067
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
1068
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
1069
+ tmp_dict["label"] = dataset.label_name[label]
1070
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
1071
+ proposal_anc_dict.append(tmp_dict)
1072
+
1073
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
1074
+
1075
+ sup_queue[:-1, :] = sup_queue[1:, :].clone()
1076
+ sup_queue[-1, :] = 0
1077
+ for proposal in proposal_anc_dict:
1078
+ cls_idx = dataset.label_name.index(proposal['label'])
1079
+ sup_queue[-1, cls_idx] = proposal["score"]
1080
+
1081
+ minput = sup_queue.unsqueeze(0)
1082
+ suppress_conf = sup_model(minput.cuda())
1083
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
1084
+
1085
+ for cls in range(0, num_class - 1):
1086
+ if suppress_conf[cls] > opt['sup_threshold']:
1087
+ for proposal in proposal_anc_dict:
1088
+ if proposal['label'] == dataset.label_name[cls]:
1089
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
1090
+ proposal_dict.append(proposal)
1091
+
1092
+ result_dict[video_name] = proposal_dict
1093
+ proposal_dict = []
1094
+
1095
+ end_time = time.time()
1096
+ working_time = end_time - start_time
1097
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
1098
+
1099
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
1100
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
1101
+ json.dump(output_dict, outfile, indent=2)
1102
+ outfile.close()
1103
+
1104
+ mAP = evaluation_detection(opt)
1105
+ return mAP
1106
+
1107
+ def main(opt, video_name=None):
1108
+ max_perf = 0
1109
+ if not video_name and 'video_name' in opt:
1110
+ video_name = opt['video_name']
1111
+
1112
+ if opt['mode'] == 'train':
1113
+ max_perf = train(opt)
1114
+ if opt['mode'] == 'test':
1115
+ max_perf = test(opt, video_name=video_name)
1116
+ if opt['mode'] == 'test_frame':
1117
+ max_perf = test_frame(opt, video_name=video_name)
1118
+ if opt['mode'] == 'test_online':
1119
+ max_perf = test_online(opt, video_name=video_name)
1120
+ if opt['mode'] == 'eval':
1121
+ max_perf = evaluation_detection(opt)
1122
+
1123
+ return max_perf
1124
+
1125
+ if __name__ == '__main__':
1126
+ opt = opts.parse_opt()
1127
+ opt = vars(opt)
1128
+ if not os.path.exists(opt["checkpoint_path"]):
1129
+ os.makedirs(opt["checkpoint_path"])
1130
+ opt_file = open(opt["checkpoint_path"] + "/" + opt["exp"] + "_opts.json", "w")
1131
+ json.dump(opt, opt_file)
1132
+ opt_file.close()
1133
+
1134
+ if opt['seed'] >= 0:
1135
+ seed = opt['seed']
1136
+ torch.manual_seed(seed)
1137
+ np.random.seed(seed)
1138
+
1139
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
1140
+
1141
+ video_name = opt.get('video_name', None)
1142
+ main(opt, video_name=video_name)
1143
+ while(opt['wterm']):
1144
+ pass
models.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ from torch.autograd import Variable
5
+ import torch.nn.functional as F
6
+ import torch.nn as nn
7
+ from torch.nn import init
8
+ from torch.nn.functional import normalize
9
+
10
+
11
+ class PositionalEncoding(nn.Module):
12
+ def __init__(self,
13
+ emb_size: int,
14
+ dropout: float = 0.1,
15
+ maxlen: int = 750):
16
+ super(PositionalEncoding, self).__init__()
17
+ den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
18
+ pos = torch.arange(0, maxlen).reshape(maxlen, 1)
19
+ pos_embedding = torch.zeros((maxlen, emb_size))
20
+ pos_embedding[:, 0::2] = torch.sin(pos * den)
21
+ pos_embedding[:, 1::2] = torch.cos(pos * den)
22
+ pos_embedding = pos_embedding.unsqueeze(-2)
23
+ self.dropout = nn.Dropout(dropout)
24
+ self.register_buffer('pos_embedding', pos_embedding)
25
+
26
+ def forward(self, token_embedding: torch.Tensor):
27
+ return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
28
+
29
+ class HistoryUnit(torch.nn.Module):
30
+ def __init__(self, opt):
31
+ super(HistoryUnit, self).__init__()
32
+ self.n_feature=opt["feat_dim"]
33
+ n_class=opt["num_of_class"]
34
+ n_embedding_dim=opt["hidden_dim"]
35
+ n_hist_dec_head = 4
36
+ n_hist_dec_layer = 5
37
+ n_hist_dec_head_2 = 4
38
+ n_hist_dec_layer_2 = 2
39
+ self.anchors=opt["anchors"]
40
+ self.history_tokens = 16
41
+ self.short_window_size = 16
42
+ self.anchors_stride=[]
43
+ dropout=0.3
44
+ self.best_loss=1000000
45
+ self.best_map=0
46
+
47
+
48
+ self.history_positional_encoding = PositionalEncoding(n_embedding_dim, dropout, maxlen=400)
49
+
50
+ self.history_encoder_block1 = nn.TransformerDecoder(
51
+ nn.TransformerDecoderLayer(d_model=n_embedding_dim,
52
+ nhead=n_hist_dec_head,
53
+ dropout=dropout,
54
+ activation='gelu'),
55
+ n_hist_dec_layer,
56
+ nn.LayerNorm(n_embedding_dim))
57
+
58
+
59
+ self.history_encoder_block2 = nn.TransformerDecoder(
60
+ nn.TransformerDecoderLayer(d_model=n_embedding_dim,
61
+ nhead=n_hist_dec_head_2,
62
+ dropout=dropout,
63
+ activation='gelu'),
64
+ n_hist_dec_layer_2,
65
+ nn.LayerNorm(n_embedding_dim))
66
+
67
+
68
+
69
+ self.snip_head = nn.Sequential(nn.Linear(n_embedding_dim,n_embedding_dim//4), nn.ReLU())
70
+ self.snip_classifier = nn.Sequential(nn.Linear(self.history_tokens*n_embedding_dim//4, (self.history_tokens*n_embedding_dim//4)//4), nn.ReLU(), nn.Linear((self.history_tokens*n_embedding_dim//4)//4,n_class))
71
+
72
+
73
+ self.history_token = nn.Parameter(torch.zeros(self.history_tokens, 1, n_embedding_dim))
74
+ # self.history_token_extra = nn.Parameter(torch.zeros(self.history_tokens*2, 1, n_embedding_dim))
75
+
76
+ self.norm2 = nn.LayerNorm(n_embedding_dim)
77
+ self.dropout2 = nn.Dropout(0.1)
78
+
79
+
80
+ def forward(self, long_x, encoded_x):
81
+
82
+
83
+ ## History Encoder
84
+ hist_pe_x = self.history_positional_encoding(long_x)
85
+ history_token = self.history_token.expand(-1, hist_pe_x.shape[1], -1)
86
+ hist_encoded_x_1 = self.history_encoder_block1(history_token, hist_pe_x)
87
+ hist_encoded_x_2 = self.history_encoder_block2(hist_encoded_x_1, encoded_x)
88
+ hist_encoded_x_2 = hist_encoded_x_2 + self.dropout2(hist_encoded_x_1)
89
+ hist_encoded_x = self.norm2(hist_encoded_x_2)
90
+
91
+ ## Snippet Classfication Head
92
+ snippet_feat = self.snip_head(hist_encoded_x_1)
93
+ snippet_feat = torch.flatten(snippet_feat.permute(1, 0, 2), start_dim=1)
94
+
95
+ snip_cls = self.snip_classifier(snippet_feat)
96
+
97
+ return hist_encoded_x, snip_cls
98
+
99
+
100
+
101
+ class MYNET(torch.nn.Module):
102
+ def __init__(self, opt):
103
+ super(MYNET, self).__init__()
104
+ self.n_feature=opt["feat_dim"]
105
+ n_class=opt["num_of_class"]
106
+ n_embedding_dim=opt["hidden_dim"]
107
+ n_enc_layer=opt["enc_layer"]
108
+ n_enc_head=opt["enc_head"]
109
+ n_dec_layer=opt["dec_layer"]
110
+ n_dec_head=opt["dec_head"]
111
+ n_comb_dec_head = 4
112
+ n_comb_dec_layer = 5
113
+ n_seglen=opt["segment_size"]
114
+ self.anchors=opt["anchors"]
115
+ self.history_tokens = 16
116
+ self.short_window_size = 16
117
+ self.anchors_stride=[]
118
+ dropout=0.3
119
+ self.best_loss=1000000
120
+ self.best_map=0
121
+
122
+ self.feature_reduction_rgb = nn.Linear(self.n_feature//2, n_embedding_dim//2)
123
+ self.feature_reduction_flow = nn.Linear(self.n_feature//2, n_embedding_dim//2)
124
+
125
+ self.positional_encoding = PositionalEncoding(n_embedding_dim, dropout, maxlen=400)
126
+
127
+ self.encoder = nn.TransformerEncoder(
128
+ nn.TransformerEncoderLayer(d_model=n_embedding_dim,
129
+ nhead=n_enc_head,
130
+ dropout=dropout,
131
+ activation='gelu'),
132
+ n_enc_layer,
133
+ nn.LayerNorm(n_embedding_dim))
134
+
135
+ self.decoder = nn.TransformerDecoder(
136
+ nn.TransformerDecoderLayer(d_model=n_embedding_dim,
137
+ nhead=n_dec_head,
138
+ dropout=dropout,
139
+ activation='gelu'),
140
+ n_dec_layer,
141
+ nn.LayerNorm(n_embedding_dim))
142
+
143
+ self.history_unit = HistoryUnit(opt)
144
+
145
+
146
+ self.history_anchor_decoder_block1 = nn.TransformerDecoder(
147
+ nn.TransformerDecoderLayer(d_model=n_embedding_dim,
148
+ nhead=n_comb_dec_head,
149
+ dropout=dropout,
150
+ activation='gelu'),
151
+ n_comb_dec_layer,
152
+ nn.LayerNorm(n_embedding_dim))
153
+
154
+
155
+ self.classifier = nn.Sequential(nn.Linear(n_embedding_dim,n_embedding_dim), nn.ReLU(), nn.Linear(n_embedding_dim,n_class))
156
+ self.regressor = nn.Sequential(nn.Linear(n_embedding_dim,n_embedding_dim), nn.ReLU(), nn.Linear(n_embedding_dim,2))
157
+
158
+
159
+ self.decoder_token = nn.Parameter(torch.zeros(len(self.anchors), 1, n_embedding_dim))
160
+
161
+
162
+ self.norm1 = nn.LayerNorm(n_embedding_dim)
163
+ self.dropout1 = nn.Dropout(0.1)
164
+
165
+ self.relu = nn.ReLU(True)
166
+ self.softmaxd1 = nn.Softmax(dim=-1)
167
+
168
+ def forward(self, inputs):
169
+ # base_x_rgb = self.feature_reduction_rgb(inputs[:,:,:self.n_feature//2])
170
+ # base_x_flow = self.feature_reduction_flow(inputs[:,:,self.n_feature//2:])
171
+ base_x_rgb = self.feature_reduction_rgb(inputs[:,:,:self.n_feature//2].float())
172
+ base_x_flow = self.feature_reduction_flow(inputs[:,:,self.n_feature//2:].float())
173
+ base_x = torch.cat([base_x_rgb,base_x_flow],dim=-1)
174
+
175
+ base_x = base_x.permute([1,0,2])# seq_len x batch x featsize x
176
+
177
+ short_x = base_x[-self.short_window_size:]
178
+
179
+ long_x = base_x[:-self.short_window_size]
180
+
181
+ ## Anchor Feature Generator
182
+ pe_x = self.positional_encoding(short_x)
183
+ encoded_x = self.encoder(pe_x)
184
+ decoder_token = self.decoder_token.expand(-1, encoded_x.shape[1], -1)
185
+ decoded_x = self.decoder(decoder_token, encoded_x)
186
+ decoded_x = decoded_x
187
+
188
+ ## Future-Supervised History Module
189
+ hist_encoded_x, snip_cls = self.history_unit(long_x, encoded_x)
190
+
191
+
192
+ ## History Driven Anchor Refinement
193
+ decoded_anchor_feat = self.history_anchor_decoder_block1(decoded_x, hist_encoded_x)
194
+ decoded_anchor_feat = decoded_anchor_feat + self.dropout1(decoded_x)
195
+ decoded_anchor_feat = self.norm1(decoded_anchor_feat)
196
+ decoded_anchor_feat = decoded_anchor_feat.permute([1, 0, 2])
197
+
198
+ # Predition Module
199
+ anc_cls = self.classifier(decoded_anchor_feat)
200
+ anc_reg = self.regressor(decoded_anchor_feat)
201
+
202
+ return anc_cls, anc_reg, snip_cls
203
+
204
+
205
+ class SuppressNet(torch.nn.Module):
206
+ def __init__(self, opt):
207
+ super(SuppressNet, self).__init__()
208
+ n_class=opt["num_of_class"]-1
209
+ n_seglen=opt["segment_size"]
210
+ n_embedding_dim=2*n_seglen
211
+ dropout=0.3
212
+ self.best_loss=1000000
213
+ self.best_map=0
214
+ # FC layers for the 2 streams
215
+
216
+ self.mlp1 = nn.Linear(n_seglen, n_embedding_dim)
217
+ self.mlp2 = nn.Linear(n_embedding_dim, 1)
218
+ self.norm = nn.InstanceNorm1d(n_class)
219
+ self.relu = nn.ReLU(True)
220
+ self.sigmoid = nn.Sigmoid()
221
+
222
+ def forward(self, inputs):
223
+ #inputs - batch x seq_len x class
224
+
225
+ base_x = inputs.permute([0,2,1])
226
+ base_x = self.norm(base_x)
227
+ x = self.relu(self.mlp1(base_x))
228
+ x = self.sigmoid(self.mlp2(x))
229
+ x = x.squeeze(-1)
230
+
231
+ return x
232
+
opts_egtea.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def parse_opt():
4
+ parser = argparse.ArgumentParser()
5
+ # Overall settings
6
+ parser.add_argument('--mode', type=str, default='train')
7
+ parser.add_argument('--video_name', type=str, default=None, help='Name of the single video to evaluate')
8
+ parser.add_argument('--video_path', type=str, default='', help='Path to the input video file for visualization')
9
+ parser.add_argument('--checkpoint_path', type=str, default='./checkpoint')
10
+ parser.add_argument('--segment_size', type=int, default=64)
11
+ parser.add_argument('--anchors', type=str, default='2,4,6,8,12,16')
12
+ parser.add_argument('--seed', default=7, type=int, help='random seed for reproducibility')
13
+
14
+ # Overall Dataset settings
15
+ parser.add_argument('--num_of_class', type=int, default=23)
16
+ parser.add_argument('--data_format', type=str, default="npz_i3d")
17
+ parser.add_argument('--data_rescale', default=False, action='store_true')
18
+ parser.add_argument('--predefined_fps', default=None, type=float)
19
+ parser.add_argument('--rgb_only', default=False, action='store_true')
20
+ parser.add_argument('--video_anno', type=str, default="./data/egtea_annotations_split{}.json")
21
+ parser.add_argument('--video_feature_all_train', type=str, default="./data/I3D/")
22
+ parser.add_argument('--video_feature_all_test', type=str, default="./data/I3D/")
23
+ parser.add_argument('--setup', type=str, default="")
24
+ parser.add_argument('--exp', type=str, default="01")
25
+ parser.add_argument('--split', type=str, default="1")
26
+
27
+ # Network
28
+ parser.add_argument('--feat_dim', type=int, default=2048)
29
+ parser.add_argument('--hidden_dim', type=int, default=1024)
30
+ parser.add_argument('--out_dim', type=int, default=23)
31
+ parser.add_argument('--enc_layer', type=int, default=3)
32
+ parser.add_argument('--enc_head', type=int, default=8)
33
+ parser.add_argument('--dec_layer', type=int, default=5)
34
+ parser.add_argument('--dec_head', type=int, default=4)
35
+
36
+ # Training settings
37
+ parser.add_argument('--batch_size', type=int, default=128)
38
+ parser.add_argument('--lr', type=float, default=1e-4)
39
+ parser.add_argument('--weight_decay', type=float, default=1e-4)
40
+ parser.add_argument('--epoch', type=int, default=5)
41
+ parser.add_argument('--lr_step', type=int, default=3)
42
+
43
+ # Post processing
44
+ parser.add_argument('--alpha', type=float, default=1)
45
+ parser.add_argument('--beta', type=float, default=1)
46
+ parser.add_argument('--gamma', type=float, default=0.2)
47
+ parser.add_argument('--pptype', type=str, default="net")
48
+ parser.add_argument('--pos_threshold', type=float, default=0.5)
49
+ parser.add_argument('--sup_threshold', type=float, default=0.1)
50
+ parser.add_argument('--threshold', type=float, default=0.1)
51
+ parser.add_argument('--inference_subset', type=str, default="test")
52
+ parser.add_argument('--soft_nms', type=float, default=0.3)
53
+ parser.add_argument('--video_len_file', type=str, default="./output/video_len_{}.json")
54
+ parser.add_argument('--proposal_label_file', type=str, default="./output/proposal_label_{}.h5")
55
+ parser.add_argument('--suppress_label_file', type=str, default="./output/suppress_label_{}.h5")
56
+ parser.add_argument('--suppress_result_file', type=str, default="./output/suppress_result{}.h5")
57
+ parser.add_argument('--frame_result_file', type=str, default="./output/frame_result{}.h5")
58
+ parser.add_argument('--result_file', type=str, default="./output/result_proposal{}.json")
59
+ parser.add_argument('--wterm', type=bool, default=False)
60
+
61
+ args = parser.parse_args()
62
+ return args
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ h5py
2
+ ipdb
3
+ sklearn
4
+ matplotlib
5
+ tensorboardX
result image main.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+
11
+ import time
12
+ import h5py
13
+ from tqdm import tqdm
14
+ from iou_utils import *
15
+ from eval import evaluation_detection
16
+ from tensorboardX import SummaryWriter
17
+ from dataset import VideoDataSet, calc_iou
18
+ from models import MYNET, SuppressNet
19
+ from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
20
+ from loss_func import MultiCrossEntropyLoss
21
+ from functools import *
22
+
23
+ import matplotlib.pyplot as plt
24
+ import matplotlib.patches as patches
25
+ import cv2
26
+ from typing import List, Dict, Optional
27
+
28
+ # Visualization Configuration
29
+ # Visualization Configuration
30
+ VIS_CONFIG = {
31
+ 'frame_interval': 1.0, # Sample frames every 1 second
32
+ 'max_frames': 20, # Maximum number of frames to display
33
+ 'save_dir': './output/visualizations',
34
+ 'gt_color': '#1f77b4', # Blue for ground truth
35
+ 'pred_color': '#ff7f0e', # Orange for predictions
36
+ 'fontsize_label': 10, # Reduced for better fit
37
+ 'fontsize_title': 14,
38
+ 'frame_highlight_both': 'green',
39
+ 'frame_highlight_gt': 'red',
40
+ 'frame_highlight_pred': 'black',
41
+ 'iou_threshold': 0.3,
42
+ 'frame_scale_factor': 0.8, # Reduced scaling for smaller figure
43
+ }
44
+
45
+ def visualize_action_lengths(
46
+ video_id: str,
47
+ pred_segments: List[Dict],
48
+ gt_segments: List[Dict],
49
+ video_path: str,
50
+ duration: float,
51
+ save_dir: str = VIS_CONFIG['save_dir'],
52
+ frame_interval: float = VIS_CONFIG['frame_interval']
53
+ ) -> None:
54
+ """
55
+ Generate a visualization plot comparing ground truth and predicted action lengths with video frames.
56
+
57
+ Args:
58
+ video_id: Video identifier (e.g., 'my_video').
59
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
60
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
61
+ video_path: Path to the input video file.
62
+ duration: Total duration of the video in seconds.
63
+ save_dir: Directory to save the output image.
64
+ frame_interval: Time interval between sampled frames (seconds).
65
+ """
66
+ os.makedirs(save_dir, exist_ok=True)
67
+
68
+ # Calculate frame sampling times
69
+ num_frames = int(duration / frame_interval) + 1
70
+ if num_frames > VIS_CONFIG['max_frames']:
71
+ frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
72
+ num_frames = VIS_CONFIG['max_frames']
73
+ print(f"Warning: Video duration ({duration:.1f}s) requires {num_frames} frames. Adjusted frame_interval to {frame_interval:.2f}s.")
74
+
75
+ frame_times = np.linspace(0, duration, num_frames, endpoint=False)
76
+
77
+ # Load video frames
78
+ frames = []
79
+ cap = cv2.VideoCapture(video_path)
80
+ if not cap.isOpened():
81
+ print(f"Warning: Could not open video {video_path}. Using placeholder frames.")
82
+ frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times]
83
+ else:
84
+ for t in frame_times:
85
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
86
+ ret, frame = cap.read()
87
+ if ret:
88
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
89
+ # Resize frame to reduce memory usage
90
+ frame = cv2.resize(frame, (int(frame.shape[1] * 0.5), int(frame.shape[0] * 0.5)))
91
+ frames.append(frame)
92
+ else:
93
+ frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
94
+ cap.release()
95
+
96
+ # Initialize figure
97
+ fig = plt.figure(figsize=(num_frames * VIS_CONFIG['frame_scale_factor'], 6), constrained_layout=True)
98
+ gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
99
+
100
+ # Plot frames
101
+ for i, (t, frame) in enumerate(zip(frame_times, frames)):
102
+ ax = fig.add_subplot(gs[0, i])
103
+
104
+ # Check if frame falls within GT or predicted segments
105
+ gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments)
106
+ pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments)
107
+
108
+ # Set border color
109
+ border_color = None
110
+ if gt_hit and pred_hit:
111
+ border_color = VIS_CONFIG['frame_highlight_both']
112
+ elif gt_hit:
113
+ border_color = VIS_CONFIG['frame_highlight_gt']
114
+ elif pred_hit:
115
+ border_color = VIS_CONFIG['frame_highlight_pred']
116
+
117
+ ax.imshow(frame)
118
+ ax.axis('off')
119
+ if border_color:
120
+ for spine in ax.spines.values():
121
+ spine.set_edgecolor(border_color)
122
+ spine.set_linewidth(2)
123
+
124
+ ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'],
125
+ color=border_color if border_color else 'black')
126
+
127
+ # Plot ground truth bar
128
+ ax_gt = fig.add_subplot(gs[1, :])
129
+ ax_gt.set_xlim(0, duration)
130
+ ax_gt.set_ylim(0, 1)
131
+ ax_gt.axis('off')
132
+ ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'],
133
+ va='center', ha='right', weight='bold')
134
+
135
+ for seg in gt_segments:
136
+ start, end = seg['start'], seg['end']
137
+ width = end - start
138
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
139
+ ax_gt.add_patch(patches.Rectangle(
140
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'],
141
+ edgecolor='black', alpha=0.8
142
+ ))
143
+ ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center',
144
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
145
+ ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black')
146
+ ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black')
147
+
148
+ # Plot prediction bar
149
+ ax_pred = fig.add_subplot(gs[2, :])
150
+ ax_pred.set_xlim(0, duration)
151
+ ax_pred.set_ylim(0, 1)
152
+ ax_pred.axis('off')
153
+ ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'],
154
+ va='center', ha='right', weight='bold')
155
+
156
+ for seg in pred_segments:
157
+ start, end = seg['start'], seg['end']
158
+ width = end - start
159
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
160
+ ax_pred.add_patch(patches.Rectangle(
161
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'],
162
+ edgecolor='black', alpha=0.8
163
+ ))
164
+ ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center',
165
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
166
+ ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black')
167
+ ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black')
168
+
169
+ # Save plot
170
+ jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png") # Use PNG
171
+ plt.savefig(jpg_path, dpi=100, bbox_inches='tight') # Lower DPI
172
+ print(f"[✅ Saved Visualization]: {jpg_path}")
173
+ plt.close()
174
+
175
+
176
+
177
+ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
178
+ train_loader = torch.utils.data.DataLoader(train_dataset,
179
+ batch_size=opt['batch_size'], shuffle=True,
180
+ num_workers=0, pin_memory=True, drop_last=False)
181
+ epoch_cost = 0
182
+ epoch_cost_cls = 0
183
+ epoch_cost_reg = 0
184
+ epoch_cost_snip = 0
185
+
186
+ total_iter = len(train_dataset) // opt['batch_size']
187
+ cls_loss = MultiCrossEntropyLoss(focal=True)
188
+ snip_loss = MultiCrossEntropyLoss(focal=True)
189
+ for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)):
190
+ if warmup:
191
+ for g in optimizer.param_groups:
192
+ g['lr'] = n_iter * (opt['lr']) / total_iter
193
+
194
+ act_cls, act_reg, snip_cls = model(input_data.float().cuda())
195
+
196
+ act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
197
+ snip_cls.register_hook(partial(snip_loss.collect_grad, snip_label))
198
+
199
+ cost_reg = 0
200
+ cost_cls = 0
201
+
202
+ loss = cls_loss_func_(cls_loss, cls_label, act_cls)
203
+ cost_cls = loss
204
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
205
+
206
+ loss = regress_loss_func(reg_label, act_reg)
207
+ cost_reg = loss
208
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
209
+
210
+ loss = cls_loss_func_(snip_loss, snip_label, snip_cls)
211
+ cost_snip = loss
212
+ epoch_cost_snip += cost_snip.detach().cpu().numpy()
213
+
214
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip
215
+ epoch_cost += cost.detach().cpu().numpy()
216
+
217
+ optimizer.zero_grad()
218
+ cost.backward()
219
+ optimizer.step()
220
+
221
+ return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
222
+
223
+ def eval_one_epoch(opt, model, test_dataset):
224
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, test_dataset)
225
+
226
+ result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
227
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
228
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
229
+ json.dump(output_dict, outfile, indent=2)
230
+ outfile.close()
231
+
232
+ IoUmAP = evaluation_detection(opt, verbose=False)
233
+ IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:])
234
+
235
+ return cls_loss, reg_loss, tot_loss, IoUmAP_5
236
+
237
+ def train(opt):
238
+ writer = SummaryWriter()
239
+ model = MYNET(opt).cuda()
240
+
241
+ rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
242
+ optimizer = optim.Adam([{'params': model.history_unit.parameters(), 'lr': 1e-6}, {'params': rest_of_model_params}], lr=opt["lr"], weight_decay=opt["weight_decay"])
243
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt["lr_step"])
244
+
245
+ train_dataset = VideoDataSet(opt, subset="train")
246
+ test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
247
+
248
+ warmup = False
249
+
250
+ for n_epoch in range(opt['epoch']):
251
+ if n_epoch >= 1:
252
+ warmup = False
253
+
254
+ n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
255
+
256
+ writer.add_scalars('data/cost', {'train': epoch_cost / (n_iter + 1)}, n_epoch)
257
+ print("training loss(epoch %d): %.03f, cls - %f, reg - %f, snip - %f, lr - %f" % (n_epoch,
258
+ epoch_cost / (n_iter + 1),
259
+ epoch_cost_cls / (n_iter + 1),
260
+ epoch_cost_reg / (n_iter + 1),
261
+ epoch_cost_snip / (n_iter + 1),
262
+ optimizer.param_groups[-1]["lr"]))
263
+
264
+ scheduler.step()
265
+ model.eval()
266
+
267
+ cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
268
+
269
+ writer.add_scalars('data/mAP', {'test': IoUmAP_5}, n_epoch)
270
+ print("testing loss(epoch %d): %.03f, cls - %f, reg - %f, mAP Avg - %f" % (n_epoch, tot_loss, cls_loss, reg_loss, IoUmAP_5))
271
+
272
+ state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
273
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_checkpoint_" + str(n_epoch + 1) + ".pth.tar")
274
+ if IoUmAP_5 > model.best_map:
275
+ model.best_map = IoUmAP_5
276
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_ckp_best.pth.tar")
277
+
278
+ model.train()
279
+
280
+ writer.close()
281
+ return model.best_map
282
+
283
+ def eval_frame(opt, model, dataset):
284
+ test_loader = torch.utils.data.DataLoader(dataset,
285
+ batch_size=opt['batch_size'], shuffle=False,
286
+ num_workers=0, pin_memory=True, drop_last=False)
287
+
288
+ labels_cls = {}
289
+ labels_reg = {}
290
+ output_cls = {}
291
+ output_reg = {}
292
+ for video_name in dataset.video_list:
293
+ labels_cls[video_name] = []
294
+ labels_reg[video_name] = []
295
+ output_cls[video_name] = []
296
+ output_reg[video_name] = []
297
+
298
+ start_time = time.time()
299
+ total_frames = 0
300
+ epoch_cost = 0
301
+ epoch_cost_cls = 0
302
+ epoch_cost_reg = 0
303
+
304
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
305
+ act_cls, act_reg, _ = model(input_data.float().cuda())
306
+ cost_reg = 0
307
+ cost_cls = 0
308
+
309
+ loss = cls_loss_func(cls_label, act_cls)
310
+ cost_cls = loss
311
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
312
+
313
+ loss = regress_loss_func(reg_label, act_reg)
314
+ cost_reg = loss
315
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
316
+
317
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
318
+ epoch_cost += cost.detach().cpu().numpy()
319
+
320
+ act_cls = torch.softmax(act_cls, dim=-1)
321
+
322
+ total_frames += input_data.size(0)
323
+
324
+ for b in range(0, input_data.size(0)):
325
+ video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + b]
326
+ output_cls[video_name] += [act_cls[b, :].detach().cpu().numpy()]
327
+ output_reg[video_name] += [act_reg[b, :].detach().cpu().numpy()]
328
+ labels_cls[video_name] += [cls_label[b, :].numpy()]
329
+ labels_reg[video_name] += [reg_label[b, :].numpy()]
330
+
331
+ end_time = time.time()
332
+ working_time = end_time - start_time
333
+
334
+ for video_name in dataset.video_list:
335
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
336
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
337
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
338
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
339
+
340
+ cls_loss = epoch_cost_cls / n_iter
341
+ reg_loss = epoch_cost_reg / n_iter
342
+ tot_loss = epoch_cost / n_iter
343
+
344
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
345
+
346
+ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
347
+ result_dict = {}
348
+ proposal_dict = []
349
+
350
+ num_class = opt["num_of_class"]
351
+ unit_size = opt['segment_size']
352
+ threshold = opt['threshold']
353
+ anchors = opt['anchors']
354
+
355
+ for video_name in dataset.video_list:
356
+ duration = dataset.video_len[video_name]
357
+ video_time = float(dataset.video_dict[video_name]["duration"])
358
+ frame_to_time = 100.0 * video_time / duration
359
+
360
+ for idx in range(0, duration):
361
+ cls_anc = output_cls[video_name][idx]
362
+ reg_anc = output_reg[video_name][idx]
363
+
364
+ proposal_anc_dict = []
365
+ for anc_idx in range(0, len(anchors)):
366
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
367
+
368
+ if len(cls) == 0:
369
+ continue
370
+
371
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
372
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
373
+ st = ed - length
374
+
375
+ for cidx in range(0, len(cls)):
376
+ label = cls[cidx]
377
+ tmp_dict = {}
378
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
379
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
380
+ tmp_dict["label"] = dataset.label_name[label]
381
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
382
+ proposal_anc_dict.append(tmp_dict)
383
+
384
+ proposal_dict += proposal_anc_dict
385
+
386
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
387
+ result_dict[video_name] = proposal_dict
388
+ proposal_dict = []
389
+
390
+ return result_dict
391
+
392
+ def eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
393
+ model = SuppressNet(opt).cuda()
394
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
395
+ base_dict = checkpoint['state_dict']
396
+ model.load_state_dict(base_dict)
397
+ model.eval()
398
+
399
+ result_dict = {}
400
+ proposal_dict = []
401
+
402
+ num_class = opt["num_of_class"]
403
+ unit_size = opt['segment_size']
404
+ threshold = opt['threshold']
405
+ anchors = opt['anchors']
406
+
407
+ for video_name in dataset.video_list:
408
+ duration = dataset.video_len[video_name]
409
+ video_time = float(dataset.video_dict[video_name]["duration"])
410
+ frame_to_time = 100.0 * video_time / duration
411
+ conf_queue = torch.zeros((unit_size, num_class - 1))
412
+
413
+ for idx in range(0, duration):
414
+ cls_anc = output_cls[video_name][idx]
415
+ reg_anc = output_reg[video_name][idx]
416
+
417
+ proposal_anc_dict = []
418
+ for anc_idx in range(0, len(anchors)):
419
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
420
+
421
+ if len(cls) == 0:
422
+ continue
423
+
424
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
425
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
426
+ st = ed - length
427
+
428
+ for cidx in range(0, len(cls)):
429
+ label = cls[cidx]
430
+ tmp_dict = {}
431
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
432
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
433
+ tmp_dict["label"] = dataset.label_name[label]
434
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
435
+ proposal_anc_dict.append(tmp_dict)
436
+
437
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
438
+
439
+ conf_queue[:-1, :] = conf_queue[1:, :].clone()
440
+ conf_queue[-1, :] = 0
441
+ for proposal in proposal_anc_dict:
442
+ cls_idx = dataset.label_name.index(proposal['label'])
443
+ conf_queue[-1, cls_idx] = proposal["score"]
444
+
445
+ minput = conf_queue.unsqueeze(0)
446
+ suppress_conf = model(minput.cuda())
447
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
448
+
449
+ for cls in range(0, num_class - 1):
450
+ if suppress_conf[cls] > opt['sup_threshold']:
451
+ for proposal in proposal_anc_dict:
452
+ if proposal['label'] == dataset.label_name[cls]:
453
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
454
+ proposal_dict.append(proposal)
455
+
456
+ result_dict[video_name] = proposal_dict
457
+ proposal_dict = []
458
+
459
+ return result_dict
460
+
461
+ def test_frame(opt, video_name=None):
462
+ model = MYNET(opt).cuda()
463
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
464
+ base_dict = checkpoint['state_dict']
465
+ model.load_state_dict(base_dict)
466
+ model.eval()
467
+
468
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
469
+ outfile = h5py.File(opt['frame_result_file'].format(opt['exp']), 'w')
470
+
471
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
472
+
473
+ print("testing loss: %f, cls_loss: %f, reg_loss: %f" % (tot_loss, cls_loss, reg_loss))
474
+
475
+ for video_name in dataset.video_list:
476
+ o_cls = output_cls[video_name]
477
+ o_reg = output_reg[video_name]
478
+ l_cls = labels_cls[video_name]
479
+ l_reg = labels_reg[video_name]
480
+
481
+ dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
482
+ dset_predcls[:, :] = o_cls[:, :]
483
+ dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
484
+ dset_predreg[:, :] = o_reg[:, :]
485
+ dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
486
+ dset_labelcls[:, :] = l_cls[:, :]
487
+ dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
488
+ dset_labelreg[:, :] = l_reg[:, :]
489
+ outfile.close()
490
+
491
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
492
+ return cls_loss, reg_loss, tot_loss
493
+
494
+ def patch_attention(m):
495
+ forward_orig = m.forward
496
+
497
+ def wrap(*args, **kwargs):
498
+ kwargs["need_weights"] = True
499
+ kwargs["average_attn_weights"] = False
500
+ return forward_orig(*args, **kwargs)
501
+
502
+ m.forward = wrap
503
+
504
+ class SaveOutput:
505
+ def __init__(self):
506
+ self.outputs = []
507
+
508
+ def __call__(self, module, module_in, module_out):
509
+ self.outputs.append(module_out[1])
510
+
511
+ def clear(self):
512
+ self.outputs = []
513
+
514
+ def test(opt, video_name=None):
515
+ model = MYNET(opt).cuda()
516
+ checkpoint = torch.load(opt["checkpoint_path"] + "/" + opt['exp'] + "_ckp_best.pth.tar")
517
+ base_dict = checkpoint['state_dict']
518
+ model.load_state_dict(base_dict)
519
+ model.eval()
520
+
521
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
522
+
523
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
524
+
525
+ if opt["pptype"] == "nms":
526
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
527
+ if opt["pptype"] == "net":
528
+ result_dict = eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
529
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
530
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
531
+ json.dump(output_dict, outfile, indent=2)
532
+ outfile.close()
533
+
534
+ mAP = evaluation_detection(opt)
535
+
536
+ # Compare predicted and ground truth action lengths
537
+ if video_name:
538
+ print("\nComparing Predicted and Ground Truth Action Lengths for Video:", video_name)
539
+ # Load ground truth annotations
540
+ with open(opt["video_anno"].format(opt["split"]), 'r') as f:
541
+ anno_data = json.load(f)
542
+ gt_annotations = anno_data['database'][video_name]['annotations']
543
+ duration = anno_data['database'][video_name]['duration']
544
+
545
+ # Extract ground truth segments
546
+ gt_segments = []
547
+ for anno in gt_annotations:
548
+ start, end = anno['segment']
549
+ label = anno['label']
550
+ duration_seg = end - start
551
+ gt_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg})
552
+
553
+ # Extract predicted segments
554
+ pred_segments = []
555
+ for pred in result_dict[video_name]:
556
+ start, end = pred['segment']
557
+ label = pred['label']
558
+ score = pred['score']
559
+ duration_seg = end - start
560
+ pred_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg, 'score': score})
561
+
562
+ # Print comparison table
563
+ matches = []
564
+ iou_threshold = VIS_CONFIG['iou_threshold']
565
+ used_gt_indices = set()
566
+ for pred in pred_segments:
567
+ best_iou = 0
568
+ best_gt_idx = None
569
+ for gt_idx, gt in enumerate(gt_segments):
570
+ if gt_idx in used_gt_indices:
571
+ continue
572
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
573
+ if iou > best_iou and iou >= iou_threshold:
574
+ best_iou = iou
575
+ best_gt_idx = gt_idx
576
+ if best_gt_idx is not None:
577
+ matches.append({
578
+ 'pred': pred,
579
+ 'gt': gt_segments[best_gt_idx],
580
+ 'iou': best_iou
581
+ })
582
+ used_gt_indices.add(best_gt_idx)
583
+ else:
584
+ matches.append({'pred': pred, 'gt': None, 'iou': 0})
585
+
586
+ for gt_idx, gt in enumerate(gt_segments):
587
+ if gt_idx not in used_gt_indices:
588
+ matches.append({'pred': None, 'gt': gt, 'iou': 0})
589
+
590
+ print("\n{:<20} {:<30} {:<30} {:<15} {:<10}".format(
591
+ "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU"))
592
+ print("-" * 105)
593
+ for match in matches:
594
+ pred = match['pred']
595
+ gt = match['gt']
596
+ iou = match['iou']
597
+ if pred and gt:
598
+ label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
599
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
600
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
601
+ duration_diff = pred['duration'] - gt['duration']
602
+ print("{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}".format(
603
+ label, pred_str, gt_str, duration_diff, iou))
604
+ elif pred:
605
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
606
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
607
+ pred['label'], pred_str, "None", "N/A", iou))
608
+ elif gt:
609
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
610
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
611
+ gt['label'], "None", gt_str, "N/A", iou))
612
+
613
+ # Summarize
614
+ matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
615
+ avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0
616
+ avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0
617
+ print(f"\nSummary:")
618
+ print(f"- Total Predictions: {len(pred_segments)}")
619
+ print(f"- Total Ground Truth: {len(gt_segments)}")
620
+ print(f"- Matched Segments: {matched_count}")
621
+ print(f"- Average Duration Difference (Matched): {avg_duration_diff:.2f}s")
622
+ print(f"- Average IoU (Matched): {avg_iou:.2f}")
623
+
624
+ # Generate visualization
625
+ video_path = opt.get('video_path', '') # Add --video_path to opts_egtea.py
626
+ if os.path.exists(video_path):
627
+ visualize_action_lengths(
628
+ video_id=video_name,
629
+ pred_segments=pred_segments,
630
+ gt_segments=gt_segments,
631
+ video_path=video_path,
632
+ duration=duration
633
+ )
634
+ else:
635
+ print(f"Warning: Video path {video_path} not found. Skipping visualization.")
636
+
637
+ return mAP
638
+
639
+ def test_online(opt, video_name=None):
640
+ model = MYNET(opt).cuda()
641
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
642
+ base_dict = checkpoint['state_dict']
643
+ model.load_state_dict(base_dict)
644
+ model.eval()
645
+
646
+ sup_model = SuppressNet(opt).cuda()
647
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
648
+ base_dict = checkpoint['state_dict']
649
+ sup_model.load_state_dict(base_dict)
650
+ sup_model.eval()
651
+
652
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
653
+ test_loader = torch.utils.data.DataLoader(dataset,
654
+ batch_size=1, shuffle=False,
655
+ num_workers=0, pin_memory=True, drop_last=False)
656
+
657
+ result_dict = {}
658
+ proposal_dict = []
659
+
660
+ num_class = opt["num_of_class"]
661
+ unit_size = opt['segment_size']
662
+ threshold = opt['threshold']
663
+ anchors = opt['anchors']
664
+
665
+ start_time = time.time()
666
+ total_frames = 0
667
+
668
+ for video_name in dataset.video_list:
669
+ input_queue = torch.zeros((unit_size, opt['feat_dim']))
670
+ sup_queue = torch.zeros(((unit_size, num_class - 1)))
671
+
672
+ duration = dataset.video_len[video_name]
673
+ video_time = float(dataset.video_dict[video_name]["duration"])
674
+ frame_to_time = 100.0 * video_time / duration
675
+
676
+ for idx in range(0, duration):
677
+ total_frames += 1
678
+ input_queue[:-1, :] = input_queue[1:, :].clone()
679
+ input_queue[-1:, :] = dataset._get_base_data(video_name, idx, idx + 1)
680
+
681
+ minput = input_queue.unsqueeze(0)
682
+ act_cls, act_reg, _ = model(minput.cuda())
683
+ act_cls = torch.softmax(act_cls, dim=-1)
684
+
685
+ cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
686
+ reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
687
+
688
+ proposal_anc_dict = []
689
+ for anc_idx in range(0, len(anchors)):
690
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
691
+
692
+ if len(cls) == 0:
693
+ continue
694
+
695
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
696
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
697
+ st = ed - length
698
+
699
+ for cidx in range(0, len(cls)):
700
+ label = cls[cidx]
701
+ tmp_dict = {}
702
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
703
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
704
+ tmp_dict["label"] = dataset.label_name[label]
705
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
706
+ proposal_anc_dict.append(tmp_dict)
707
+
708
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
709
+
710
+ sup_queue[:-1, :] = sup_queue[1:, :].clone()
711
+ sup_queue[-1, :] = 0
712
+ for proposal in proposal_anc_dict:
713
+ cls_idx = dataset.label_name.index(proposal['label'])
714
+ sup_queue[-1, cls_idx] = proposal["score"]
715
+
716
+ minput = sup_queue.unsqueeze(0)
717
+ suppress_conf = sup_model(minput.cuda())
718
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
719
+
720
+ for cls in range(0, num_class - 1):
721
+ if suppress_conf[cls] > opt['sup_threshold']:
722
+ for proposal in proposal_anc_dict:
723
+ if proposal['label'] == dataset.label_name[cls]:
724
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
725
+ proposal_dict.append(proposal)
726
+
727
+ result_dict[video_name] = proposal_dict
728
+ proposal_dict = []
729
+
730
+ end_time = time.time()
731
+ working_time = end_time - start_time
732
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
733
+
734
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
735
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
736
+ json.dump(output_dict, outfile, indent=2)
737
+ outfile.close()
738
+
739
+ mAP = evaluation_detection(opt)
740
+ return mAP
741
+
742
+ def main(opt, video_name=None):
743
+ max_perf = 0
744
+ if not video_name and 'video_name' in opt:
745
+ video_name = opt['video_name']
746
+
747
+ if opt['mode'] == 'train':
748
+ max_perf = train(opt)
749
+ if opt['mode'] == 'test':
750
+ max_perf = test(opt, video_name=video_name)
751
+ if opt['mode'] == 'test_frame':
752
+ max_perf = test_frame(opt, video_name=video_name)
753
+ if opt['mode'] == 'test_online':
754
+ max_perf = test_online(opt, video_name=video_name)
755
+ if opt['mode'] == 'eval':
756
+ max_perf = evaluation_detection(opt)
757
+
758
+ return max_perf
759
+
760
+ if __name__ == '__main__':
761
+ opt = opts.parse_opt()
762
+ opt = vars(opt)
763
+ if not os.path.exists(opt["checkpoint_path"]):
764
+ os.makedirs(opt["checkpoint_path"])
765
+ opt_file = open(opt["checkpoint_path"] + "/" + opt["exp"] + "_opts.json", "w")
766
+ json.dump(opt, opt_file)
767
+ opt_file.close()
768
+
769
+ if opt['seed'] >= 0:
770
+ seed = opt['seed']
771
+ torch.manual_seed(seed)
772
+ np.random.seed(seed)
773
+
774
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
775
+
776
+ video_name = opt.get('video_name', None)
777
+ main(opt, video_name=video_name)
778
+ while(opt['wterm']):
779
+ pass
result image opts_egtea.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def parse_opt():
4
+ parser = argparse.ArgumentParser()
5
+ # Overall settings
6
+ parser.add_argument('--mode', type=str, default='train')
7
+ parser.add_argument('--video_name', type=str, default=None, help='Name of the single video to evaluate')
8
+ parser.add_argument('--video_path', type=str, default='', help='Path to the input video file for visualization')
9
+ parser.add_argument('--checkpoint_path', type=str, default='./checkpoint')
10
+ parser.add_argument('--segment_size', type=int, default=64)
11
+ parser.add_argument('--anchors', type=str, default='2,4,6,8,12,16')
12
+ parser.add_argument('--seed', default=7, type=int, help='random seed for reproducibility')
13
+
14
+ # Overall Dataset settings
15
+ parser.add_argument('--num_of_class', type=int, default=23)
16
+ parser.add_argument('--data_format', type=str, default="npz_i3d")
17
+ parser.add_argument('--data_rescale', default=False, action='store_true')
18
+ parser.add_argument('--predefined_fps', default=None, type=float)
19
+ parser.add_argument('--rgb_only', default=False, action='store_true')
20
+ parser.add_argument('--video_anno', type=str, default="./data/egtea_annotations_split{}.json")
21
+ parser.add_argument('--video_feature_all_train', type=str, default="./data/I3D/")
22
+ parser.add_argument('--video_feature_all_test', type=str, default="./data/I3D/")
23
+ parser.add_argument('--setup', type=str, default="")
24
+ parser.add_argument('--exp', type=str, default="01")
25
+ parser.add_argument('--split', type=str, default="1")
26
+
27
+ # Network
28
+ parser.add_argument('--feat_dim', type=int, default=2048)
29
+ parser.add_argument('--hidden_dim', type=int, default=1024)
30
+ parser.add_argument('--out_dim', type=int, default=23)
31
+ parser.add_argument('--enc_layer', type=int, default=3)
32
+ parser.add_argument('--enc_head', type=int, default=8)
33
+ parser.add_argument('--dec_layer', type=int, default=5)
34
+ parser.add_argument('--dec_head', type=int, default=4)
35
+
36
+ # Training settings
37
+ parser.add_argument('--batch_size', type=int, default=128)
38
+ parser.add_argument('--lr', type=float, default=1e-4)
39
+ parser.add_argument('--weight_decay', type=float, default=1e-4)
40
+ parser.add_argument('--epoch', type=int, default=5)
41
+ parser.add_argument('--lr_step', type=int, default=3)
42
+
43
+ # Post processing
44
+ parser.add_argument('--alpha', type=float, default=1)
45
+ parser.add_argument('--beta', type=float, default=1)
46
+ parser.add_argument('--gamma', type=float, default=0.2)
47
+ parser.add_argument('--pptype', type=str, default="net")
48
+ parser.add_argument('--pos_threshold', type=float, default=0.5)
49
+ parser.add_argument('--sup_threshold', type=float, default=0.1)
50
+ parser.add_argument('--threshold', type=float, default=0.1)
51
+ parser.add_argument('--inference_subset', type=str, default="test")
52
+ parser.add_argument('--soft_nms', type=float, default=0.3)
53
+ parser.add_argument('--video_len_file', type=str, default="./output/video_len_{}.json")
54
+ parser.add_argument('--proposal_label_file', type=str, default="./output/proposal_label_{}.h5")
55
+ parser.add_argument('--suppress_label_file', type=str, default="./output/suppress_label_{}.h5")
56
+ parser.add_argument('--suppress_result_file', type=str, default="./output/suppress_result{}.h5")
57
+ parser.add_argument('--frame_result_file', type=str, default="./output/frame_result{}.h5")
58
+ parser.add_argument('--result_file', type=str, default="./output/result_proposal{}.json")
59
+ parser.add_argument('--wterm', type=bool, default=False)
60
+
61
+ args = parser.parse_args()
62
+ return args
rgb bar main.py ADDED
@@ -0,0 +1,1144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+
11
+ import time
12
+ import h5py
13
+ from tqdm import tqdm
14
+ from iou_utils import *
15
+ from eval import evaluation_detection
16
+ from tensorboardX import SummaryWriter
17
+ from dataset import VideoDataSet, calc_iou
18
+ from models import MYNET, SuppressNet
19
+ from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
20
+ from loss_func import MultiCrossEntropyLoss
21
+ from functools import *
22
+
23
+ import matplotlib.pyplot as plt
24
+ import matplotlib.patches as patches
25
+ import cv2
26
+ from typing import List, Dict, Optional
27
+
28
+ from PIL import Image, ImageDraw, ImageFont
29
+ import warnings
30
+
31
+ # Visualization Configuration (Updated)
32
+ VIS_CONFIG = {
33
+ 'frame_interval': 1.0,
34
+ 'max_frames': 20,
35
+ 'save_dir': './output/visualizations',
36
+ 'video_save_dir': './output/videos',
37
+ 'gt_color': '#1f77b4', # Blue for ground truth (RGB: 31, 119, 180)
38
+ 'pred_color': '#ff7f0e', # Orange for predictions (RGB: 255, 127, 14)
39
+ 'fontsize_label': 10,
40
+ 'fontsize_title': 14,
41
+ 'frame_highlight_both': 'green',
42
+ 'frame_highlight_gt': 'red',
43
+ 'frame_highlight_pred': 'black',
44
+ 'iou_threshold': 0.3,
45
+ 'frame_scale_factor': 0.8,
46
+ 'video_text_scale': 0.5,
47
+ 'video_gt_text_color': (180, 119, 31), # BGR for OpenCV
48
+ 'video_pred_text_color': (14, 127, 255), # BGR for OpenCV
49
+ 'video_text_thickness': 1,
50
+ 'video_font_path': "./data/Poppins ExtraBold Italic 800.ttf",
51
+ 'video_font_fallback': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
52
+ 'video_pred_text_y': 0.45,
53
+ 'video_gt_text_y': 0.55,
54
+ 'video_footer_height': 150, # Increased to accommodate labels
55
+ 'video_gt_bar_y': 0.5,
56
+ 'video_pred_bar_y': 0.8,
57
+ 'video_bar_height': 0.15,
58
+ 'video_bar_text_scale': 0.7,
59
+ 'min_segment_duration': 1.0,
60
+ 'video_frame_text_y': 0.05, # Position for frame number and FPS
61
+ 'video_bar_label_x': 10, # X-position for GT/Pred labels
62
+ 'video_bar_label_scale': 0.5,
63
+ 'scroll_window_duration': 30.0, # Duration of the visible time window (seconds)
64
+ 'scroll_speed': 0.5, # Seconds to advance the window per second of video
65
+ }
66
+
67
+
68
+ def annotate_video_with_actions(
69
+ video_id: str,
70
+ pred_segments: List[Dict],
71
+ gt_segments: List[Dict],
72
+ video_path: str,
73
+ save_dir: str = VIS_CONFIG['video_save_dir'],
74
+ text_scale: float = VIS_CONFIG['video_text_scale'] * 1.5, # Increased text size by 50%
75
+ gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'],
76
+ pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'],
77
+ text_thickness: int = VIS_CONFIG['video_text_thickness']
78
+ ) -> None:
79
+ """
80
+ Annotate a video with predicted and ground truth action labels, cumulative bars, frame number, and FPS.
81
+ Use fixed 20-second windows with original bar animation, resetting bars at each window boundary.
82
+ Different colors for different action classes, no labels or timestamps on bars, increased text size.
83
+ GT and Pred text labels are on the left, with bars starting 0.5 inches (48 pixels) to the right.
84
+
85
+ Args:
86
+ video_id: Video identifier (e.g., 'my_video').
87
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
88
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
89
+ video_path: Path to the input video file.
90
+ save_dir: Directory to save the annotated video.
91
+ text_scale: Scale factor for text size in video (increased).
92
+ gt_text_color: BGR color tuple for ground truth text.
93
+ pred_text_color: BGR color tuple for predicted text.
94
+ text_thickness: Thickness of text strokes.
95
+ """
96
+ os.makedirs(save_dir, exist_ok=True)
97
+
98
+ # Open input video
99
+ cap = cv2.VideoCapture(video_path)
100
+ if not cap.isOpened():
101
+ print(f"Error: Could not open video {video_path}. Skipping video annotation.")
102
+ return
103
+
104
+ # Get video properties
105
+ fps = cap.get(cv2.CAP_PROP_FPS)
106
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
107
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
108
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
109
+ duration = total_frames / fps
110
+ print(f"Input Video: FPS={fps:.2f}, Resolution={frame_width}x{frame_height}, Total Frames={total_frames}, Duration={duration:.2f}s")
111
+
112
+ # Define output video with extended height for footer
113
+ footer_height = VIS_CONFIG['video_footer_height']
114
+ output_height = frame_height + footer_height
115
+ output_path = os.path.join(save_dir, f"annotated_{video_id}_{opt['exp']}.avi")
116
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
117
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, output_height))
118
+
119
+ if not out.isOpened():
120
+ print(f"Error: Could not initialize video writer for {output_path}. Check codec availability.")
121
+ cap.release()
122
+ return
123
+
124
+ # Filter short segments
125
+ min_duration = VIS_CONFIG['min_segment_duration']
126
+ gt_segments = [seg for seg in gt_segments if seg['duration'] >= min_duration]
127
+ pred_segments = [seg for seg in pred_segments if seg['duration'] >= min_duration]
128
+ print(f"Filtered Segments: GT={len(gt_segments)}, Pred={len(pred_segments)} (min_duration={min_duration}s)")
129
+
130
+ # Define color palette (BGR)
131
+ color_palette = [
132
+ (128, 0, 0), # Navy Blue
133
+ (60, 20, 220), # Crimson Red
134
+ (0, 128, 0), # Emerald Green
135
+ (128, 0, 128), # Royal Purple
136
+ (79, 69, 54), # Charcoal Gray
137
+ (128, 128, 0), # Teal
138
+ (0, 0, 128), # Maroon
139
+ (130, 0, 75), # Indigo
140
+ (34, 139, 34), # Forest Green
141
+ (0, 85, 204), # Burnt Orange
142
+ (149, 146, 209), # Dusty Rose
143
+ (235, 206, 135), # Sky Blue
144
+ (250, 230, 230), # Lavender
145
+ (191, 226, 159), # Seafoam Green
146
+ (185, 218, 255), # Peach
147
+ (255, 204, 204), # Periwinkle
148
+ (193, 182, 255), # Blush Pink
149
+ (201, 252, 189), # Mint Green
150
+ (144, 128, 112), # Slate Gray
151
+ (112, 25, 25), # Midnight Blue
152
+ (102, 51, 102), # Deep Plum
153
+ (0, 128, 128), # Olive Green
154
+ (171, 71, 0) # Cobalt Blue
155
+ ]
156
+
157
+ # Create color mapping for actions
158
+ action_labels = set(seg['label'] for seg in gt_segments).union(seg['label'] for seg in pred_segments)
159
+ action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
160
+ print(f"Action Color Mapping: {action_color_map}")
161
+
162
+ # Convert fallback colors to RGB for PIL
163
+ gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0]) # BGR to RGB
164
+ pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0]) # BGR to RGB
165
+
166
+ # Load font
167
+ font_path = VIS_CONFIG['video_font_path']
168
+ font_fallback = VIS_CONFIG['video_font_fallback']
169
+ font_size = int(20 * text_scale)
170
+ bar_font_size = int(20 * VIS_CONFIG['video_bar_text_scale'])
171
+ font = None
172
+ bar_font = None
173
+ if font_path:
174
+ try:
175
+ font = ImageFont.truetype(font_path, font_size)
176
+ bar_font = ImageFont.truetype(font_path, bar_font_size)
177
+ print(f"Using font: {font_path}")
178
+ except IOError:
179
+ print(f"Warning: Font {font_path} not found. Trying fallback font.")
180
+ if not font:
181
+ try:
182
+ font = ImageFont.truetype(font_fallback, font_size)
183
+ bar_font = ImageFont.truetype(font_fallback, bar_font_size)
184
+ print(f"Using fallback font: {font_fallback}")
185
+ except IOError:
186
+ print(f"Warning: Fallback font {font_fallback} not found. Using OpenCV default font.")
187
+ font = None
188
+ bar_font = None
189
+
190
+ # Fixed window configuration
191
+ window_size = 20.0 # 20-second windows
192
+ num_windows = int(np.ceil(duration / window_size))
193
+
194
+ # Define horizontal gap (0.5 inch = 48 pixels at 96 DPI)
195
+ text_bar_gap = 48 # Pixels
196
+ text_x = 10 # Fixed x-position for GT and Pred labels
197
+
198
+ frame_idx = 0
199
+ written_frames = 0
200
+ while cap.isOpened():
201
+ ret, frame = cap.read()
202
+ if not ret:
203
+ break
204
+
205
+ # Create extended frame with footer
206
+ extended_frame = np.zeros((output_height, frame_width, 3), dtype=np.uint8)
207
+ extended_frame[:frame_height, :, :] = frame
208
+ extended_frame[frame_height:, :, :] = 255 # White footer
209
+
210
+ # Calculate current timestamp
211
+ timestamp = frame_idx / fps
212
+
213
+ # Determine current window
214
+ window_idx = int(timestamp // window_size)
215
+ window_start = window_idx * window_size
216
+ window_end = min(window_start + window_size, duration)
217
+ window_duration = window_end - window_start
218
+ window_timestamp = timestamp - window_start # Relative timestamp within window
219
+
220
+ # Find active GT actions (for text overlay)
221
+ gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
222
+ gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else ""
223
+
224
+ # Find active predicted actions (for text overlay)
225
+ pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
226
+ pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else ""
227
+
228
+ # Draw GT and prediction bars in footer (within current window, using original animation)
229
+ footer_y = frame_height
230
+ gt_bar_y = footer_y + int(0.2 * footer_height) # GT bar position
231
+ pred_bar_y = footer_y + int(0.5 * footer_height) # Pred bar position
232
+ bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
233
+
234
+ # Calculate text width for GT and Pred labels to determine bar start
235
+ if font:
236
+ gt_text_bbox = bar_font.getbbox("GT")
237
+ pred_text_bbox = bar_font.getbbox("Pred")
238
+ gt_text_width = gt_text_bbox[2] - gt_text_bbox[0]
239
+ pred_text_width = pred_text_bbox[2] - pred_text_bbox[0]
240
+ else:
241
+ gt_text_size, _ = cv2.getTextSize("GT", cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
242
+ pred_text_size, _ = cv2.getTextSize("Pred", cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
243
+ gt_text_width = gt_text_size[0]
244
+ pred_text_width = pred_text_size[0]
245
+ max_text_width = max(gt_text_width, pred_text_width)
246
+ bar_start_x = text_x + max_text_width + text_bar_gap # Bars start after text + 0.5-inch gap
247
+ bar_width = frame_width - bar_start_x # Adjust bar width to fit remaining space
248
+
249
+ # Draw bars with action-specific colors
250
+ for seg in gt_segments:
251
+ if seg['start'] <= window_end and seg['end'] >= window_start:
252
+ start_t = max(seg['start'], window_start)
253
+ end_t = min(seg['end'], window_start + window_timestamp) # Original animation
254
+ start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
255
+ end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
256
+ if end_x > start_x:
257
+ cv2.rectangle(
258
+ extended_frame,
259
+ (start_x, gt_bar_y),
260
+ (end_x, gt_bar_y + bar_height),
261
+ action_color_map[seg['label']], # Action-specific color
262
+ -1
263
+ )
264
+
265
+ for seg in pred_segments:
266
+ if seg['start'] <= window_end and seg['end'] >= window_start:
267
+ start_t = max(seg['start'], window_start)
268
+ end_t = min(seg['end'], window_start + window_timestamp) # Original animation
269
+ start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
270
+ end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
271
+ if end_x > start_x:
272
+ cv2.rectangle(
273
+ extended_frame,
274
+ (start_x, pred_bar_y),
275
+ (end_x, pred_bar_y + bar_height),
276
+ action_color_map[seg['label']], # Action-specific color
277
+ -1
278
+ )
279
+
280
+ if font:
281
+ # Convert frame to PIL image
282
+ frame_rgb = cv2.cvtColor(extended_frame, cv2.COLOR_BGR2RGB)
283
+ pil_image = Image.fromarray(frame_rgb)
284
+ draw = ImageDraw.Draw(pil_image)
285
+
286
+ # Draw frame number and FPS at top center
287
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
288
+ frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
289
+ frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
290
+ frame_text_x = (frame_width - frame_text_width) // 2
291
+ draw.text((frame_text_x, 10), frame_info, font=font, fill=(0, 0, 0))
292
+
293
+ # Draw window timestamp range at top of footer
294
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
295
+ window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
296
+ window_text_width = window_text_bbox[2] - window_text_bbox[0]
297
+ window_text_x = (frame_width - window_text_width) // 2
298
+ draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
299
+
300
+ # Draw GT text in video only if there are actions
301
+ if gt_text:
302
+ gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
303
+ draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
304
+
305
+ # Draw predicted text in video only if there are actions
306
+ if pred_text:
307
+ pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
308
+ draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
309
+
310
+ # Draw GT and Pred labels in footer
311
+ draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
312
+ draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
313
+
314
+ # Convert back to OpenCV frame
315
+ extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
316
+ else:
317
+ # Fallback to OpenCV font
318
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
319
+ text_size, _ = cv2.getTextSize(frame_info, cv2.FONT_HERSHEY_DUPLEX, text_scale, text_thickness)
320
+ frame_text_x = (frame_width - text_size[0]) // 2
321
+ cv2.putText(
322
+ extended_frame,
323
+ frame_info,
324
+ (frame_text_x, 30),
325
+ cv2.FONT_HERSHEY_DUPLEX,
326
+ text_scale,
327
+ (0, 0, 0),
328
+ text_thickness,
329
+ cv2.LINE_AA
330
+ )
331
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
332
+ window_text_size, _ = cv2.getTextSize(window_info, cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
333
+ window_text_x = (frame_width - window_text_size[0]) // 2
334
+ cv2.putText(
335
+ extended_frame,
336
+ window_info,
337
+ (window_text_x, footer_y + 20),
338
+ cv2.FONT_HERSHEY_DUPLEX,
339
+ VIS_CONFIG['video_bar_text_scale'],
340
+ (0, 0, 0),
341
+ 1,
342
+ cv2.LINE_AA
343
+ )
344
+ if gt_text:
345
+ cv2.putText(
346
+ extended_frame,
347
+ gt_text,
348
+ (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
349
+ cv2.FONT_HERSHEY_DUPLEX,
350
+ text_scale,
351
+ gt_text_color,
352
+ text_thickness,
353
+ cv2.LINE_AA
354
+ )
355
+ if pred_text:
356
+ cv2.putText(
357
+ extended_frame,
358
+ pred_text,
359
+ (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
360
+ cv2.FONT_HERSHEY_DUPLEX,
361
+ text_scale,
362
+ pred_text_color,
363
+ text_thickness,
364
+ cv2.LINE_AA
365
+ )
366
+ cv2.putText(
367
+ extended_frame,
368
+ "GT",
369
+ (text_x, gt_bar_y + bar_height // 2 + 5),
370
+ cv2.FONT_HERSHEY_DUPLEX,
371
+ VIS_CONFIG['video_bar_text_scale'],
372
+ gt_text_color,
373
+ 1,
374
+ cv2.LINE_AA
375
+ )
376
+ cv2.putText(
377
+ extended_frame,
378
+ "Pred",
379
+ (text_x, pred_bar_y + bar_height // 2 + 5),
380
+ cv2.FONT_HERSHEY_DUPLEX,
381
+ VIS_CONFIG['video_bar_text_scale'],
382
+ pred_text_color,
383
+ 1,
384
+ cv2.LINE_AA
385
+ )
386
+
387
+ # Write frame to output video
388
+ out.write(extended_frame)
389
+ written_frames += 1
390
+ frame_idx += 1
391
+
392
+ # Release resources
393
+ cap.release()
394
+ out.release()
395
+ print(f"[✅ Saved Annotated Video]: {output_path}, Written Frames={written_frames}")
396
+ print("Note: If .avi is not playable, convert to .mp4 using FFmpeg:")
397
+ print(f"ffmpeg -i {output_path} -vcodec libx264 -acodec aac {output_path.replace('.avi', '.mp4')}")
398
+
399
+
400
+
401
+
402
+
403
+
404
+
405
+
406
+ def visualize_action_lengths(
407
+ video_id: str,
408
+ pred_segments: List[Dict],
409
+ gt_segments: List[Dict],
410
+ video_path: str,
411
+ duration: float,
412
+ save_dir: str = VIS_CONFIG['save_dir'],
413
+ frame_interval: float = VIS_CONFIG['frame_interval']
414
+ ) -> None:
415
+ """
416
+ Generate a visualization plot comparing ground truth and predicted action lengths with video frames.
417
+
418
+ Args:
419
+ video_id: Video identifier (e.g., 'my_video').
420
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
421
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
422
+ video_path: Path to the input video file.
423
+ duration: Total duration of the video in seconds.
424
+ save_dir: Directory to save the output image.
425
+ frame_interval: Time interval between sampled frames (seconds).
426
+ """
427
+ os.makedirs(save_dir, exist_ok=True)
428
+
429
+ # Calculate frame sampling times
430
+ num_frames = int(duration / frame_interval) + 1
431
+ if num_frames > VIS_CONFIG['max_frames']:
432
+ frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
433
+ num_frames = VIS_CONFIG['max_frames']
434
+ print(f"Warning: Video duration ({duration:.1f}s) requires {num_frames} frames. Adjusted frame_interval to {frame_interval:.2f}s.")
435
+
436
+ frame_times = np.linspace(0, duration, num_frames, endpoint=False)
437
+
438
+ # Load video frames
439
+ frames = []
440
+ cap = cv2.VideoCapture(video_path)
441
+ if not cap.isOpened():
442
+ print(f"Warning: Could not open video {video_path}. Using placeholder frames.")
443
+ frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times]
444
+ else:
445
+ for t in frame_times:
446
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
447
+ ret, frame = cap.read()
448
+ if ret:
449
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
450
+ # Resize frame to reduce memory usage
451
+ frame = cv2.resize(frame, (int(frame.shape[1] * 0.5), int(frame.shape[0] * 0.5)))
452
+ frames.append(frame)
453
+ else:
454
+ frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
455
+ cap.release()
456
+
457
+ # Initialize figure
458
+ fig = plt.figure(figsize=(num_frames * VIS_CONFIG['frame_scale_factor'], 6), constrained_layout=True)
459
+ gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
460
+
461
+ # Plot frames
462
+ for i, (t, frame) in enumerate(zip(frame_times, frames)):
463
+ ax = fig.add_subplot(gs[0, i])
464
+
465
+ # Check if frame falls within GT or predicted segments
466
+ gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments)
467
+ pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments)
468
+
469
+ # Set border color
470
+ border_color = None
471
+ if gt_hit and pred_hit:
472
+ border_color = VIS_CONFIG['frame_highlight_both']
473
+ elif gt_hit:
474
+ border_color = VIS_CONFIG['frame_highlight_gt']
475
+ elif pred_hit:
476
+ border_color = VIS_CONFIG['frame_highlight_pred']
477
+
478
+ ax.imshow(frame)
479
+ ax.axis('off')
480
+ if border_color:
481
+ for spine in ax.spines.values():
482
+ spine.set_edgecolor(border_color)
483
+ spine.set_linewidth(2)
484
+
485
+ ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'],
486
+ color=border_color if border_color else 'black')
487
+
488
+ # Plot ground truth bar
489
+ ax_gt = fig.add_subplot(gs[1, :])
490
+ ax_gt.set_xlim(0, duration)
491
+ ax_gt.set_ylim(0, 1)
492
+ ax_gt.axis('off')
493
+ ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'],
494
+ va='center', ha='right', weight='bold')
495
+
496
+ for seg in gt_segments:
497
+ start, end = seg['start'], seg['end']
498
+ width = end - start
499
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
500
+ ax_gt.add_patch(patches.Rectangle(
501
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'],
502
+ edgecolor='black', alpha=0.8
503
+ ))
504
+ ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center',
505
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
506
+ ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black')
507
+ ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black')
508
+
509
+ # Plot prediction bar
510
+ ax_pred = fig.add_subplot(gs[2, :])
511
+ ax_pred.set_xlim(0, duration)
512
+ ax_pred.set_ylim(0, 1)
513
+ ax_pred.axis('off')
514
+ ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'],
515
+ va='center', ha='right', weight='bold')
516
+
517
+ for seg in pred_segments:
518
+ start, end = seg['start'], seg['end']
519
+ width = end - start
520
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
521
+ ax_pred.add_patch(patches.Rectangle(
522
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'],
523
+ edgecolor='black', alpha=0.8
524
+ ))
525
+ ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center',
526
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
527
+ ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black')
528
+ ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black')
529
+
530
+ # Save plot
531
+ jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png") # Use PNG
532
+ plt.savefig(jpg_path, dpi=100, bbox_inches='tight') # Lower DPI
533
+ print(f"[✅ Saved Visualization]: {jpg_path}")
534
+ plt.close()
535
+
536
+
537
+
538
+ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
539
+ train_loader = torch.utils.data.DataLoader(train_dataset,
540
+ batch_size=opt['batch_size'], shuffle=True,
541
+ num_workers=0, pin_memory=True, drop_last=False)
542
+ epoch_cost = 0
543
+ epoch_cost_cls = 0
544
+ epoch_cost_reg = 0
545
+ epoch_cost_snip = 0
546
+
547
+ total_iter = len(train_dataset) // opt['batch_size']
548
+ cls_loss = MultiCrossEntropyLoss(focal=True)
549
+ snip_loss = MultiCrossEntropyLoss(focal=True)
550
+ for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)):
551
+ if warmup:
552
+ for g in optimizer.param_groups:
553
+ g['lr'] = n_iter * (opt['lr']) / total_iter
554
+
555
+ act_cls, act_reg, snip_cls = model(input_data.float().cuda())
556
+
557
+ act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
558
+ snip_cls.register_hook(partial(snip_loss.collect_grad, snip_label))
559
+
560
+ cost_reg = 0
561
+ cost_cls = 0
562
+
563
+ loss = cls_loss_func_(cls_loss, cls_label, act_cls)
564
+ cost_cls = loss
565
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
566
+
567
+ loss = regress_loss_func(reg_label, act_reg)
568
+ cost_reg = loss
569
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
570
+
571
+ loss = cls_loss_func_(snip_loss, snip_label, snip_cls)
572
+ cost_snip = loss
573
+ epoch_cost_snip += cost_snip.detach().cpu().numpy()
574
+
575
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip
576
+ epoch_cost += cost.detach().cpu().numpy()
577
+
578
+ optimizer.zero_grad()
579
+ cost.backward()
580
+ optimizer.step()
581
+
582
+ return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
583
+
584
+ def eval_one_epoch(opt, model, test_dataset):
585
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, test_dataset)
586
+
587
+ result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
588
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
589
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
590
+ json.dump(output_dict, outfile, indent=2)
591
+ outfile.close()
592
+
593
+ IoUmAP = evaluation_detection(opt, verbose=False)
594
+ IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:])
595
+
596
+ return cls_loss, reg_loss, tot_loss, IoUmAP_5
597
+
598
+ def train(opt):
599
+ writer = SummaryWriter()
600
+ model = MYNET(opt).cuda()
601
+
602
+ rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
603
+ optimizer = optim.Adam([{'params': model.history_unit.parameters(), 'lr': 1e-6}, {'params': rest_of_model_params}], lr=opt["lr"], weight_decay=opt["weight_decay"])
604
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt["lr_step"])
605
+
606
+ train_dataset = VideoDataSet(opt, subset="train")
607
+ test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
608
+
609
+ warmup = False
610
+
611
+ for n_epoch in range(opt['epoch']):
612
+ if n_epoch >= 1:
613
+ warmup = False
614
+
615
+ n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
616
+
617
+ writer.add_scalars('data/cost', {'train': epoch_cost / (n_iter + 1)}, n_epoch)
618
+ print("training loss(epoch %d): %.03f, cls - %f, reg - %f, snip - %f, lr - %f" % (n_epoch,
619
+ epoch_cost / (n_iter + 1),
620
+ epoch_cost_cls / (n_iter + 1),
621
+ epoch_cost_reg / (n_iter + 1),
622
+ epoch_cost_snip / (n_iter + 1),
623
+ optimizer.param_groups[-1]["lr"]))
624
+
625
+ scheduler.step()
626
+ model.eval()
627
+
628
+ cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
629
+
630
+ writer.add_scalars('data/mAP', {'test': IoUmAP_5}, n_epoch)
631
+ print("testing loss(epoch %d): %.03f, cls - %f, reg - %f, mAP Avg - %f" % (n_epoch, tot_loss, cls_loss, reg_loss, IoUmAP_5))
632
+
633
+ state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
634
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_checkpoint_" + str(n_epoch + 1) + ".pth.tar")
635
+ if IoUmAP_5 > model.best_map:
636
+ model.best_map = IoUmAP_5
637
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_ckp_best.pth.tar")
638
+
639
+ model.train()
640
+
641
+ writer.close()
642
+ return model.best_map
643
+
644
+ def eval_frame(opt, model, dataset):
645
+ test_loader = torch.utils.data.DataLoader(dataset,
646
+ batch_size=opt['batch_size'], shuffle=False,
647
+ num_workers=0, pin_memory=True, drop_last=False)
648
+
649
+ labels_cls = {}
650
+ labels_reg = {}
651
+ output_cls = {}
652
+ output_reg = {}
653
+ for video_name in dataset.video_list:
654
+ labels_cls[video_name] = []
655
+ labels_reg[video_name] = []
656
+ output_cls[video_name] = []
657
+ output_reg[video_name] = []
658
+
659
+ start_time = time.time()
660
+ total_frames = 0
661
+ epoch_cost = 0
662
+ epoch_cost_cls = 0
663
+ epoch_cost_reg = 0
664
+
665
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
666
+ act_cls, act_reg, _ = model(input_data.float().cuda())
667
+ cost_reg = 0
668
+ cost_cls = 0
669
+
670
+ loss = cls_loss_func(cls_label, act_cls)
671
+ cost_cls = loss
672
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
673
+
674
+ loss = regress_loss_func(reg_label, act_reg)
675
+ cost_reg = loss
676
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
677
+
678
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
679
+ epoch_cost += cost.detach().cpu().numpy()
680
+
681
+ act_cls = torch.softmax(act_cls, dim=-1)
682
+
683
+ total_frames += input_data.size(0)
684
+
685
+ for b in range(0, input_data.size(0)):
686
+ video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + b]
687
+ output_cls[video_name] += [act_cls[b, :].detach().cpu().numpy()]
688
+ output_reg[video_name] += [act_reg[b, :].detach().cpu().numpy()]
689
+ labels_cls[video_name] += [cls_label[b, :].numpy()]
690
+ labels_reg[video_name] += [reg_label[b, :].numpy()]
691
+
692
+ end_time = time.time()
693
+ working_time = end_time - start_time
694
+
695
+ for video_name in dataset.video_list:
696
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
697
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
698
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
699
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
700
+
701
+ cls_loss = epoch_cost_cls / n_iter
702
+ reg_loss = epoch_cost_reg / n_iter
703
+ tot_loss = epoch_cost / n_iter
704
+
705
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
706
+
707
+ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
708
+ result_dict = {}
709
+ proposal_dict = []
710
+
711
+ num_class = opt["num_of_class"]
712
+ unit_size = opt['segment_size']
713
+ threshold = opt['threshold']
714
+ anchors = opt['anchors']
715
+
716
+ for video_name in dataset.video_list:
717
+ duration = dataset.video_len[video_name]
718
+ video_time = float(dataset.video_dict[video_name]["duration"])
719
+ frame_to_time = 100.0 * video_time / duration
720
+
721
+ for idx in range(0, duration):
722
+ cls_anc = output_cls[video_name][idx]
723
+ reg_anc = output_reg[video_name][idx]
724
+
725
+ proposal_anc_dict = []
726
+ for anc_idx in range(0, len(anchors)):
727
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
728
+
729
+ if len(cls) == 0:
730
+ continue
731
+
732
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
733
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
734
+ st = ed - length
735
+
736
+ for cidx in range(0, len(cls)):
737
+ label = cls[cidx]
738
+ tmp_dict = {}
739
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
740
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
741
+ tmp_dict["label"] = dataset.label_name[label]
742
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
743
+ proposal_anc_dict.append(tmp_dict)
744
+
745
+ proposal_dict += proposal_anc_dict
746
+
747
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
748
+ result_dict[video_name] = proposal_dict
749
+ proposal_dict = []
750
+
751
+ return result_dict
752
+
753
+ def eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
754
+ model = SuppressNet(opt).cuda()
755
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
756
+ base_dict = checkpoint['state_dict']
757
+ model.load_state_dict(base_dict)
758
+ model.eval()
759
+
760
+ result_dict = {}
761
+ proposal_dict = []
762
+
763
+ num_class = opt["num_of_class"]
764
+ unit_size = opt['segment_size']
765
+ threshold = opt['threshold']
766
+ anchors = opt['anchors']
767
+
768
+ for video_name in dataset.video_list:
769
+ duration = dataset.video_len[video_name]
770
+ video_time = float(dataset.video_dict[video_name]["duration"])
771
+ frame_to_time = 100.0 * video_time / duration
772
+ conf_queue = torch.zeros((unit_size, num_class - 1))
773
+
774
+ for idx in range(0, duration):
775
+ cls_anc = output_cls[video_name][idx]
776
+ reg_anc = output_reg[video_name][idx]
777
+
778
+ proposal_anc_dict = []
779
+ for anc_idx in range(0, len(anchors)):
780
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
781
+
782
+ if len(cls) == 0:
783
+ continue
784
+
785
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
786
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
787
+ st = ed - length
788
+
789
+ for cidx in range(0, len(cls)):
790
+ label = cls[cidx]
791
+ tmp_dict = {}
792
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
793
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
794
+ tmp_dict["label"] = dataset.label_name[label]
795
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
796
+ proposal_anc_dict.append(tmp_dict)
797
+
798
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
799
+
800
+ conf_queue[:-1, :] = conf_queue[1:, :].clone()
801
+ conf_queue[-1, :] = 0
802
+ for proposal in proposal_anc_dict:
803
+ cls_idx = dataset.label_name.index(proposal['label'])
804
+ conf_queue[-1, cls_idx] = proposal["score"]
805
+
806
+ minput = conf_queue.unsqueeze(0)
807
+ suppress_conf = model(minput.cuda())
808
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
809
+
810
+ for cls in range(0, num_class - 1):
811
+ if suppress_conf[cls] > opt['sup_threshold']:
812
+ for proposal in proposal_anc_dict:
813
+ if proposal['label'] == dataset.label_name[cls]:
814
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
815
+ proposal_dict.append(proposal)
816
+
817
+ result_dict[video_name] = proposal_dict
818
+ proposal_dict = []
819
+
820
+ return result_dict
821
+
822
+ def test_frame(opt, video_name=None):
823
+ model = MYNET(opt).cuda()
824
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
825
+ base_dict = checkpoint['state_dict']
826
+ model.load_state_dict(base_dict)
827
+ model.eval()
828
+
829
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
830
+ outfile = h5py.File(opt['frame_result_file'].format(opt['exp']), 'w')
831
+
832
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
833
+
834
+ print("testing loss: %f, cls_loss: %f, reg_loss: %f" % (tot_loss, cls_loss, reg_loss))
835
+
836
+ for video_name in dataset.video_list:
837
+ o_cls = output_cls[video_name]
838
+ o_reg = output_reg[video_name]
839
+ l_cls = labels_cls[video_name]
840
+ l_reg = labels_reg[video_name]
841
+
842
+ dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
843
+ dset_predcls[:, :] = o_cls[:, :]
844
+ dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
845
+ dset_predreg[:, :] = o_reg[:, :]
846
+ dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
847
+ dset_labelcls[:, :] = l_cls[:, :]
848
+ dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
849
+ dset_labelreg[:, :] = l_reg[:, :]
850
+ outfile.close()
851
+
852
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
853
+ return cls_loss, reg_loss, tot_loss
854
+
855
+ def patch_attention(m):
856
+ forward_orig = m.forward
857
+
858
+ def wrap(*args, **kwargs):
859
+ kwargs["need_weights"] = True
860
+ kwargs["average_attn_weights"] = False
861
+ return forward_orig(*args, **kwargs)
862
+
863
+ m.forward = wrap
864
+
865
+ class SaveOutput:
866
+ def __init__(self):
867
+ self.outputs = []
868
+
869
+ def __call__(self, module, module_in, module_out):
870
+ self.outputs.append(module_out[1])
871
+
872
+ def clear(self):
873
+ self.outputs = []
874
+
875
+ def test(opt, video_name=None):
876
+ model = MYNET(opt).cuda()
877
+ checkpoint = torch.load(opt["checkpoint_path"] + "/" + opt['exp'] + "_ckp_best.pth.tar")
878
+ base_dict = checkpoint['state_dict']
879
+ model.load_state_dict(base_dict)
880
+ model.eval()
881
+
882
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
883
+
884
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
885
+
886
+ if opt["pptype"] == "nms":
887
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
888
+ if opt["pptype"] == "net":
889
+ result_dict = eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
890
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
891
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
892
+ json.dump(output_dict, outfile, indent=2)
893
+ outfile.close()
894
+
895
+ mAP = evaluation_detection(opt)
896
+
897
+ # Compare predicted and ground truth action lengths
898
+ if video_name:
899
+ print("\nComparing Predicted and Ground Truth Action Lengths for Video:", video_name)
900
+ with open(opt["video_anno"].format(opt["split"]), 'r') as f:
901
+ anno_data = json.load(f)
902
+ gt_annotations = anno_data['database'][video_name]['annotations']
903
+ duration = anno_data['database'][video_name]['duration']
904
+
905
+ gt_segments = []
906
+ for anno in gt_annotations:
907
+ start, end = anno['segment']
908
+ label = anno['label']
909
+ duration_seg = end - start
910
+ gt_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg})
911
+
912
+ pred_segments = []
913
+ for pred in result_dict[video_name]:
914
+ start, end = pred['segment']
915
+ label = pred['label']
916
+ score = pred['score']
917
+ duration_seg = end - start
918
+ pred_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg, 'score': score})
919
+
920
+ # Print comparison table
921
+ matches = []
922
+ iou_threshold = VIS_CONFIG['iou_threshold']
923
+ used_gt_indices = set()
924
+ for pred in pred_segments:
925
+ best_iou = 0
926
+ best_gt_idx = None
927
+ for gt_idx, gt in enumerate(gt_segments):
928
+ if gt_idx in used_gt_indices:
929
+ continue
930
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
931
+ if iou > best_iou and iou >= iou_threshold:
932
+ best_iou = iou
933
+ best_gt_idx = gt_idx
934
+ if best_gt_idx is not None:
935
+ matches.append({
936
+ 'pred': pred,
937
+ 'gt': gt_segments[best_gt_idx],
938
+ 'iou': best_iou
939
+ })
940
+ used_gt_indices.add(best_gt_idx)
941
+ else:
942
+ matches.append({'pred': pred, 'gt': None, 'iou': 0})
943
+
944
+ for gt_idx, gt in enumerate(gt_segments):
945
+ if gt_idx not in used_gt_indices:
946
+ matches.append({'pred': None, 'gt': gt, 'iou': 0})
947
+
948
+ print("\n{:<20} {:<30} {:<30} {:<15} {:<10}".format(
949
+ "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU"))
950
+ print("-" * 105)
951
+ for match in matches:
952
+ pred = match['pred']
953
+ gt = match['gt']
954
+ iou = match['iou']
955
+ if pred and gt:
956
+ label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
957
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
958
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
959
+ duration_diff = pred['duration'] - gt['duration']
960
+ print("{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}".format(
961
+ label, pred_str, gt_str, duration_diff, iou))
962
+ elif pred:
963
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
964
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
965
+ pred['label'], pred_str, "None", "N/A", iou))
966
+ elif gt:
967
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
968
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
969
+ gt['label'], "None", gt_str, "N/A", iou))
970
+
971
+ # Summarize
972
+ matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
973
+ avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0
974
+ avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0
975
+ print(f"\nSummary:")
976
+ print(f"- Total Predictions: {len(pred_segments)}")
977
+ print(f"- Total Ground Truth: {len(gt_segments)}")
978
+ print(f"- Matched Segments: {matched_count}")
979
+ print(f"- Average Duration Difference (Matched): {avg_duration_diff:.2f}s")
980
+ print(f"- Average IoU (Matched): {avg_iou:.2f}")
981
+
982
+ # Generate static visualization
983
+ video_path = opt.get('video_path', '')
984
+ if os.path.exists(video_path):
985
+ visualize_action_lengths(
986
+ video_id=video_name,
987
+ pred_segments=pred_segments,
988
+ gt_segments=gt_segments,
989
+ video_path=video_path,
990
+ duration=duration
991
+ )
992
+ # Generate annotated video
993
+ annotate_video_with_actions(
994
+ video_id=video_name,
995
+ pred_segments=pred_segments,
996
+ gt_segments=gt_segments,
997
+ video_path=video_path
998
+ )
999
+ else:
1000
+ print(f"Warning: Video path {video_path} not found. Skipping visualization and video annotation.")
1001
+
1002
+ return mAP
1003
+
1004
+ def test_online(opt, video_name=None):
1005
+ model = MYNET(opt).cuda()
1006
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
1007
+ base_dict = checkpoint['state_dict']
1008
+ model.load_state_dict(base_dict)
1009
+ model.eval()
1010
+
1011
+ sup_model = SuppressNet(opt).cuda()
1012
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
1013
+ base_dict = checkpoint['state_dict']
1014
+ sup_model.load_state_dict(base_dict)
1015
+ sup_model.eval()
1016
+
1017
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
1018
+ test_loader = torch.utils.data.DataLoader(dataset,
1019
+ batch_size=1, shuffle=False,
1020
+ num_workers=0, pin_memory=True, drop_last=False)
1021
+
1022
+ result_dict = {}
1023
+ proposal_dict = []
1024
+
1025
+ num_class = opt["num_of_class"]
1026
+ unit_size = opt['segment_size']
1027
+ threshold = opt['threshold']
1028
+ anchors = opt['anchors']
1029
+
1030
+ start_time = time.time()
1031
+ total_frames = 0
1032
+
1033
+ for video_name in dataset.video_list:
1034
+ input_queue = torch.zeros((unit_size, opt['feat_dim']))
1035
+ sup_queue = torch.zeros(((unit_size, num_class - 1)))
1036
+
1037
+ duration = dataset.video_len[video_name]
1038
+ video_time = float(dataset.video_dict[video_name]["duration"])
1039
+ frame_to_time = 100.0 * video_time / duration
1040
+
1041
+ for idx in range(0, duration):
1042
+ total_frames += 1
1043
+ input_queue[:-1, :] = input_queue[1:, :].clone()
1044
+ input_queue[-1:, :] = dataset._get_base_data(video_name, idx, idx + 1)
1045
+
1046
+ minput = input_queue.unsqueeze(0)
1047
+ act_cls, act_reg, _ = model(minput.cuda())
1048
+ act_cls = torch.softmax(act_cls, dim=-1)
1049
+
1050
+ cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
1051
+ reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
1052
+
1053
+ proposal_anc_dict = []
1054
+ for anc_idx in range(0, len(anchors)):
1055
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
1056
+
1057
+ if len(cls) == 0:
1058
+ continue
1059
+
1060
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
1061
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
1062
+ st = ed - length
1063
+
1064
+ for cidx in range(0, len(cls)):
1065
+ label = cls[cidx]
1066
+ tmp_dict = {}
1067
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
1068
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
1069
+ tmp_dict["label"] = dataset.label_name[label]
1070
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
1071
+ proposal_anc_dict.append(tmp_dict)
1072
+
1073
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
1074
+
1075
+ sup_queue[:-1, :] = sup_queue[1:, :].clone()
1076
+ sup_queue[-1, :] = 0
1077
+ for proposal in proposal_anc_dict:
1078
+ cls_idx = dataset.label_name.index(proposal['label'])
1079
+ sup_queue[-1, cls_idx] = proposal["score"]
1080
+
1081
+ minput = sup_queue.unsqueeze(0)
1082
+ suppress_conf = sup_model(minput.cuda())
1083
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
1084
+
1085
+ for cls in range(0, num_class - 1):
1086
+ if suppress_conf[cls] > opt['sup_threshold']:
1087
+ for proposal in proposal_anc_dict:
1088
+ if proposal['label'] == dataset.label_name[cls]:
1089
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
1090
+ proposal_dict.append(proposal)
1091
+
1092
+ result_dict[video_name] = proposal_dict
1093
+ proposal_dict = []
1094
+
1095
+ end_time = time.time()
1096
+ working_time = end_time - start_time
1097
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
1098
+
1099
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
1100
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
1101
+ json.dump(output_dict, outfile, indent=2)
1102
+ outfile.close()
1103
+
1104
+ mAP = evaluation_detection(opt)
1105
+ return mAP
1106
+
1107
+ def main(opt, video_name=None):
1108
+ max_perf = 0
1109
+ if not video_name and 'video_name' in opt:
1110
+ video_name = opt['video_name']
1111
+
1112
+ if opt['mode'] == 'train':
1113
+ max_perf = train(opt)
1114
+ if opt['mode'] == 'test':
1115
+ max_perf = test(opt, video_name=video_name)
1116
+ if opt['mode'] == 'test_frame':
1117
+ max_perf = test_frame(opt, video_name=video_name)
1118
+ if opt['mode'] == 'test_online':
1119
+ max_perf = test_online(opt, video_name=video_name)
1120
+ if opt['mode'] == 'eval':
1121
+ max_perf = evaluation_detection(opt)
1122
+
1123
+ return max_perf
1124
+
1125
+ if __name__ == '__main__':
1126
+ opt = opts.parse_opt()
1127
+ opt = vars(opt)
1128
+ if not os.path.exists(opt["checkpoint_path"]):
1129
+ os.makedirs(opt["checkpoint_path"])
1130
+ opt_file = open(opt["checkpoint_path"] + "/" + opt["exp"] + "_opts.json", "w")
1131
+ json.dump(opt, opt_file)
1132
+ opt_file.close()
1133
+
1134
+ if opt['seed'] >= 0:
1135
+ seed = opt['seed']
1136
+ torch.manual_seed(seed)
1137
+ np.random.seed(seed)
1138
+
1139
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
1140
+
1141
+ video_name = opt.get('video_name', None)
1142
+ main(opt, video_name=video_name)
1143
+ while(opt['wterm']):
1144
+ pass
short main.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+
11
+ import time
12
+ import h5py
13
+ from tqdm import tqdm
14
+ from iou_utils import *
15
+ from eval import evaluation_detection
16
+ from tensorboardX import SummaryWriter
17
+ from dataset import VideoDataSet, calc_iou
18
+ from models import MYNET, SuppressNet
19
+ from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
20
+ from loss_func import MultiCrossEntropyLoss
21
+ from functools import *
22
+
23
+ import matplotlib.pyplot as plt
24
+ import matplotlib.patches as patches
25
+ import cv2
26
+ from typing import List, Dict, Optional
27
+
28
+ from PIL import Image, ImageDraw, ImageFont
29
+ import warnings
30
+
31
+ # Visualization Configuration (Updated)
32
+ VIS_CONFIG = {
33
+ 'frame_interval': 1.0,
34
+ 'max_frames': 20,
35
+ 'save_dir': './output/visualizations',
36
+ 'video_save_dir': './output/videos',
37
+ 'gt_color': '#1f77b4', # Blue for ground truth (RGB: 31, 119, 180)
38
+ 'pred_color': '#ff7f0e', # Orange for predictions (RGB: 255, 127, 14)
39
+ 'fontsize_label': 10,
40
+ 'fontsize_title': 14,
41
+ 'frame_highlight_both': 'green',
42
+ 'frame_highlight_gt': 'red',
43
+ 'frame_highlight_pred': 'black',
44
+ 'iou_threshold': 0.3,
45
+ 'frame_scale_factor': 0.8,
46
+ 'video_text_scale': 0.5,
47
+ 'video_gt_text_color': (180, 119, 31), # BGR for OpenCV
48
+ 'video_pred_text_color': (14, 127, 255), # BGR for OpenCV
49
+ 'video_text_thickness': 1,
50
+ 'video_font_path': "./data/Poppins ExtraBold Italic 800.ttf",
51
+ 'video_font_fallback': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
52
+ 'video_pred_text_y': 0.45,
53
+ 'video_gt_text_y': 0.55,
54
+ 'video_footer_height': 150, # Increased to accommodate labels
55
+ 'video_gt_bar_y': 0.5,
56
+ 'video_pred_bar_y': 0.8,
57
+ 'video_bar_height': 0.15,
58
+ 'video_bar_text_scale': 0.7,
59
+ 'min_segment_duration': 1.0,
60
+ 'video_frame_text_y': 0.05, # Position for frame number and FPS
61
+ 'video_bar_label_x': 10, # X-position for GT/Pred labels
62
+ 'video_bar_label_scale': 0.5,
63
+ 'scroll_window_duration': 30.0, # Duration of the visible time window (seconds)
64
+ 'scroll_speed': 0.5, # Seconds to advance the window per second of video
65
+ }
66
+
67
+
68
+ def annotate_video_with_actions(
69
+ video_id: str,
70
+ pred_segments: List[Dict],
71
+ gt_segments: List[Dict],
72
+ video_path: str,
73
+ save_dir: str = VIS_CONFIG['video_save_dir'],
74
+ text_scale: float = VIS_CONFIG['video_text_scale'] * 1.5, # Increased text size by 50%
75
+ gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'],
76
+ pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'],
77
+ text_thickness: int = VIS_CONFIG['video_text_thickness']
78
+ ) -> None:
79
+ """
80
+ Annotate a video with predicted and ground truth action labels, cumulative bars, frame number, and FPS.
81
+ Use fixed 20-second windows with original bar animation, resetting bars at each window boundary.
82
+ Different colors for different action classes, no labels or timestamps on bars, increased text size.
83
+ GT and Pred text labels are on the left, with bars starting 0.5 inches (48 pixels) to the right.
84
+
85
+ Args:
86
+ video_id: Video identifier (e.g., 'my_video').
87
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
88
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
89
+ video_path: Path to the input video file.
90
+ save_dir: Directory to save the annotated video.
91
+ text_scale: Scale factor for text size in video (increased).
92
+ gt_text_color: BGR color tuple for ground truth text.
93
+ pred_text_color: BGR color tuple for predicted text.
94
+ text_thickness: Thickness of text strokes.
95
+ """
96
+ os.makedirs(save_dir, exist_ok=True)
97
+
98
+ # Open input video
99
+ cap = cv2.VideoCapture(video_path)
100
+ if not cap.isOpened():
101
+ print(f"Error: Could not open video {video_path}. Skipping video annotation.")
102
+ return
103
+
104
+ # Get video properties
105
+ fps = cap.get(cv2.CAP_PROP_FPS)
106
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
107
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
108
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
109
+ duration = total_frames / fps
110
+ print(f"Input Video: FPS={fps:.2f}, Resolution={frame_width}x{frame_height}, Total Frames={total_frames}, Duration={duration:.2f}s")
111
+
112
+ # Define output video with extended height for footer
113
+ footer_height = VIS_CONFIG['video_footer_height']
114
+ output_height = frame_height + footer_height
115
+ output_path = os.path.join(save_dir, f"annotated_{video_id}_{opt['exp']}.avi")
116
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
117
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, output_height))
118
+
119
+ if not out.isOpened():
120
+ print(f"Error: Could not initialize video writer for {output_path}. Check codec availability.")
121
+ cap.release()
122
+ return
123
+
124
+ # Filter short segments
125
+ min_duration = VIS_CONFIG['min_segment_duration']
126
+ gt_segments = [seg for seg in gt_segments if seg['duration'] >= min_duration]
127
+ pred_segments = [seg for seg in pred_segments if seg['duration'] >= min_duration]
128
+ print(f"Filtered Segments: GT={len(gt_segments)}, Pred={len(pred_segments)} (min_duration={min_duration}s)")
129
+
130
+ # Define color palette (BGR)
131
+ color_palette = [
132
+ (128, 0, 0), # Navy Blue
133
+ (60, 20, 220), # Crimson Red
134
+ (0, 128, 0), # Emerald Green
135
+ (128, 0, 128), # Royal Purple
136
+ (79, 69, 54), # Charcoal Gray
137
+ (128, 128, 0), # Teal
138
+ (0, 0, 128), # Maroon
139
+ (130, 0, 75), # Indigo
140
+ (34, 139, 34), # Forest Green
141
+ (0, 85, 204), # Burnt Orange
142
+ (149, 146, 209), # Dusty Rose
143
+ (235, 206, 135), # Sky Blue
144
+ (250, 230, 230), # Lavender
145
+ (191, 226, 159), # Seafoam Green
146
+ (185, 218, 255), # Peach
147
+ (255, 204, 204), # Periwinkle
148
+ (193, 182, 255), # Blush Pink
149
+ (201, 252, 189), # Mint Green
150
+ (144, 128, 112), # Slate Gray
151
+ (112, 25, 25), # Midnight Blue
152
+ (102, 51, 102), # Deep Plum
153
+ (0, 128, 128), # Olive Green
154
+ (171, 71, 0) # Cobalt Blue
155
+ ]
156
+
157
+ # Create color mapping for actions
158
+ action_labels = set(seg['label'] for seg in gt_segments).union(seg['label'] for seg in pred_segments)
159
+ action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
160
+ print(f"Action Color Mapping: {action_color_map}")
161
+
162
+ # Convert fallback colors to RGB for PIL
163
+ gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0]) # BGR to RGB
164
+ pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0]) # BGR to RGB
165
+
166
+ # Load font
167
+ font_path = VIS_CONFIG['video_font_path']
168
+ font_fallback = VIS_CONFIG['video_font_fallback']
169
+ font_size = int(20 * text_scale)
170
+ bar_font_size = int(20 * VIS_CONFIG['video_bar_text_scale'])
171
+ font = None
172
+ bar_font = None
173
+ if font_path:
174
+ try:
175
+ font = ImageFont.truetype(font_path, font_size)
176
+ bar_font = ImageFont.truetype(font_path, bar_font_size)
177
+ print(f"Using font: {font_path}")
178
+ except IOError:
179
+ print(f"Warning: Font {font_path} not found. Trying fallback font.")
180
+ if not font:
181
+ try:
182
+ font = ImageFont.truetype(font_fallback, font_size)
183
+ bar_font = ImageFont.truetype(font_fallback, bar_font_size)
184
+ print(f"Using fallback font: {font_fallback}")
185
+ except IOError:
186
+ print(f"Warning: Fallback font {font_fallback} not found. Using OpenCV default font.")
187
+ font = None
188
+ bar_font = None
189
+
190
+ # Fixed window configuration
191
+ window_size = 20.0 # 20-second windows
192
+ num_windows = int(np.ceil(duration / window_size))
193
+
194
+ # Define horizontal gap (0.5 inch = 48 pixels at 96 DPI)
195
+ text_bar_gap = 48 # Pixels
196
+ text_x = 10 # Fixed x-position for GT and Pred labels
197
+
198
+ frame_idx = 0
199
+ written_frames = 0
200
+ while cap.isOpened():
201
+ ret, frame = cap.read()
202
+ if not ret:
203
+ break
204
+
205
+ # Create extended frame with footer
206
+ extended_frame = np.zeros((output_height, frame_width, 3), dtype=np.uint8)
207
+ extended_frame[:frame_height, :, :] = frame
208
+ extended_frame[frame_height:, :, :] = 255 # White footer
209
+
210
+ # Calculate current timestamp
211
+ timestamp = frame_idx / fps
212
+
213
+ # Determine current window
214
+ window_idx = int(timestamp // window_size)
215
+ window_start = window_idx * window_size
216
+ window_end = min(window_start + window_size, duration)
217
+ window_duration = window_end - window_start
218
+ window_timestamp = timestamp - window_start # Relative timestamp within window
219
+
220
+ # Find active GT actions (for text overlay)
221
+ gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
222
+ gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else ""
223
+
224
+ # Find active predicted actions (for text overlay)
225
+ pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
226
+ pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else ""
227
+
228
+ # Draw GT and prediction bars in footer (within current window, using original animation)
229
+ footer_y = frame_height
230
+ gt_bar_y = footer_y + int(0.2 * footer_height) # GT bar position
231
+ pred_bar_y = footer_y + int(0.5 * footer_height) # Pred bar position
232
+ bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
233
+
234
+ # Calculate text width for GT and Pred labels to determine bar start
235
+ if font:
236
+ gt_text_bbox = bar_font.getbbox("GT")
237
+ pred_text_bbox = bar_font.getbbox("Pred")
238
+ gt_text_width = gt_text_bbox[2] - gt_text_bbox[0]
239
+ pred_text_width = pred_text_bbox[2] - pred_text_bbox[0]
240
+ else:
241
+ gt_text_size, _ = cv2.getTextSize("GT", cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
242
+ pred_text_size, _ = cv2.getTextSize("Pred", cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
243
+ gt_text_width = gt_text_size[0]
244
+ pred_text_width = pred_text_size[0]
245
+ max_text_width = max(gt_text_width, pred_text_width)
246
+ bar_start_x = text_x + max_text_width + text_bar_gap # Bars start after text + 0.5-inch gap
247
+ bar_width = frame_width - bar_start_x # Adjust bar width to fit remaining space
248
+
249
+ # Draw bars with action-specific colors
250
+ for seg in gt_segments:
251
+ if seg['start'] <= window_end and seg['end'] >= window_start:
252
+ start_t = max(seg['start'], window_start)
253
+ end_t = min(seg['end'], window_start + window_timestamp) # Original animation
254
+ start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
255
+ end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
256
+ if end_x > start_x:
257
+ cv2.rectangle(
258
+ extended_frame,
259
+ (start_x, gt_bar_y),
260
+ (end_x, gt_bar_y + bar_height),
261
+ action_color_map[seg['label']], # Action-specific color
262
+ -1
263
+ )
264
+
265
+ for seg in pred_segments:
266
+ if seg['start'] <= window_end and seg['end'] >= window_start:
267
+ start_t = max(seg['start'], window_start)
268
+ end_t = min(seg['end'], window_start + window_timestamp) # Original animation
269
+ start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
270
+ end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
271
+ if end_x > start_x:
272
+ cv2.rectangle(
273
+ extended_frame,
274
+ (start_x, pred_bar_y),
275
+ (end_x, pred_bar_y + bar_height),
276
+ action_color_map[seg['label']], # Action-specific color
277
+ -1
278
+ )
279
+
280
+ if font:
281
+ # Convert frame to PIL image
282
+ frame_rgb = cv2.cvtColor(extended_frame, cv2.COLOR_BGR2RGB)
283
+ pil_image = Image.fromarray(frame_rgb)
284
+ draw = ImageDraw.Draw(pil_image)
285
+
286
+ # Draw frame number and FPS at top center
287
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
288
+ frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
289
+ frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
290
+ frame_text_x = (frame_width - frame_text_width) // 2
291
+ draw.text((frame_text_x, 10), frame_info, font=font, fill=(0, 0, 0))
292
+
293
+ # Draw window timestamp range at top of footer
294
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
295
+ window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
296
+ window_text_width = window_text_bbox[2] - window_text_bbox[0]
297
+ window_text_x = (frame_width - window_text_width) // 2
298
+ draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
299
+
300
+ # Draw GT text in video only if there are actions
301
+ if gt_text:
302
+ gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
303
+ draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
304
+
305
+ # Draw predicted text in video only if there are actions
306
+ if pred_text:
307
+ pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
308
+ draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
309
+
310
+ # Draw GT and Pred labels in footer
311
+ draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
312
+ draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
313
+
314
+ # Convert back to OpenCV frame
315
+ extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
316
+ else:
317
+ # Fallback to OpenCV font
318
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
319
+ text_size, _ = cv2.getTextSize(frame_info, cv2.FONT_HERSHEY_DUPLEX, text_scale, text_thickness)
320
+ frame_text_x = (frame_width - text_size[0]) // 2
321
+ cv2.putText(
322
+ extended_frame,
323
+ frame_info,
324
+ (frame_text_x, 30),
325
+ cv2.FONT_HERSHEY_DUPLEX,
326
+ text_scale,
327
+ (0, 0, 0),
328
+ text_thickness,
329
+ cv2.LINE_AA
330
+ )
331
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
332
+ window_text_size, _ = cv2.getTextSize(window_info, cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
333
+ window_text_x = (frame_width - window_text_size[0]) // 2
334
+ cv2.putText(
335
+ extended_frame,
336
+ window_info,
337
+ (window_text_x, footer_y + 20),
338
+ cv2.FONT_HERSHEY_DUPLEX,
339
+ VIS_CONFIG['video_bar_text_scale'],
340
+ (0, 0, 0),
341
+ 1,
342
+ cv2.LINE_AA
343
+ )
344
+ if gt_text:
345
+ cv2.putText(
346
+ extended_frame,
347
+ gt_text,
348
+ (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
349
+ cv2.FONT_HERSHEY_DUPLEX,
350
+ text_scale,
351
+ gt_text_color,
352
+ text_thickness,
353
+ cv2.LINE_AA
354
+ )
355
+ if pred_text:
356
+ cv2.putText(
357
+ extended_frame,
358
+ pred_text,
359
+ (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
360
+ cv2.FONT_HERSHEY_DUPLEX,
361
+ text_scale,
362
+ pred_text_color,
363
+ text_thickness,
364
+ cv2.LINE_AA
365
+ )
366
+ cv2.putText(
367
+ extended_frame,
368
+ "GT",
369
+ (text_x, gt_bar_y + bar_height // 2 + 5),
370
+ cv2.FONT_HERSHEY_DUPLEX,
371
+ VIS_CONFIG['video_bar_text_scale'],
372
+ gt_text_color,
373
+ 1,
374
+ cv2.LINE_AA
375
+ )
376
+ cv2.putText(
377
+ extended_frame,
378
+ "Pred",
379
+ (text_x, pred_bar_y + bar_height // 2 + 5),
380
+ cv2.FONT_HERSHEY_DUPLEX,
381
+ VIS_CONFIG['video_bar_text_scale'],
382
+ pred_text_color,
383
+ 1,
384
+ cv2.LINE_AA
385
+ )
386
+
387
+ # Write frame to output video
388
+ out.write(extended_frame)
389
+ written_frames += 1
390
+ frame_idx += 1
391
+
392
+ # Release resources
393
+ cap.release()
394
+ out.release()
395
+ print(f"[✅ Saved Annotated Video]: {output_path}, Written Frames={written_frames}")
396
+ print("Note: If .avi is not playable, convert to .mp4 using FFmpeg:")
397
+ print(f"ffmpeg -i {output_path} -vcodec libx264 -acodec aac {output_path.replace('.avi', '.mp4')}")
398
+
399
+
400
+
401
+
402
+
403
+
404
+
405
+
406
+ def visualize_action_lengths(
407
+ video_id: str,
408
+ pred_segments: List[Dict],
409
+ gt_segments: List[Dict],
410
+ video_path: str,
411
+ duration: float,
412
+ save_dir: str = VIS_CONFIG['save_dir'],
413
+ frame_interval: float = VIS_CONFIG['frame_interval']
414
+ ) -> None:
415
+ """
416
+ Generate a visualization plot comparing ground truth and predicted action lengths with video frames.
417
+
418
+ Args:
419
+ video_id: Video identifier (e.g., 'my_video').
420
+ pred_segments: List of predicted segments with 'label', 'start', 'end', 'duration', 'score'.
421
+ gt_segments: List of ground truth segments with 'label', 'start', 'end', 'duration'.
422
+ video_path: Path to the input video file.
423
+ duration: Total duration of the video in seconds.
424
+ save_dir: Directory to save the output image.
425
+ frame_interval: Time interval between sampled frames (seconds).
426
+ """
427
+ os.makedirs(save_dir, exist_ok=True)
428
+
429
+ # Calculate frame sampling times
430
+ num_frames = int(duration / frame_interval) + 1
431
+ if num_frames > VIS_CONFIG['max_frames']:
432
+ frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
433
+ num_frames = VIS_CONFIG['max_frames']
434
+ print(f"Warning: Video duration ({duration:.1f}s) requires {num_frames} frames. Adjusted frame_interval to {frame_interval:.2f}s.")
435
+
436
+ frame_times = np.linspace(0, duration, num_frames, endpoint=False)
437
+
438
+ # Load video frames
439
+ frames = []
440
+ cap = cv2.VideoCapture(video_path)
441
+ if not cap.isOpened():
442
+ print(f"Warning: Could not open video {video_path}. Using placeholder frames.")
443
+ frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times]
444
+ else:
445
+ for t in frame_times:
446
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
447
+ ret, frame = cap.read()
448
+ if ret:
449
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
450
+ # Resize frame to reduce memory usage
451
+ frame = cv2.resize(frame, (int(frame.shape[1] * 0.5), int(frame.shape[0] * 0.5)))
452
+ frames.append(frame)
453
+ else:
454
+ frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
455
+ cap.release()
456
+
457
+ # Initialize figure
458
+ fig = plt.figure(figsize=(num_frames * VIS_CONFIG['frame_scale_factor'], 6), constrained_layout=True)
459
+ gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
460
+
461
+ # Plot frames
462
+ for i, (t, frame) in enumerate(zip(frame_times, frames)):
463
+ ax = fig.add_subplot(gs[0, i])
464
+
465
+ # Check if frame falls within GT or predicted segments
466
+ gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments)
467
+ pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments)
468
+
469
+ # Set border color
470
+ border_color = None
471
+ if gt_hit and pred_hit:
472
+ border_color = VIS_CONFIG['frame_highlight_both']
473
+ elif gt_hit:
474
+ border_color = VIS_CONFIG['frame_highlight_gt']
475
+ elif pred_hit:
476
+ border_color = VIS_CONFIG['frame_highlight_pred']
477
+
478
+ ax.imshow(frame)
479
+ ax.axis('off')
480
+ if border_color:
481
+ for spine in ax.spines.values():
482
+ spine.set_edgecolor(border_color)
483
+ spine.set_linewidth(2)
484
+
485
+ ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'],
486
+ color=border_color if border_color else 'black')
487
+
488
+ # Plot ground truth bar
489
+ ax_gt = fig.add_subplot(gs[1, :])
490
+ ax_gt.set_xlim(0, duration)
491
+ ax_gt.set_ylim(0, 1)
492
+ ax_gt.axis('off')
493
+ ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'],
494
+ va='center', ha='right', weight='bold')
495
+
496
+ for seg in gt_segments:
497
+ start, end = seg['start'], seg['end']
498
+ width = end - start
499
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
500
+ ax_gt.add_patch(patches.Rectangle(
501
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'],
502
+ edgecolor='black', alpha=0.8
503
+ ))
504
+ ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center',
505
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
506
+ ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black')
507
+ ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black')
508
+
509
+ # Plot prediction bar
510
+ ax_pred = fig.add_subplot(gs[2, :])
511
+ ax_pred.set_xlim(0, duration)
512
+ ax_pred.set_ylim(0, 1)
513
+ ax_pred.axis('off')
514
+ ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'],
515
+ va='center', ha='right', weight='bold')
516
+
517
+ for seg in pred_segments:
518
+ start, end = seg['start'], seg['end']
519
+ width = end - start
520
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
521
+ ax_pred.add_patch(patches.Rectangle(
522
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'],
523
+ edgecolor='black', alpha=0.8
524
+ ))
525
+ ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center',
526
+ fontsize=VIS_CONFIG['fontsize_label'], color='white')
527
+ ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black')
528
+ ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black')
529
+
530
+ # Save plot
531
+ jpg_path = os.path.join(save_dir, f"viz_{video_id}_{opt['exp']}.png") # Use PNG
532
+ plt.savefig(jpg_path, dpi=100, bbox_inches='tight') # Lower DPI
533
+ print(f"[✅ Saved Visualization]: {jpg_path}")
534
+ plt.close()
535
+
536
+
537
+
538
+
539
+
540
+ def eval_frame(opt, model, dataset):
541
+ test_loader = torch.utils.data.DataLoader(dataset,
542
+ batch_size=opt['batch_size'], shuffle=False,
543
+ num_workers=0, pin_memory=True, drop_last=False)
544
+
545
+ labels_cls = {}
546
+ labels_reg = {}
547
+ output_cls = {}
548
+ output_reg = {}
549
+ for video_name in dataset.video_list:
550
+ labels_cls[video_name] = []
551
+ labels_reg[video_name] = []
552
+ output_cls[video_name] = []
553
+ output_reg[video_name] = []
554
+
555
+ start_time = time.time()
556
+ total_frames = 0
557
+ epoch_cost = 0
558
+ epoch_cost_cls = 0
559
+ epoch_cost_reg = 0
560
+
561
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
562
+ act_cls, act_reg, _ = model(input_data.float().cuda())
563
+ cost_reg = 0
564
+ cost_cls = 0
565
+
566
+ loss = cls_loss_func(cls_label, act_cls)
567
+ cost_cls = loss
568
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
569
+
570
+ loss = regress_loss_func(reg_label, act_reg)
571
+ cost_reg = loss
572
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
573
+
574
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
575
+ epoch_cost += cost.detach().cpu().numpy()
576
+
577
+ act_cls = torch.softmax(act_cls, dim=-1)
578
+
579
+ total_frames += input_data.size(0)
580
+
581
+ for b in range(0, input_data.size(0)):
582
+ video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + b]
583
+ output_cls[video_name] += [act_cls[b, :].detach().cpu().numpy()]
584
+ output_reg[video_name] += [act_reg[b, :].detach().cpu().numpy()]
585
+ labels_cls[video_name] += [cls_label[b, :].numpy()]
586
+ labels_reg[video_name] += [reg_label[b, :].numpy()]
587
+
588
+ end_time = time.time()
589
+ working_time = end_time - start_time
590
+
591
+ for video_name in dataset.video_list:
592
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
593
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
594
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
595
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
596
+
597
+ cls_loss = epoch_cost_cls / n_iter
598
+ reg_loss = epoch_cost_reg / n_iter
599
+ tot_loss = epoch_cost / n_iter
600
+
601
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
602
+
603
+ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
604
+ result_dict = {}
605
+ proposal_dict = []
606
+
607
+ num_class = opt["num_of_class"]
608
+ unit_size = opt['segment_size']
609
+ threshold = opt['threshold']
610
+ anchors = opt['anchors']
611
+
612
+ for video_name in dataset.video_list:
613
+ duration = dataset.video_len[video_name]
614
+ video_time = float(dataset.video_dict[video_name]["duration"])
615
+ frame_to_time = 100.0 * video_time / duration
616
+
617
+ for idx in range(0, duration):
618
+ cls_anc = output_cls[video_name][idx]
619
+ reg_anc = output_reg[video_name][idx]
620
+
621
+ proposal_anc_dict = []
622
+ for anc_idx in range(0, len(anchors)):
623
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
624
+
625
+ if len(cls) == 0:
626
+ continue
627
+
628
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
629
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
630
+ st = ed - length
631
+
632
+ for cidx in range(0, len(cls)):
633
+ label = cls[cidx]
634
+ tmp_dict = {}
635
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
636
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
637
+ tmp_dict["label"] = dataset.label_name[label]
638
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
639
+ proposal_anc_dict.append(tmp_dict)
640
+
641
+ proposal_dict += proposal_anc_dict
642
+
643
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
644
+ result_dict[video_name] = proposal_dict
645
+ proposal_dict = []
646
+
647
+ return result_dict
648
+
649
+ def eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
650
+ model = SuppressNet(opt).cuda()
651
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
652
+ base_dict = checkpoint['state_dict']
653
+ model.load_state_dict(base_dict)
654
+ model.eval()
655
+
656
+ result_dict = {}
657
+ proposal_dict = []
658
+
659
+ num_class = opt["num_of_class"]
660
+ unit_size = opt['segment_size']
661
+ threshold = opt['threshold']
662
+ anchors = opt['anchors']
663
+
664
+ for video_name in dataset.video_list:
665
+ duration = dataset.video_len[video_name]
666
+ video_time = float(dataset.video_dict[video_name]["duration"])
667
+ frame_to_time = 100.0 * video_time / duration
668
+ conf_queue = torch.zeros((unit_size, num_class - 1))
669
+
670
+ for idx in range(0, duration):
671
+ cls_anc = output_cls[video_name][idx]
672
+ reg_anc = output_reg[video_name][idx]
673
+
674
+ proposal_anc_dict = []
675
+ for anc_idx in range(0, len(anchors)):
676
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
677
+
678
+ if len(cls) == 0:
679
+ continue
680
+
681
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
682
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
683
+ st = ed - length
684
+
685
+ for cidx in range(0, len(cls)):
686
+ label = cls[cidx]
687
+ tmp_dict = {}
688
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
689
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
690
+ tmp_dict["label"] = dataset.label_name[label]
691
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
692
+ proposal_anc_dict.append(tmp_dict)
693
+
694
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
695
+
696
+ conf_queue[:-1, :] = conf_queue[1:, :].clone()
697
+ conf_queue[-1, :] = 0
698
+ for proposal in proposal_anc_dict:
699
+ cls_idx = dataset.label_name.index(proposal['label'])
700
+ conf_queue[-1, cls_idx] = proposal["score"]
701
+
702
+ minput = conf_queue.unsqueeze(0)
703
+ suppress_conf = model(minput.cuda())
704
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
705
+
706
+ for cls in range(0, num_class - 1):
707
+ if suppress_conf[cls] > opt['sup_threshold']:
708
+ for proposal in proposal_anc_dict:
709
+ if proposal['label'] == dataset.label_name[cls]:
710
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
711
+ proposal_dict.append(proposal)
712
+
713
+ result_dict[video_name] = proposal_dict
714
+ proposal_dict = []
715
+
716
+ return result_dict
717
+
718
+ def test_frame(opt, video_name=None):
719
+ model = MYNET(opt).cuda()
720
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
721
+ base_dict = checkpoint['state_dict']
722
+ model.load_state_dict(base_dict)
723
+ model.eval()
724
+
725
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
726
+ outfile = h5py.File(opt['frame_result_file'].format(opt['exp']), 'w')
727
+
728
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
729
+
730
+ print("testing loss: %f, cls_loss: %f, reg_loss: %f" % (tot_loss, cls_loss, reg_loss))
731
+
732
+ for video_name in dataset.video_list:
733
+ o_cls = output_cls[video_name]
734
+ o_reg = output_reg[video_name]
735
+ l_cls = labels_cls[video_name]
736
+ l_reg = labels_reg[video_name]
737
+
738
+ dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
739
+ dset_predcls[:, :] = o_cls[:, :]
740
+ dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
741
+ dset_predreg[:, :] = o_reg[:, :]
742
+ dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
743
+ dset_labelcls[:, :] = l_cls[:, :]
744
+ dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
745
+ dset_labelreg[:, :] = l_reg[:, :]
746
+ outfile.close()
747
+
748
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
749
+ return cls_loss, reg_loss, tot_loss
750
+
751
+ def patch_attention(m):
752
+ forward_orig = m.forward
753
+
754
+ def wrap(*args, **kwargs):
755
+ kwargs["need_weights"] = True
756
+ kwargs["average_attn_weights"] = False
757
+ return forward_orig(*args, **kwargs)
758
+
759
+ m.forward = wrap
760
+
761
+ class SaveOutput:
762
+ def __init__(self):
763
+ self.outputs = []
764
+
765
+ def __call__(self, module, module_in, module_out):
766
+ self.outputs.append(module_out[1])
767
+
768
+ def clear(self):
769
+ self.outputs = []
770
+
771
+ def test(opt, video_name=None):
772
+ model = MYNET(opt).cuda()
773
+ checkpoint = torch.load(opt["checkpoint_path"] + "/" + opt['exp'] + "_ckp_best.pth.tar")
774
+ base_dict = checkpoint['state_dict']
775
+ model.load_state_dict(base_dict)
776
+ model.eval()
777
+
778
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
779
+
780
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
781
+
782
+ if opt["pptype"] == "nms":
783
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
784
+ if opt["pptype"] == "net":
785
+ result_dict = eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
786
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
787
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
788
+ json.dump(output_dict, outfile, indent=2)
789
+ outfile.close()
790
+
791
+ mAP = evaluation_detection(opt)
792
+
793
+ # Compare predicted and ground truth action lengths
794
+ if video_name:
795
+ print("\nComparing Predicted and Ground Truth Action Lengths for Video:", video_name)
796
+ with open(opt["video_anno"].format(opt["split"]), 'r') as f:
797
+ anno_data = json.load(f)
798
+ gt_annotations = anno_data['database'][video_name]['annotations']
799
+ duration = anno_data['database'][video_name]['duration']
800
+
801
+ gt_segments = []
802
+ for anno in gt_annotations:
803
+ start, end = anno['segment']
804
+ label = anno['label']
805
+ duration_seg = end - start
806
+ gt_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg})
807
+
808
+ pred_segments = []
809
+ for pred in result_dict[video_name]:
810
+ start, end = pred['segment']
811
+ label = pred['label']
812
+ score = pred['score']
813
+ duration_seg = end - start
814
+ pred_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration_seg, 'score': score})
815
+
816
+ # Print comparison table
817
+ matches = []
818
+ iou_threshold = VIS_CONFIG['iou_threshold']
819
+ used_gt_indices = set()
820
+ for pred in pred_segments:
821
+ best_iou = 0
822
+ best_gt_idx = None
823
+ for gt_idx, gt in enumerate(gt_segments):
824
+ if gt_idx in used_gt_indices:
825
+ continue
826
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
827
+ if iou > best_iou and iou >= iou_threshold:
828
+ best_iou = iou
829
+ best_gt_idx = gt_idx
830
+ if best_gt_idx is not None:
831
+ matches.append({
832
+ 'pred': pred,
833
+ 'gt': gt_segments[best_gt_idx],
834
+ 'iou': best_iou
835
+ })
836
+ used_gt_indices.add(best_gt_idx)
837
+ else:
838
+ matches.append({'pred': pred, 'gt': None, 'iou': 0})
839
+
840
+ for gt_idx, gt in enumerate(gt_segments):
841
+ if gt_idx not in used_gt_indices:
842
+ matches.append({'pred': None, 'gt': gt, 'iou': 0})
843
+
844
+ print("\n{:<20} {:<30} {:<30} {:<15} {:<10}".format(
845
+ "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU"))
846
+ print("-" * 105)
847
+ for match in matches:
848
+ pred = match['pred']
849
+ gt = match['gt']
850
+ iou = match['iou']
851
+ if pred and gt:
852
+ label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
853
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
854
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
855
+ duration_diff = pred['duration'] - gt['duration']
856
+ print("{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}".format(
857
+ label, pred_str, gt_str, duration_diff, iou))
858
+ elif pred:
859
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
860
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
861
+ pred['label'], pred_str, "None", "N/A", iou))
862
+ elif gt:
863
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
864
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
865
+ gt['label'], "None", gt_str, "N/A", iou))
866
+
867
+ # Summarize
868
+ matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
869
+ avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0
870
+ avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0
871
+ print(f"\nSummary:")
872
+ print(f"- Total Predictions: {len(pred_segments)}")
873
+ print(f"- Total Ground Truth: {len(gt_segments)}")
874
+ print(f"- Matched Segments: {matched_count}")
875
+ print(f"- Average Duration Difference (Matched): {avg_duration_diff:.2f}s")
876
+ print(f"- Average IoU (Matched): {avg_iou:.2f}")
877
+
878
+ # Generate static visualization
879
+ video_path = opt.get('video_path', '')
880
+ if os.path.exists(video_path):
881
+ visualize_action_lengths(
882
+ video_id=video_name,
883
+ pred_segments=pred_segments,
884
+ gt_segments=gt_segments,
885
+ video_path=video_path,
886
+ duration=duration
887
+ )
888
+ # Generate annotated video
889
+ annotate_video_with_actions(
890
+ video_id=video_name,
891
+ pred_segments=pred_segments,
892
+ gt_segments=gt_segments,
893
+ video_path=video_path
894
+ )
895
+ else:
896
+ print(f"Warning: Video path {video_path} not found. Skipping visualization and video annotation.")
897
+
898
+ return mAP
899
+
900
+ def test_online(opt, video_name=None):
901
+ model = MYNET(opt).cuda()
902
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
903
+ base_dict = checkpoint['state_dict']
904
+ model.load_state_dict(base_dict)
905
+ model.eval()
906
+
907
+ sup_model = SuppressNet(opt).cuda()
908
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
909
+ base_dict = checkpoint['state_dict']
910
+ sup_model.load_state_dict(base_dict)
911
+ sup_model.eval()
912
+
913
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
914
+ test_loader = torch.utils.data.DataLoader(dataset,
915
+ batch_size=1, shuffle=False,
916
+ num_workers=0, pin_memory=True, drop_last=False)
917
+
918
+ result_dict = {}
919
+ proposal_dict = []
920
+
921
+ num_class = opt["num_of_class"]
922
+ unit_size = opt['segment_size']
923
+ threshold = opt['threshold']
924
+ anchors = opt['anchors']
925
+
926
+ start_time = time.time()
927
+ total_frames = 0
928
+
929
+ for video_name in dataset.video_list:
930
+ input_queue = torch.zeros((unit_size, opt['feat_dim']))
931
+ sup_queue = torch.zeros(((unit_size, num_class - 1)))
932
+
933
+ duration = dataset.video_len[video_name]
934
+ video_time = float(dataset.video_dict[video_name]["duration"])
935
+ frame_to_time = 100.0 * video_time / duration
936
+
937
+ for idx in range(0, duration):
938
+ total_frames += 1
939
+ input_queue[:-1, :] = input_queue[1:, :].clone()
940
+ input_queue[-1:, :] = dataset._get_base_data(video_name, idx, idx + 1)
941
+
942
+ minput = input_queue.unsqueeze(0)
943
+ act_cls, act_reg, _ = model(minput.cuda())
944
+ act_cls = torch.softmax(act_cls, dim=-1)
945
+
946
+ cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
947
+ reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
948
+
949
+ proposal_anc_dict = []
950
+ for anc_idx in range(0, len(anchors)):
951
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
952
+
953
+ if len(cls) == 0:
954
+ continue
955
+
956
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
957
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
958
+ st = ed - length
959
+
960
+ for cidx in range(0, len(cls)):
961
+ label = cls[cidx]
962
+ tmp_dict = {}
963
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
964
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
965
+ tmp_dict["label"] = dataset.label_name[label]
966
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
967
+ proposal_anc_dict.append(tmp_dict)
968
+
969
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
970
+
971
+ sup_queue[:-1, :] = sup_queue[1:, :].clone()
972
+ sup_queue[-1, :] = 0
973
+ for proposal in proposal_anc_dict:
974
+ cls_idx = dataset.label_name.index(proposal['label'])
975
+ sup_queue[-1, cls_idx] = proposal["score"]
976
+
977
+ minput = sup_queue.unsqueeze(0)
978
+ suppress_conf = sup_model(minput.cuda())
979
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
980
+
981
+ for cls in range(0, num_class - 1):
982
+ if suppress_conf[cls] > opt['sup_threshold']:
983
+ for proposal in proposal_anc_dict:
984
+ if proposal['label'] == dataset.label_name[cls]:
985
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
986
+ proposal_dict.append(proposal)
987
+
988
+ result_dict[video_name] = proposal_dict
989
+ proposal_dict = []
990
+
991
+ end_time = time.time()
992
+ working_time = end_time - start_time
993
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
994
+
995
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
996
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
997
+ json.dump(output_dict, outfile, indent=2)
998
+ outfile.close()
999
+
1000
+ mAP = evaluation_detection(opt)
1001
+ return mAP
1002
+
1003
+ def main(opt, video_name=None):
1004
+ max_perf = 0
1005
+ if not video_name and 'video_name' in opt:
1006
+ video_name = opt['video_name']
1007
+
1008
+ if opt['mode'] == 'train':
1009
+ max_perf = train(opt)
1010
+ if opt['mode'] == 'test':
1011
+ max_perf = test(opt, video_name=video_name)
1012
+ if opt['mode'] == 'test_frame':
1013
+ max_perf = test_frame(opt, video_name=video_name)
1014
+ if opt['mode'] == 'test_online':
1015
+ max_perf = test_online(opt, video_name=video_name)
1016
+ if opt['mode'] == 'eval':
1017
+ max_perf = evaluation_detection(opt)
1018
+
1019
+ return max_perf
1020
+
1021
+ if __name__ == '__main__':
1022
+ opt = opts.parse_opt()
1023
+ opt = vars(opt)
1024
+ if not os.path.exists(opt["checkpoint_path"]):
1025
+ os.makedirs(opt["checkpoint_path"])
1026
+ opt_file = open(opt["checkpoint_path"] + "/" + opt["exp"] + "_opts.json", "w")
1027
+ json.dump(opt, opt_file)
1028
+ opt_file.close()
1029
+
1030
+ if opt['seed'] >= 0:
1031
+ seed = opt['seed']
1032
+ torch.manual_seed(seed)
1033
+ np.random.seed(seed)
1034
+
1035
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
1036
+
1037
+ video_name = opt.get('video_name', None)
1038
+ main(opt, video_name=video_name)
1039
+ while(opt['wterm']):
1040
+ pass
single prediction and Gt print main.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+
11
+ import time
12
+ import h5py
13
+ from tqdm import tqdm
14
+ from iou_utils import *
15
+ from eval import evaluation_detection
16
+ from tensorboardX import SummaryWriter
17
+ from dataset import VideoDataSet, calc_iou # Import calc_iou explicitly
18
+ from models import MYNET, SuppressNet
19
+ from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
20
+ from loss_func import MultiCrossEntropyLoss
21
+ from functools import *
22
+
23
+ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
24
+ train_loader = torch.utils.data.DataLoader(train_dataset,
25
+ batch_size=opt['batch_size'], shuffle=True,
26
+ num_workers=0, pin_memory=True, drop_last=False)
27
+ epoch_cost = 0
28
+ epoch_cost_cls = 0
29
+ epoch_cost_reg = 0
30
+ epoch_cost_snip = 0
31
+
32
+ total_iter = len(train_dataset) // opt['batch_size']
33
+ cls_loss = MultiCrossEntropyLoss(focal=True)
34
+ snip_loss = MultiCrossEntropyLoss(focal=True)
35
+ for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)):
36
+ if warmup:
37
+ for g in optimizer.param_groups:
38
+ g['lr'] = n_iter * (opt['lr']) / total_iter
39
+
40
+ act_cls, act_reg, snip_cls = model(input_data.float().cuda())
41
+
42
+ act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
43
+ snip_cls.register_hook(partial(snip_loss.collect_grad, snip_label))
44
+
45
+ cost_reg = 0
46
+ cost_cls = 0
47
+
48
+ loss = cls_loss_func_(cls_loss, cls_label, act_cls)
49
+ cost_cls = loss
50
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
51
+
52
+ loss = regress_loss_func(reg_label, act_reg)
53
+ cost_reg = loss
54
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
55
+
56
+ loss = cls_loss_func_(snip_loss, snip_label, snip_cls)
57
+ cost_snip = loss
58
+ epoch_cost_snip += cost_snip.detach().cpu().numpy()
59
+
60
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip
61
+ epoch_cost += cost.detach().cpu().numpy()
62
+
63
+ optimizer.zero_grad()
64
+ cost.backward()
65
+ optimizer.step()
66
+
67
+ return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
68
+
69
+ def eval_one_epoch(opt, model, test_dataset):
70
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, test_dataset)
71
+
72
+ result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
73
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
74
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
75
+ json.dump(output_dict, outfile, indent=2)
76
+ outfile.close()
77
+
78
+ IoUmAP = evaluation_detection(opt, verbose=False)
79
+ IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:])
80
+
81
+ return cls_loss, reg_loss, tot_loss, IoUmAP_5
82
+
83
+ def train(opt):
84
+ writer = SummaryWriter()
85
+ model = MYNET(opt).cuda()
86
+
87
+ rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
88
+ optimizer = optim.Adam([{'params': model.history_unit.parameters(), 'lr': 1e-6}, {'params': rest_of_model_params}], lr=opt["lr"], weight_decay=opt["weight_decay"])
89
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt["lr_step"])
90
+
91
+ train_dataset = VideoDataSet(opt, subset="train")
92
+ test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
93
+
94
+ warmup = False
95
+
96
+ for n_epoch in range(opt['epoch']):
97
+ if n_epoch >= 1:
98
+ warmup = False
99
+
100
+ n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
101
+
102
+ writer.add_scalars('data/cost', {'train': epoch_cost / (n_iter + 1)}, n_epoch)
103
+ print("training loss(epoch %d): %.03f, cls - %f, reg - %f, snip - %f, lr - %f" % (n_epoch,
104
+ epoch_cost / (n_iter + 1),
105
+ epoch_cost_cls / (n_iter + 1),
106
+ epoch_cost_reg / (n_iter + 1),
107
+ epoch_cost_snip / (n_iter + 1),
108
+ optimizer.param_groups[-1]["lr"]))
109
+
110
+ scheduler.step()
111
+ model.eval()
112
+
113
+ cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
114
+
115
+ writer.add_scalars('data/mAP', {'test': IoUmAP_5}, n_epoch)
116
+ print("testing loss(epoch %d): %.03f, cls - %f, reg - %f, mAP Avg - %f" % (n_epoch, tot_loss, cls_loss, reg_loss, IoUmAP_5))
117
+
118
+ state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
119
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_checkpoint_" + str(n_epoch + 1) + ".pth.tar")
120
+ if IoUmAP_5 > model.best_map:
121
+ model.best_map = IoUmAP_5
122
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_ckp_best.pth.tar")
123
+
124
+ model.train()
125
+
126
+ writer.close()
127
+ return model.best_map
128
+
129
+ def eval_frame(opt, model, dataset):
130
+ test_loader = torch.utils.data.DataLoader(dataset,
131
+ batch_size=opt['batch_size'], shuffle=False,
132
+ num_workers=0, pin_memory=True, drop_last=False)
133
+
134
+ labels_cls = {}
135
+ labels_reg = {}
136
+ output_cls = {}
137
+ output_reg = {}
138
+ for video_name in dataset.video_list:
139
+ labels_cls[video_name] = []
140
+ labels_reg[video_name] = []
141
+ output_cls[video_name] = []
142
+ output_reg[video_name] = []
143
+
144
+ start_time = time.time()
145
+ total_frames = 0
146
+ epoch_cost = 0
147
+ epoch_cost_cls = 0
148
+ epoch_cost_reg = 0
149
+
150
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
151
+ act_cls, act_reg, _ = model(input_data.float().cuda())
152
+ cost_reg = 0
153
+ cost_cls = 0
154
+
155
+ loss = cls_loss_func(cls_label, act_cls)
156
+ cost_cls = loss
157
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
158
+
159
+ loss = regress_loss_func(reg_label, act_reg)
160
+ cost_reg = loss
161
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
162
+
163
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
164
+ epoch_cost += cost.detach().cpu().numpy()
165
+
166
+ act_cls = torch.softmax(act_cls, dim=-1)
167
+
168
+ total_frames += input_data.size(0)
169
+
170
+ for b in range(0, input_data.size(0)):
171
+ video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + b]
172
+ output_cls[video_name] += [act_cls[b, :].detach().cpu().numpy()]
173
+ output_reg[video_name] += [act_reg[b, :].detach().cpu().numpy()]
174
+ labels_cls[video_name] += [cls_label[b, :].numpy()]
175
+ labels_reg[video_name] += [reg_label[b, :].numpy()]
176
+
177
+ end_time = time.time()
178
+ working_time = end_time - start_time
179
+
180
+ for video_name in dataset.video_list:
181
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
182
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
183
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
184
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
185
+
186
+ cls_loss = epoch_cost_cls / n_iter
187
+ reg_loss = epoch_cost_reg / n_iter
188
+ tot_loss = epoch_cost / n_iter
189
+
190
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
191
+
192
+ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
193
+ result_dict = {}
194
+ proposal_dict = []
195
+
196
+ num_class = opt["num_of_class"]
197
+ unit_size = opt['segment_size']
198
+ threshold = opt['threshold']
199
+ anchors = opt['anchors']
200
+
201
+ for video_name in dataset.video_list:
202
+ duration = dataset.video_len[video_name]
203
+ video_time = float(dataset.video_dict[video_name]["duration"])
204
+ frame_to_time = 100.0 * video_time / duration
205
+
206
+ for idx in range(0, duration):
207
+ cls_anc = output_cls[video_name][idx]
208
+ reg_anc = output_reg[video_name][idx]
209
+
210
+ proposal_anc_dict = []
211
+ for anc_idx in range(0, len(anchors)):
212
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
213
+
214
+ if len(cls) == 0:
215
+ continue
216
+
217
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
218
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
219
+ st = ed - length
220
+
221
+ for cidx in range(0, len(cls)):
222
+ label = cls[cidx]
223
+ tmp_dict = {}
224
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
225
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
226
+ tmp_dict["label"] = dataset.label_name[label]
227
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
228
+ proposal_anc_dict.append(tmp_dict)
229
+
230
+ proposal_dict += proposal_anc_dict
231
+
232
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
233
+ result_dict[video_name] = proposal_dict
234
+ proposal_dict = []
235
+
236
+ return result_dict
237
+
238
+ def eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
239
+ model = SuppressNet(opt).cuda()
240
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
241
+ base_dict = checkpoint['state_dict']
242
+ model.load_state_dict(base_dict)
243
+ model.eval()
244
+
245
+ result_dict = {}
246
+ proposal_dict = []
247
+
248
+ num_class = opt["num_of_class"]
249
+ unit_size = opt['segment_size']
250
+ threshold = opt['threshold']
251
+ anchors = opt['anchors']
252
+
253
+ for video_name in dataset.video_list:
254
+ duration = dataset.video_len[video_name]
255
+ video_time = float(dataset.video_dict[video_name]["duration"])
256
+ frame_to_time = 100.0 * video_time / duration
257
+ conf_queue = torch.zeros((unit_size, num_class - 1))
258
+
259
+ for idx in range(0, duration):
260
+ cls_anc = output_cls[video_name][idx]
261
+ reg_anc = output_reg[video_name][idx]
262
+
263
+ proposal_anc_dict = []
264
+ for anc_idx in range(0, len(anchors)):
265
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
266
+
267
+ if len(cls) == 0:
268
+ continue
269
+
270
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
271
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
272
+ st = ed - length
273
+
274
+ for cidx in range(0, len(cls)):
275
+ label = cls[cidx]
276
+ tmp_dict = {}
277
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
278
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
279
+ tmp_dict["label"] = dataset.label_name[label]
280
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
281
+ proposal_anc_dict.append(tmp_dict)
282
+
283
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
284
+
285
+ conf_queue[:-1, :] = conf_queue[1:, :].clone()
286
+ conf_queue[-1, :] = 0
287
+ for proposal in proposal_anc_dict:
288
+ cls_idx = dataset.label_name.index(proposal['label'])
289
+ conf_queue[-1, cls_idx] = proposal["score"]
290
+
291
+ minput = conf_queue.unsqueeze(0)
292
+ suppress_conf = model(minput.cuda())
293
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
294
+
295
+ for cls in range(0, num_class - 1):
296
+ if suppress_conf[cls] > opt['sup_threshold']:
297
+ for proposal in proposal_anc_dict:
298
+ if proposal['label'] == dataset.label_name[cls]:
299
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
300
+ proposal_dict.append(proposal)
301
+
302
+ result_dict[video_name] = proposal_dict
303
+ proposal_dict = []
304
+
305
+ return result_dict
306
+
307
+ def test_frame(opt, video_name=None):
308
+ model = MYNET(opt).cuda()
309
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
310
+ base_dict = checkpoint['state_dict']
311
+ model.load_state_dict(base_dict)
312
+ model.eval()
313
+
314
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
315
+ outfile = h5py.File(opt['frame_result_file'].format(opt['exp']), 'w')
316
+
317
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
318
+
319
+ print("testing loss: %f, cls_loss: %f, reg_loss: %f" % (tot_loss, cls_loss, reg_loss))
320
+
321
+ for video_name in dataset.video_list:
322
+ o_cls = output_cls[video_name]
323
+ o_reg = output_reg[video_name]
324
+ l_cls = labels_cls[video_name]
325
+ l_reg = labels_reg[video_name]
326
+
327
+ dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
328
+ dset_predcls[:, :] = o_cls[:, :]
329
+ dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
330
+ dset_predreg[:, :] = o_reg[:, :]
331
+ dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
332
+ dset_labelcls[:, :] = l_cls[:, :]
333
+ dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
334
+ dset_labelreg[:, :] = l_reg[:, :]
335
+ outfile.close()
336
+
337
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
338
+ return cls_loss, reg_loss, tot_loss
339
+
340
+ def patch_attention(m):
341
+ forward_orig = m.forward
342
+
343
+ def wrap(*args, **kwargs):
344
+ kwargs["need_weights"] = True
345
+ kwargs["average_attn_weights"] = False
346
+ return forward_orig(*args, **kwargs)
347
+
348
+ m.forward = wrap
349
+
350
+ class SaveOutput:
351
+ def __init__(self):
352
+ self.outputs = []
353
+
354
+ def __call__(self, module, module_in, module_out):
355
+ self.outputs.append(module_out[1])
356
+
357
+ def clear(self):
358
+ self.outputs = []
359
+
360
+ def test(opt, video_name=None):
361
+ model = MYNET(opt).cuda()
362
+ checkpoint = torch.load(opt["checkpoint_path"] + "/" + opt['exp'] + "_ckp_best.pth.tar")
363
+ base_dict = checkpoint['state_dict']
364
+ model.load_state_dict(base_dict)
365
+ model.eval()
366
+
367
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
368
+
369
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
370
+
371
+ if opt["pptype"] == "nms":
372
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
373
+ if opt["pptype"] == "net":
374
+ result_dict = eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
375
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
376
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
377
+ json.dump(output_dict, outfile, indent=2)
378
+ outfile.close()
379
+
380
+ mAP = evaluation_detection(opt)
381
+
382
+ # New: Compare predicted and ground truth action lengths
383
+ if video_name:
384
+ print("\nComparing Predicted and Ground Truth Action Lengths for Video:", video_name)
385
+ # Load ground truth annotations
386
+ with open(opt["video_anno"].format(opt["split"]), 'r') as f:
387
+ anno_data = json.load(f)
388
+ gt_annotations = anno_data['database'][video_name]['annotations']
389
+
390
+ # Extract ground truth segments
391
+ gt_segments = []
392
+ for anno in gt_annotations:
393
+ start, end = anno['segment']
394
+ label = anno['label']
395
+ duration = end - start
396
+ gt_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration})
397
+
398
+ # Extract predicted segments from result_dict
399
+ pred_segments = []
400
+ for pred in result_dict[video_name]:
401
+ start, end = pred['segment']
402
+ label = pred['label']
403
+ score = pred['score']
404
+ duration = end - start
405
+ pred_segments.append({'label': label, 'start': start, 'end': end, 'duration': duration, 'score': score})
406
+
407
+ # Match predictions to ground truth using IoU
408
+ matches = []
409
+ iou_threshold = 0.3 # Same as evaluation default for matching
410
+ used_gt_indices = set()
411
+ for pred in pred_segments:
412
+ best_iou = 0
413
+ best_gt_idx = None
414
+ for gt_idx, gt in enumerate(gt_segments):
415
+ if gt_idx in used_gt_indices:
416
+ continue
417
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
418
+ if iou > best_iou and iou >= iou_threshold:
419
+ best_iou = iou
420
+ best_gt_idx = gt_idx
421
+ if best_gt_idx is not None:
422
+ matches.append({
423
+ 'pred': pred,
424
+ 'gt': gt_segments[best_gt_idx],
425
+ 'iou': best_iou
426
+ })
427
+ used_gt_indices.add(best_gt_idx)
428
+ else:
429
+ matches.append({'pred': pred, 'gt': None, 'iou': 0})
430
+
431
+ # Include unmatched ground truth segments
432
+ for gt_idx, gt in enumerate(gt_segments):
433
+ if gt_idx not in used_gt_indices:
434
+ matches.append({'pred': None, 'gt': gt, 'iou': 0})
435
+
436
+ # Print comparison table
437
+ print("\n{:<20} {:<30} {:<30} {:<15} {:<10}".format(
438
+ "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU"))
439
+ print("-" * 105)
440
+ for match in matches:
441
+ pred = match['pred']
442
+ gt = match['gt']
443
+ iou = match['iou']
444
+ if pred and gt:
445
+ label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
446
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
447
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
448
+ duration_diff = pred['duration'] - gt['duration']
449
+ print("{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}".format(
450
+ label, pred_str, gt_str, duration_diff, iou))
451
+ elif pred:
452
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
453
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
454
+ pred['label'], pred_str, "None", "N/A", iou))
455
+ elif gt:
456
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
457
+ print("{:<20} {:<30} {:<30} {:<15} {:<10.2f}".format(
458
+ gt['label'], "None", gt_str, "N/A", iou))
459
+
460
+ # Summarize
461
+ matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
462
+ avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']])
463
+ avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0])
464
+ print(f"\nSummary:")
465
+ print(f"- Total Predictions: {len(pred_segments)}")
466
+ print(f"- Total Ground Truth: {len(gt_segments)}")
467
+ print(f"- Matched Segments: {matched_count}")
468
+ print(f"- Average Duration Difference (Matched): {avg_duration_diff:.2f}s")
469
+ print(f"- Average IoU (Matched): {avg_iou:.2f}")
470
+
471
+ return mAP
472
+
473
+ def test_online(opt, video_name=None):
474
+ model = MYNET(opt).cuda()
475
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
476
+ base_dict = checkpoint['state_dict']
477
+ model.load_state_dict(base_dict)
478
+ model.eval()
479
+
480
+ sup_model = SuppressNet(opt).cuda()
481
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
482
+ base_dict = checkpoint['state_dict']
483
+ sup_model.load_state_dict(base_dict)
484
+ sup_model.eval()
485
+
486
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
487
+ test_loader = torch.utils.data.DataLoader(dataset,
488
+ batch_size=1, shuffle=False,
489
+ num_workers=0, pin_memory=True, drop_last=False)
490
+
491
+ result_dict = {}
492
+ proposal_dict = []
493
+
494
+ num_class = opt["num_of_class"]
495
+ unit_size = opt['segment_size']
496
+ threshold = opt['threshold']
497
+ anchors = opt['anchors']
498
+
499
+ start_time = time.time()
500
+ total_frames = 0
501
+
502
+ for video_name in dataset.video_list:
503
+ input_queue = torch.zeros((unit_size, opt['feat_dim']))
504
+ sup_queue = torch.zeros(((unit_size, num_class - 1)))
505
+
506
+ duration = dataset.video_len[video_name]
507
+ video_time = float(dataset.video_dict[video_name]["duration"])
508
+ frame_to_time = 100.0 * video_time / duration
509
+
510
+ for idx in range(0, duration):
511
+ total_frames += 1
512
+ input_queue[:-1, :] = input_queue[1:, :].clone()
513
+ input_queue[-1:, :] = dataset._get_base_data(video_name, idx, idx + 1)
514
+
515
+ minput = input_queue.unsqueeze(0)
516
+ act_cls, act_reg, _ = model(minput.cuda())
517
+ act_cls = torch.softmax(act_cls, dim=-1)
518
+
519
+ cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
520
+ reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
521
+
522
+ proposal_anc_dict = []
523
+ for anc_idx in range(0, len(anchors)):
524
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
525
+
526
+ if len(cls) == 0:
527
+ continue
528
+
529
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
530
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
531
+ st = ed - length
532
+
533
+ for cidx in range(0, len(cls)):
534
+ label = cls[cidx]
535
+ tmp_dict = {}
536
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
537
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
538
+ tmp_dict["label"] = dataset.label_name[label]
539
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
540
+ proposal_anc_dict.append(tmp_dict)
541
+
542
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
543
+
544
+ sup_queue[:-1, :] = sup_queue[1:, :].clone()
545
+ sup_queue[-1, :] = 0
546
+ for proposal in proposal_anc_dict:
547
+ cls_idx = dataset.label_name.index(proposal['label'])
548
+ sup_queue[-1, cls_idx] = proposal["score"]
549
+
550
+ minput = sup_queue.unsqueeze(0)
551
+ suppress_conf = sup_model(minput.cuda())
552
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
553
+
554
+ for cls in range(0, num_class - 1):
555
+ if suppress_conf[cls] > opt['sup_threshold']:
556
+ for proposal in proposal_anc_dict:
557
+ if proposal['label'] == dataset.label_name[cls]:
558
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
559
+ proposal_dict.append(proposal)
560
+
561
+ result_dict[video_name] = proposal_dict
562
+ proposal_dict = []
563
+
564
+ end_time = time.time()
565
+ working_time = end_time - start_time
566
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
567
+
568
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
569
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
570
+ json.dump(output_dict, outfile, indent=2)
571
+ outfile.close()
572
+
573
+ mAP = evaluation_detection(opt)
574
+ return mAP
575
+
576
+ def main(opt, video_name=None):
577
+ max_perf = 0
578
+ if not video_name and 'video_name' in opt:
579
+ video_name = opt['video_name']
580
+
581
+ if opt['mode'] == 'train':
582
+ max_perf = train(opt)
583
+ if opt['mode'] == 'test':
584
+ max_perf = test(opt, video_name=video_name)
585
+ if opt['mode'] == 'test_frame':
586
+ max_perf = test_frame(opt, video_name=video_name)
587
+ if opt['mode'] == 'test_online':
588
+ max_perf = test_online(opt, video_name=video_name)
589
+ if opt['mode'] == 'eval':
590
+ max_perf = evaluation_detection(opt)
591
+
592
+ return max_perf
593
+
594
+ if __name__ == '__main__':
595
+ opt = opts.parse_opt()
596
+ opt = vars(opt)
597
+ if not os.path.exists(opt["checkpoint_path"]):
598
+ os.makedirs(opt["checkpoint_path"])
599
+ opt_file = open(opt["checkpoint_path"] + "/" + opt["exp"] + "_opts.json", "w")
600
+ json.dump(opt, opt_file)
601
+ opt_file.close()
602
+
603
+ if opt['seed'] >= 0:
604
+ seed = opt['seed']
605
+ torch.manual_seed(seed)
606
+ np.random.seed(seed)
607
+
608
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
609
+
610
+ video_name = opt.get('video_name', None)
611
+ main(opt, video_name=video_name)
612
+ while(opt['wterm']):
613
+ pass
single result dataset.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import h5py
3
+ import json
4
+ import torch
5
+ import torch.utils.data as data
6
+ import os
7
+ import pickle
8
+ from multiprocessing import Pool
9
+
10
+ def load_json(file):
11
+ with open(file) as json_file:
12
+ data = json.load(json_file)
13
+ return data
14
+
15
+ def calc_iou(a, b):
16
+ st = a[0] - a[1]
17
+ ed = a[0]
18
+ target_st = b[0] - b[1]
19
+ target_ed = b[0]
20
+ sst = min(st, target_st)
21
+ led = max(ed, target_ed)
22
+ lst = max(st, target_st)
23
+ sed = min(ed, target_ed)
24
+ iou = (sed - lst) / max(led - sst, 1)
25
+ return iou
26
+
27
+ def box_include(y, target):
28
+ st = y[0] - y[1]
29
+ ed = y[0]
30
+ target_st = target[0] - target[1]
31
+ target_ed = target[0]
32
+ detection_point = target_st
33
+ if ed > detection_point and target_st < st and target_ed > ed:
34
+ return True
35
+ return False
36
+
37
+ class VideoDataSet(data.Dataset):
38
+ def __init__(self, opt, subset="train", video_name=None):
39
+ self.subset = subset
40
+ self.mode = opt["mode"]
41
+ self.predefined_fps = opt["predefined_fps"]
42
+ self.video_anno_path = opt["video_anno"].format(opt["split"])
43
+ self.video_len_path = opt["video_len_file"].format(self.subset + '_' + opt["setup"])
44
+ self.num_of_class = opt["num_of_class"]
45
+ self.segment_size = opt["segment_size"]
46
+ self.label_name = []
47
+ self.match_score = {}
48
+ self.match_score_end = {}
49
+ self.match_length = {}
50
+ self.gt_action = {}
51
+ self.cls_label = {}
52
+ self.reg_label = {}
53
+ self.snip_label = {}
54
+ self.inputs = []
55
+ self.inputs_all = []
56
+ self.data_rescale = opt["data_rescale"]
57
+ self.anchors = opt["anchors"]
58
+ self.pos_threshold = opt["pos_threshold"]
59
+ self.single_video_name = video_name
60
+
61
+ self._getDatasetDict()
62
+ self._loadFeaturelen(opt)
63
+ self._getMatchScore()
64
+ self._makeInputSeq()
65
+ self._loadPropLabel(opt['proposal_label_file'].format(self.subset + '_' + opt["setup"]))
66
+
67
+ if self.subset == "train":
68
+ if opt['data_format'] == "h5":
69
+ feature_rgb_file = h5py.File(opt["video_feature_rgb_train"], 'r')
70
+ self.feature_rgb_file = {}
71
+ keys = self.video_list
72
+ for vidx in range(len(keys)):
73
+ if keys[vidx] not in feature_rgb_file:
74
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_train']}")
75
+ self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:])
76
+ if opt['rgb_only']:
77
+ self.feature_flow_file = None
78
+ else:
79
+ self.feature_flow_file = {}
80
+ feature_flow_file = h5py.File(opt["video_feature_flow_train"], 'r')
81
+ for vidx in range(len(keys)):
82
+ if keys[vidx] not in feature_flow_file:
83
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_train']}")
84
+ self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:])
85
+ elif opt['data_format'] == "pickle":
86
+ feature_All = pickle.load(open(opt["video_feature_all_train"], 'rb'))
87
+ self.feature_rgb_file = {}
88
+ self.feature_flow_file = {}
89
+ keys = self.video_list
90
+ for vidx in range(len(keys)):
91
+ if keys[vidx] not in feature_All:
92
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_train']}")
93
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
94
+ self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
95
+ elif opt['data_format'] == "npz":
96
+ feature_All = {}
97
+ self.feature_rgb_file = {}
98
+ self.feature_flow_file = {}
99
+ for file in self.video_list:
100
+ feature_path = opt["video_feature_all_train"] + file + '.npz'
101
+ if not os.path.exists(feature_path):
102
+ raise ValueError(f"Feature file {feature_path} not found")
103
+ feature_All[file] = np.load(feature_path)['feats']
104
+ keys = self.video_list
105
+ for vidx in range(len(keys)):
106
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
107
+ self.feature_flow_file = None
108
+ elif opt['data_format'] == "npz_i3d":
109
+ feature_All = {}
110
+ self.feature_rgb_file = {}
111
+ self.feature_flow_file = {}
112
+ for file in self.video_list:
113
+ feature_path = opt["video_feature_all_train"] + file + '.npz'
114
+ if not os.path.exists(feature_path):
115
+ raise ValueError(f"Feature file {feature_path} not found")
116
+ feature_All[file] = np.load(feature_path)
117
+ keys = self.video_list
118
+ for vidx in range(len(keys)):
119
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
120
+ self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
121
+ elif opt['data_format'] == "pt":
122
+ feature_All = {}
123
+ self.feature_rgb_file = {}
124
+ self.feature_flow_file = {}
125
+ for file in self.video_list:
126
+ feature_path = opt["video_feature_all_train"] + file + '.pt'
127
+ if not os.path.exists(feature_path):
128
+ raise ValueError(f"Feature file {feature_path} not found")
129
+ feature_All[file] = torch.load(feature_path)
130
+ keys = self.video_list
131
+ for vidx in range(len(keys)):
132
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
133
+ self.feature_flow_file = None
134
+ else:
135
+ if opt['data_format'] == "h5":
136
+ feature_rgb_file = h5py.File(opt["video_feature_rgb_test"], 'r')
137
+ self.feature_rgb_file = {}
138
+ keys = self.video_list
139
+ for vidx in range(len(keys)):
140
+ if keys[vidx] not in feature_rgb_file:
141
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_test']}")
142
+ self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:])
143
+ if opt['rgb_only']:
144
+ self.feature_flow_file = None
145
+ else:
146
+ self.feature_flow_file = {}
147
+ feature_flow_file = h5py.File(opt["video_feature_flow_test"], 'r')
148
+ for vidx in range(len(keys)):
149
+ if keys[vidx] not in feature_flow_file:
150
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_test']}")
151
+ self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:])
152
+ elif opt['data_format'] == "pickle":
153
+ feature_All = pickle.load(open(opt["video_feature_all_test"], 'rb'))
154
+ self.feature_rgb_file = {}
155
+ self.feature_flow_file = {}
156
+ keys = self.video_list
157
+ for vidx in range(len(keys)):
158
+ if keys[vidx] not in feature_All:
159
+ raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_test']}")
160
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
161
+ self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
162
+ elif opt['data_format'] == "npz":
163
+ feature_All = {}
164
+ self.feature_rgb_file = {}
165
+ self.feature_flow_file = {}
166
+ for file in self.video_list:
167
+ feature_path = opt["video_feature_all_test"] + file + '.npz'
168
+ if not os.path.exists(feature_path):
169
+ raise ValueError(f"Feature file {feature_path} not found")
170
+ feature_All[file] = np.load(feature_path)['feats']
171
+ keys = self.video_list
172
+ for vidx in range(len(keys)):
173
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
174
+ self.feature_flow_file = None
175
+ elif opt['data_format'] == "npz_i3d":
176
+ feature_All = {}
177
+ self.feature_rgb_file = {}
178
+ self.feature_flow_file = {}
179
+ for file in self.video_list:
180
+ feature_path = opt["video_feature_all_test"] + file + '.npz'
181
+ if not os.path.exists(feature_path):
182
+ raise ValueError(f"Feature file {feature_path} not found")
183
+ feature_All[file] = np.load(feature_path)
184
+ keys = self.video_list
185
+ for vidx in range(len(keys)):
186
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
187
+ self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
188
+ elif opt['data_format'] == "pt":
189
+ feature_All = {}
190
+ self.feature_rgb_file = {}
191
+ self.feature_flow_file = {}
192
+ for file in self.video_list:
193
+ feature_path = opt["video_feature_all_test"] + file + '.pt'
194
+ if not os.path.exists(feature_path):
195
+ raise ValueError(f"Feature file {feature_path} not found")
196
+ feature_All[file] = torch.load(feature_path)
197
+ keys = self.video_list
198
+ for vidx in range(len(keys)):
199
+ self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
200
+ self.feature_flow_file = None
201
+
202
+ def _loadFeaturelen(self, opt):
203
+ if os.path.exists(self.video_len_path):
204
+ self.video_len = load_json(self.video_len_path)
205
+ return
206
+
207
+ self.video_len = {}
208
+ if self.subset == "train":
209
+ if opt['data_format'] == "h5":
210
+ feature_file = h5py.File(opt["video_feature_rgb_train"], 'r')
211
+ elif opt['data_format'] == "pickle":
212
+ feature_file = pickle.load(open(opt["video_feature_all_train"], 'rb'))
213
+ elif opt['data_format'] == "npz":
214
+ feature_file = {}
215
+ for file in self.video_list:
216
+ feature_file[file] = np.load(opt["video_feature_all_train"] + file + '.npz')['feats']
217
+ elif opt['data_format'] == "npz_i3d":
218
+ feature_file = {}
219
+ for file in self.video_list:
220
+ feature_file[file] = np.load(opt["video_feature_all_train"] + file + '.npz')
221
+ elif opt['data_format'] == "pt":
222
+ feature_file = {}
223
+ for file in self.video_list:
224
+ feature_file[file] = torch.load(opt["video_feature_all_train"] + file + '.pt')
225
+ else:
226
+ if opt['data_format'] == "h5":
227
+ feature_file = h5py.File(opt["video_feature_rgb_test"], 'r')
228
+ elif opt['data_format'] == "pickle":
229
+ feature_file = pickle.load(open(opt["video_feature_all_test"], 'rb'))
230
+ elif opt['data_format'] == "npz":
231
+ feature_file = {}
232
+ for file in self.video_list:
233
+ feature_file[file] = np.load(opt["video_feature_all_test"] + file + '.npz')['feats']
234
+ elif opt['data_format'] == "npz_i3d":
235
+ feature_file = {}
236
+ for file in self.video_list:
237
+ feature_file[file] = np.load(opt["video_feature_all_test"] + file + '.npz')
238
+ elif opt['data_format'] == "pt":
239
+ feature_file = {}
240
+ for file in self.video_list:
241
+ feature_file[file] = torch.load(opt["video_feature_all_test"] + file + '.pt')
242
+
243
+ keys = self.video_list
244
+ if opt['data_format'] == "h5":
245
+ for vidx in range(len(keys)):
246
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
247
+ elif opt['data_format'] == "pickle":
248
+ for vidx in range(len(keys)):
249
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
250
+ elif opt['data_format'] == "npz":
251
+ for vidx in range(len(keys)):
252
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
253
+ elif opt['data_format'] == "npz_i3d":
254
+ for vidx in range(len(keys)):
255
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
256
+ elif opt['data_format'] == "pt":
257
+ for vidx in range(len(keys)):
258
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
259
+ outfile = open(self.video_len_path, "w")
260
+ json.dump(self.video_len, outfile, indent=2)
261
+ outfile.close()
262
+
263
+ def _getDatasetDict(self):
264
+ anno_database = load_json(self.video_anno_path)
265
+ anno_database = anno_database['database']
266
+ self.video_dict = {}
267
+ if self.single_video_name:
268
+ if self.single_video_name in anno_database:
269
+ video_info = anno_database[self.single_video_name]
270
+ video_subset = video_info['subset']
271
+ if self.subset == "full" or self.subset in video_subset:
272
+ self.video_dict[self.single_video_name] = video_info
273
+ for seg in video_info['annotations']:
274
+ if not seg['label'] in self.label_name:
275
+ self.label_name.append(seg['label'])
276
+ else:
277
+ raise ValueError(f"Video {self.single_video_name} not found in annotation database")
278
+ else:
279
+ for video_name in anno_database:
280
+ video_info = anno_database[video_name]
281
+ video_subset = anno_database[video_name]['subset']
282
+ if self.subset == "full" or self.subset in video_subset:
283
+ self.video_dict[video_name] = video_info
284
+ for seg in video_info['annotations']:
285
+ if not seg['label'] in self.label_name:
286
+ self.label_name.append(seg['label'])
287
+
288
+ # Ensure all 22 EGTEA action classes are included
289
+ expected_labels = [
290
+ 'Clean/Wipe', 'Close', 'Compress', 'Crack', 'Cut', 'Divide/Pull Apart',
291
+ 'Dry', 'Inspect/Read', 'Mix', 'Move Around', 'Open', 'Operate', 'Other',
292
+ 'Pour', 'Put', 'Squeeze', 'Take', 'Transfer', 'Turn off', 'Turn on', 'Wash',
293
+ 'Spread' # Assumed missing label; replace with actual label if known
294
+ ]
295
+ for label in expected_labels:
296
+ if label not in self.label_name:
297
+ self.label_name.append(label)
298
+
299
+ self.label_name.sort()
300
+ self.video_list = list(self.video_dict.keys())
301
+ print(f"Labels in dataset.label_name: {self.label_name}")
302
+ print(f"Number of labels: {len(self.label_name)}, Expected: {self.num_of_class-1}")
303
+ print(f"{self.subset} subset video numbers: {len(self.video_list)}")
304
+
305
+ def _getMatchScore(self):
306
+ self.action_end_count = torch.zeros(2)
307
+ for index in range(0, len(self.video_list)):
308
+ video_name = self.video_list[index]
309
+ video_info = self.video_dict[video_name]
310
+ video_labels = video_info['annotations']
311
+ gt_bbox = []
312
+ gt_edlen = []
313
+
314
+ second_to_frame = self.video_len[video_name] / float(video_info['duration'])
315
+ for j in range(len(video_labels)):
316
+ tmp_info = video_labels[j]
317
+ tmp_start = tmp_info['segment'][0] * second_to_frame
318
+ tmp_end = tmp_info['segment'][1] * second_to_frame
319
+ tmp_label = self.label_name.index(tmp_info['label'])
320
+ gt_bbox.append([tmp_start, tmp_end, tmp_label])
321
+ gt_edlen.append([gt_bbox[-1][1], gt_bbox[-1][1] - gt_bbox[-1][0], tmp_label])
322
+
323
+ gt_bbox = np.array(gt_bbox)
324
+ gt_edlen = np.array(gt_edlen)
325
+ self.gt_action[video_name] = gt_edlen
326
+
327
+ match_score = np.zeros((self.video_len[video_name], self.num_of_class - 1), dtype=np.float32)
328
+ for idx in range(gt_bbox.shape[0]):
329
+ ed = int(gt_bbox[idx, 1]) + 1
330
+ st = int(gt_bbox[idx, 0])
331
+ match_score[st:ed, int(gt_bbox[idx, 2])] = idx + 1
332
+ self.match_score[video_name] = match_score
333
+
334
+ def _makeInputSeq(self):
335
+ data_idx = 0
336
+ for index in range(0, len(self.video_list)):
337
+ video_name = self.video_list[index]
338
+ duration = self.match_score[video_name].shape[0]
339
+ for i in range(1, duration + 1):
340
+ st = i - self.segment_size
341
+ ed = i
342
+ self.inputs_all.append([video_name, st, ed, data_idx])
343
+ data_idx += 1
344
+
345
+ self.inputs = self.inputs_all.copy()
346
+ print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
347
+
348
+ def _makePropLabelUnit(self, i):
349
+ video_name = self.inputs_all[i][0]
350
+ st = self.inputs_all[i][1]
351
+ ed = self.inputs_all[i][2]
352
+ cls_anc = []
353
+ reg_anc = []
354
+
355
+ for j in range(0, len(self.anchors)):
356
+ v1 = np.zeros(self.num_of_class)
357
+ v1[-1] = 1
358
+ v2 = np.zeros(2)
359
+ v2[-1] = -1e3
360
+ y_box = [ed - 1, self.anchors[j]]
361
+
362
+ subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[j], ed)
363
+ idx_list = []
364
+ for ii in range(0, subset_label.shape[0]):
365
+ for jj in range(0, subset_label.shape[1]):
366
+ idx = int(subset_label[ii, jj])
367
+ if idx > 0 and idx - 1 not in idx_list:
368
+ idx_list.append(idx - 1)
369
+
370
+ for idx in idx_list:
371
+ target_box = self.gt_action[video_name][idx]
372
+ cls = int(target_box[2])
373
+ iou = calc_iou(y_box, target_box)
374
+ if iou >= self.pos_threshold or (j == len(self.anchors) - 1 and box_include(y_box, target_box)) or (j == 0 and box_include(target_box, y_box)):
375
+ v1[cls] = 1
376
+ v1[-1] = 0
377
+ v2[0] = 1.0 * (target_box[0] - y_box[0]) / self.anchors[j]
378
+ v2[1] = np.log(1.0 * max(1, target_box[1]) / y_box[1])
379
+
380
+ cls_anc.append(v1)
381
+ reg_anc.append(v2)
382
+
383
+ v0 = np.zeros(self.num_of_class)
384
+ v0[-1] = 1
385
+ segment_size = ed - st
386
+ y_box = [ed - 1, self.anchors[-1]]
387
+ subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[-1], ed)
388
+ idx_list = []
389
+ for ii in range(0, subset_label.shape[0]):
390
+ for jj in range(0, subset_label.shape[1]):
391
+ idx = int(subset_label[ii, jj])
392
+ if idx > 0 and idx - 1 not in idx_list:
393
+ idx_list.append(idx - 1)
394
+
395
+ for idx in idx_list:
396
+ target_box = self.gt_action[video_name][idx]
397
+ cls = int(target_box[2])
398
+ iou = calc_iou(y_box, target_box)
399
+ if iou >= 0:
400
+ v0[cls] = 1
401
+ v0[-1] = 0
402
+
403
+ cls_anc = np.stack(cls_anc, axis=0)
404
+ reg_anc = np.stack(reg_anc, axis=0)
405
+ cls_snip = np.array(v0)
406
+ return cls_anc, reg_anc, cls_snip
407
+
408
+ def _loadPropLabel(self, filename):
409
+ if os.path.exists(filename):
410
+ prop_label_file = h5py.File(filename, 'r')
411
+ self.cls_label = np.array(prop_label_file['cls_label'][:])
412
+ self.reg_label = np.array(prop_label_file['reg_label'][:])
413
+ self.snip_label = np.array(prop_label_file['snip_label'][:])
414
+ prop_label_file.close()
415
+ self.action_frame_count = np.sum(self.cls_label.reshape((-1, self.cls_label.shape[-1])), axis=0)
416
+ self.action_frame_count = torch.Tensor(self.action_frame_count)
417
+ return
418
+
419
+ pool = Pool(os.cpu_count() // 2)
420
+ labels = pool.map(self._makePropLabelUnit, range(0, len(self.inputs_all)))
421
+ pool.close()
422
+ pool.join()
423
+
424
+ cls_label = []
425
+ reg_label = []
426
+ snip_label = []
427
+ for i in range(0, len(labels)):
428
+ cls_label.append(labels[i][0])
429
+ reg_label.append(labels[i][1])
430
+ snip_label.append(labels[i][2])
431
+ self.cls_label = np.stack(cls_label, axis=0)
432
+ self.reg_label = np.stack(reg_label, axis=0)
433
+ self.snip_label = np.stack(snip_label, axis=0)
434
+
435
+ outfile = h5py.File(filename, 'w')
436
+ dset_cls = outfile.create_dataset('/cls_label', self.cls_label.shape, maxshape=self.cls_label.shape, chunks=True, dtype=np.float32)
437
+ dset_cls[:, :] = self.cls_label[:, :]
438
+ dset_reg = outfile.create_dataset('/reg_label', self.reg_label.shape, maxshape=self.reg_label.shape, chunks=True, dtype=np.float32)
439
+ dset_reg[:, :] = self.reg_label[:, :]
440
+ dset_snip = outfile.create_dataset('/snip_label', self.snip_label.shape, maxshape=self.snip_label.shape, chunks=True, dtype=np.float32)
441
+ dset_snip[:, :] = self.snip_label[:, :]
442
+ outfile.close()
443
+
444
+ return
445
+
446
+ def __getitem__(self, index):
447
+ video_name, st, ed, data_idx = self.inputs[index]
448
+ if st >= 0:
449
+ feature = self._get_base_data(video_name, st, ed)
450
+ else:
451
+ feature = self._get_base_data(video_name, 0, ed)
452
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, -st, 0), 0)
453
+ feature = padfunc2d(feature)
454
+
455
+ cls_label = torch.Tensor(self.cls_label[data_idx])
456
+ reg_label = torch.Tensor(self.reg_label[data_idx])
457
+ snip_label = torch.Tensor(self.snip_label[data_idx])
458
+
459
+ return feature, cls_label, reg_label, snip_label
460
+
461
+ def _get_base_data(self, video_name, st, ed):
462
+ feature_rgb = self.feature_rgb_file[video_name]
463
+ feature_rgb = feature_rgb[st:ed, :]
464
+
465
+ if self.feature_flow_file is not None:
466
+ feature_flow = self.feature_flow_file[video_name]
467
+ feature_flow = feature_flow[st:ed, :]
468
+ feature = np.append(feature_rgb, feature_flow, axis=1)
469
+ else:
470
+ feature = feature_rgb
471
+ feature = torch.from_numpy(np.array(feature))
472
+
473
+ return feature
474
+
475
+ def _get_train_label_with_class(self, video_name, st, ed):
476
+ duration = len(self.match_score[video_name])
477
+ st_padding = 0
478
+ ed_padding = 0
479
+ if st < 0:
480
+ st_padding = -st
481
+ st = 0
482
+ if ed > duration:
483
+ ed_padding = ed - duration
484
+ ed = duration
485
+
486
+ match_score = torch.Tensor(self.match_score[video_name][st:ed])
487
+ if st_padding > 0:
488
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, st_padding, 0), 0)
489
+ match_score = padfunc2d(match_score)
490
+ if ed_padding > 0:
491
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, 0, ed_padding), 0)
492
+ match_score = padfunc2d(match_score)
493
+ return match_score
494
+
495
+ def __len__(self):
496
+ return len(self.inputs)
497
+
498
+ def reset_sample(self):
499
+ self.inputs = self.inputs_all.copy()
500
+
501
+ def select_sample(self, idx):
502
+ inputs = [self.inputs_all[i] for i in idx]
503
+ self.inputs = inputs.copy()
504
+ return
505
+
506
+ class SuppressDataSet(data.Dataset):
507
+ def __init__(self, opt, subset="train"):
508
+ self.subset = subset
509
+ self.mode = opt["mode"]
510
+ self.data_file = h5py.File(opt["suppress_label_file"].format(self.subset + "_" + opt['setup']), 'r')
511
+ self.video_list = list(self.data_file.keys())
512
+ self.inputs = []
513
+ for index in range(0, len(self.video_list)):
514
+ video_name = self.video_list[index]
515
+ duration = self.data_file[video_name + '/input'].shape[0]
516
+ for i in range(0, duration):
517
+ self.inputs.append([video_name, i])
518
+
519
+ print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
520
+
521
+ def __getitem__(self, index):
522
+ video_name, idx = self.inputs[index]
523
+
524
+ input_seq = self.data_file[video_name + '/input'][idx]
525
+ label = self.data_file[video_name + '/label'][idx]
526
+
527
+ input_seq = torch.from_numpy(input_seq)
528
+ label = torch.from_numpy(label)
529
+
530
+ return input_seq, label
531
+
532
+ def __len__(self):
533
+ return len(self.inputs)
single result main.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+
11
+ import time
12
+ import h5py
13
+ from tqdm import tqdm
14
+ from iou_utils import *
15
+ from eval import evaluation_detection
16
+ from tensorboardX import SummaryWriter
17
+ from dataset import VideoDataSet
18
+ from models import MYNET, SuppressNet
19
+ from loss_func import cls_loss_func, cls_loss_func_, regress_loss_func
20
+ from loss_func import MultiCrossEntropyLoss
21
+ from functools import *
22
+
23
+ def train_one_epoch(opt, model, train_dataset, optimizer, warmup=False):
24
+ train_loader = torch.utils.data.DataLoader(train_dataset,
25
+ batch_size=opt['batch_size'], shuffle=True,
26
+ num_workers=0, pin_memory=True, drop_last=False)
27
+ epoch_cost = 0
28
+ epoch_cost_cls = 0
29
+ epoch_cost_reg = 0
30
+ epoch_cost_snip = 0
31
+
32
+ total_iter = len(train_dataset) // opt['batch_size']
33
+ cls_loss = MultiCrossEntropyLoss(focal=True)
34
+ snip_loss = MultiCrossEntropyLoss(focal=True)
35
+ for n_iter, (input_data, cls_label, reg_label, snip_label) in enumerate(tqdm(train_loader)):
36
+ if warmup:
37
+ for g in optimizer.param_groups:
38
+ g['lr'] = n_iter * (opt['lr']) / total_iter
39
+
40
+ act_cls, act_reg, snip_cls = model(input_data.float().cuda())
41
+
42
+ act_cls.register_hook(partial(cls_loss.collect_grad, cls_label))
43
+ snip_cls.register_hook(partial(snip_loss.collect_grad, snip_label))
44
+
45
+ cost_reg = 0
46
+ cost_cls = 0
47
+
48
+ loss = cls_loss_func_(cls_loss, cls_label, act_cls)
49
+ cost_cls = loss
50
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
51
+
52
+ loss = regress_loss_func(reg_label, act_reg)
53
+ cost_reg = loss
54
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
55
+
56
+ loss = cls_loss_func_(snip_loss, snip_label, snip_cls)
57
+ cost_snip = loss
58
+ epoch_cost_snip += cost_snip.detach().cpu().numpy()
59
+
60
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg + opt['gamma'] * cost_snip
61
+ epoch_cost += cost.detach().cpu().numpy()
62
+
63
+ optimizer.zero_grad()
64
+ cost.backward()
65
+ optimizer.step()
66
+
67
+ return n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip
68
+
69
+ def eval_one_epoch(opt, model, test_dataset):
70
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, test_dataset)
71
+
72
+ result_dict = eval_map_nms(opt, test_dataset, output_cls, output_reg, labels_cls, labels_reg)
73
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
74
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
75
+ json.dump(output_dict, outfile, indent=2)
76
+ outfile.close()
77
+
78
+ IoUmAP = evaluation_detection(opt, verbose=False)
79
+ IoUmAP_5 = sum(IoUmAP[0:]) / len(IoUmAP[0:])
80
+
81
+ return cls_loss, reg_loss, tot_loss, IoUmAP_5
82
+
83
+ def train(opt):
84
+ writer = SummaryWriter()
85
+ model = MYNET(opt).cuda()
86
+
87
+ rest_of_model_params = [param for name, param in model.named_parameters() if "history_unit" not in name]
88
+ optimizer = optim.Adam([{'params': model.history_unit.parameters(), 'lr': 1e-6}, {'params': rest_of_model_params}], lr=opt["lr"], weight_decay=opt["weight_decay"])
89
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt["lr_step"])
90
+
91
+ train_dataset = VideoDataSet(opt, subset="train")
92
+ test_dataset = VideoDataSet(opt, subset=opt['inference_subset'])
93
+
94
+ warmup = False
95
+
96
+ for n_epoch in range(opt['epoch']):
97
+ if n_epoch >= 1:
98
+ warmup = False
99
+
100
+ n_iter, epoch_cost, epoch_cost_cls, epoch_cost_reg, epoch_cost_snip = train_one_epoch(opt, model, train_dataset, optimizer, warmup)
101
+
102
+ writer.add_scalars('data/cost', {'train': epoch_cost / (n_iter + 1)}, n_epoch)
103
+ print("training loss(epoch %d): %.03f, cls - %f, reg - %f, snip - %f, lr - %f" % (n_epoch,
104
+ epoch_cost / (n_iter + 1),
105
+ epoch_cost_cls / (n_iter + 1),
106
+ epoch_cost_reg / (n_iter + 1),
107
+ epoch_cost_snip / (n_iter + 1),
108
+ optimizer.param_groups[-1]["lr"]))
109
+
110
+ scheduler.step()
111
+ model.eval()
112
+
113
+ cls_loss, reg_loss, tot_loss, IoUmAP_5 = eval_one_epoch(opt, model, test_dataset)
114
+
115
+ writer.add_scalars('data/mAP', {'test': IoUmAP_5}, n_epoch)
116
+ print("testing loss(epoch %d): %.03f, cls - %f, reg - %f, mAP Avg - %f" % (n_epoch, tot_loss, cls_loss, reg_loss, IoUmAP_5))
117
+
118
+ state = {'epoch': n_epoch + 1, 'state_dict': model.state_dict()}
119
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_checkpoint_" + str(n_epoch + 1) + ".pth.tar")
120
+ if IoUmAP_5 > model.best_map:
121
+ model.best_map = IoUmAP_5
122
+ torch.save(state, opt["checkpoint_path"] + "/" + opt["exp"] + "_ckp_best.pth.tar")
123
+
124
+ model.train()
125
+
126
+ writer.close()
127
+ return model.best_map
128
+
129
+ def eval_frame(opt, model, dataset):
130
+ test_loader = torch.utils.data.DataLoader(dataset,
131
+ batch_size=opt['batch_size'], shuffle=False,
132
+ num_workers=0, pin_memory=True, drop_last=False)
133
+
134
+ labels_cls = {}
135
+ labels_reg = {}
136
+ output_cls = {}
137
+ output_reg = {}
138
+ for video_name in dataset.video_list:
139
+ labels_cls[video_name] = []
140
+ labels_reg[video_name] = []
141
+ output_cls[video_name] = []
142
+ output_reg[video_name] = []
143
+
144
+ start_time = time.time()
145
+ total_frames = 0
146
+ epoch_cost = 0
147
+ epoch_cost_cls = 0
148
+ epoch_cost_reg = 0
149
+
150
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
151
+ act_cls, act_reg, _ = model(input_data.float().cuda())
152
+ cost_reg = 0
153
+ cost_cls = 0
154
+
155
+ loss = cls_loss_func(cls_label, act_cls)
156
+ cost_cls = loss
157
+ epoch_cost_cls += cost_cls.detach().cpu().numpy()
158
+
159
+ loss = regress_loss_func(reg_label, act_reg)
160
+ cost_reg = loss
161
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
162
+
163
+ cost = opt['alpha'] * cost_cls + opt['beta'] * cost_reg
164
+ epoch_cost += cost.detach().cpu().numpy()
165
+
166
+ act_cls = torch.softmax(act_cls, dim=-1)
167
+
168
+ total_frames += input_data.size(0)
169
+
170
+ for b in range(0, input_data.size(0)):
171
+ video_name, st, ed, data_idx = dataset.inputs[n_iter * opt['batch_size'] + b]
172
+ output_cls[video_name] += [act_cls[b, :].detach().cpu().numpy()]
173
+ output_reg[video_name] += [act_reg[b, :].detach().cpu().numpy()]
174
+ labels_cls[video_name] += [cls_label[b, :].numpy()]
175
+ labels_reg[video_name] += [reg_label[b, :].numpy()]
176
+
177
+ end_time = time.time()
178
+ working_time = end_time - start_time
179
+
180
+ for video_name in dataset.video_list:
181
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
182
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
183
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
184
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
185
+
186
+ cls_loss = epoch_cost_cls / n_iter
187
+ reg_loss = epoch_cost_reg / n_iter
188
+ tot_loss = epoch_cost / n_iter
189
+
190
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
191
+
192
+ def eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
193
+ result_dict = {}
194
+ proposal_dict = []
195
+
196
+ num_class = opt["num_of_class"]
197
+ unit_size = opt['segment_size']
198
+ threshold = opt['threshold']
199
+ anchors = opt['anchors']
200
+
201
+ for video_name in dataset.video_list:
202
+ duration = dataset.video_len[video_name]
203
+ video_time = float(dataset.video_dict[video_name]["duration"])
204
+ frame_to_time = 100.0 * video_time / duration
205
+
206
+ for idx in range(0, duration):
207
+ cls_anc = output_cls[video_name][idx]
208
+ reg_anc = output_reg[video_name][idx]
209
+
210
+ proposal_anc_dict = []
211
+ for anc_idx in range(0, len(anchors)):
212
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
213
+
214
+ if len(cls) == 0:
215
+ continue
216
+
217
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
218
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
219
+ st = ed - length
220
+
221
+ for cidx in range(0, len(cls)):
222
+ label = cls[cidx]
223
+ tmp_dict = {}
224
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
225
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
226
+ tmp_dict["label"] = dataset.label_name[label]
227
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
228
+ proposal_anc_dict.append(tmp_dict)
229
+
230
+ proposal_dict += proposal_anc_dict
231
+
232
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
233
+ result_dict[video_name] = proposal_dict
234
+ proposal_dict = []
235
+
236
+ return result_dict
237
+
238
+ def eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg):
239
+ model = SuppressNet(opt).cuda()
240
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
241
+ base_dict = checkpoint['state_dict']
242
+ model.load_state_dict(base_dict)
243
+ model.eval()
244
+
245
+ result_dict = {}
246
+ proposal_dict = []
247
+
248
+ num_class = opt["num_of_class"]
249
+ unit_size = opt['segment_size']
250
+ threshold = opt['threshold']
251
+ anchors = opt['anchors']
252
+
253
+ for video_name in dataset.video_list:
254
+ duration = dataset.video_len[video_name]
255
+ video_time = float(dataset.video_dict[video_name]["duration"])
256
+ frame_to_time = 100.0 * video_time / duration
257
+ conf_queue = torch.zeros((unit_size, num_class - 1))
258
+
259
+ for idx in range(0, duration):
260
+ cls_anc = output_cls[video_name][idx]
261
+ reg_anc = output_reg[video_name][idx]
262
+
263
+ proposal_anc_dict = []
264
+ for anc_idx in range(0, len(anchors)):
265
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
266
+
267
+ if len(cls) == 0:
268
+ continue
269
+
270
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
271
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
272
+ st = ed - length
273
+
274
+ for cidx in range(0, len(cls)):
275
+ label = cls[cidx]
276
+ tmp_dict = {}
277
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
278
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
279
+ tmp_dict["label"] = dataset.label_name[label]
280
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
281
+ proposal_anc_dict.append(tmp_dict)
282
+
283
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
284
+
285
+ conf_queue[:-1, :] = conf_queue[1:, :].clone()
286
+ conf_queue[-1, :] = 0
287
+ for proposal in proposal_anc_dict:
288
+ cls_idx = dataset.label_name.index(proposal['label'])
289
+ conf_queue[-1, cls_idx] = proposal["score"]
290
+
291
+ minput = conf_queue.unsqueeze(0)
292
+ suppress_conf = model(minput.cuda())
293
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
294
+
295
+ for cls in range(0, num_class - 1):
296
+ if suppress_conf[cls] > opt['sup_threshold']:
297
+ for proposal in proposal_anc_dict:
298
+ if proposal['label'] == dataset.label_name[cls]:
299
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
300
+ proposal_dict.append(proposal)
301
+
302
+ result_dict[video_name] = proposal_dict
303
+ proposal_dict = []
304
+
305
+ return result_dict
306
+
307
+ def test_frame(opt, video_name=None):
308
+ model = MYNET(opt).cuda()
309
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
310
+ base_dict = checkpoint['state_dict']
311
+ model.load_state_dict(base_dict)
312
+ model.eval()
313
+
314
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
315
+ outfile = h5py.File(opt['frame_result_file'].format(opt['exp']), 'w')
316
+
317
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
318
+
319
+ print("testing loss: %f, cls_loss: %f, reg_loss: %f" % (tot_loss, cls_loss, reg_loss))
320
+
321
+ for video_name in dataset.video_list:
322
+ o_cls = output_cls[video_name]
323
+ o_reg = output_reg[video_name]
324
+ l_cls = labels_cls[video_name]
325
+ l_reg = labels_reg[video_name]
326
+
327
+ dset_predcls = outfile.create_dataset(video_name + '/pred_cls', o_cls.shape, maxshape=o_cls.shape, chunks=True, dtype=np.float32)
328
+ dset_predcls[:, :] = o_cls[:, :]
329
+ dset_predreg = outfile.create_dataset(video_name + '/pred_reg', o_reg.shape, maxshape=o_reg.shape, chunks=True, dtype=np.float32)
330
+ dset_predreg[:, :] = o_reg[:, :]
331
+ dset_labelcls = outfile.create_dataset(video_name + '/label_cls', l_cls.shape, maxshape=l_cls.shape, chunks=True, dtype=np.float32)
332
+ dset_labelcls[:, :] = l_cls[:, :]
333
+ dset_labelreg = outfile.create_dataset(video_name + '/label_reg', l_reg.shape, maxshape=l_reg.shape, chunks=True, dtype=np.float32)
334
+ dset_labelreg[:, :] = l_reg[:, :]
335
+ outfile.close()
336
+
337
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
338
+ return cls_loss, reg_loss, tot_loss
339
+
340
+ def patch_attention(m):
341
+ forward_orig = m.forward
342
+
343
+ def wrap(*args, **kwargs):
344
+ kwargs["need_weights"] = True
345
+ kwargs["average_attn_weights"] = False
346
+ return forward_orig(*args, **kwargs)
347
+
348
+ m.forward = wrap
349
+
350
+ class SaveOutput:
351
+ def __init__(self):
352
+ self.outputs = []
353
+
354
+ def __call__(self, module, module_in, module_out):
355
+ self.outputs.append(module_out[1])
356
+
357
+ def clear(self):
358
+ self.outputs = []
359
+
360
+ def test(opt, video_name=None):
361
+ model = MYNET(opt).cuda()
362
+ checkpoint = torch.load(opt["checkpoint_path"] + "/" + opt['exp'] + "_ckp_best.pth.tar")
363
+ base_dict = checkpoint['state_dict']
364
+ model.load_state_dict(base_dict)
365
+ model.eval()
366
+
367
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
368
+
369
+ cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames = eval_frame(opt, model, dataset)
370
+
371
+ if opt["pptype"] == "nms":
372
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
373
+ if opt["pptype"] == "net":
374
+ result_dict = eval_map_supnet(opt, dataset, output_cls, output_reg, labels_cls, labels_reg)
375
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
376
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
377
+ json.dump(output_dict, outfile, indent=2)
378
+ outfile.close()
379
+
380
+ mAP = evaluation_detection(opt)
381
+ return mAP
382
+
383
+ def test_online(opt, video_name=None):
384
+ model = MYNET(opt).cuda()
385
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best.pth.tar")
386
+ base_dict = checkpoint['state_dict']
387
+ model.load_state_dict(base_dict)
388
+ model.eval()
389
+
390
+ sup_model = SuppressNet(opt).cuda()
391
+ checkpoint = torch.load(opt["checkpoint_path"] + "/ckp_best_suppress.pth.tar")
392
+ base_dict = checkpoint['state_dict']
393
+ sup_model.load_state_dict(base_dict)
394
+ sup_model.eval()
395
+
396
+ dataset = VideoDataSet(opt, subset=opt['inference_subset'], video_name=video_name)
397
+ test_loader = torch.utils.data.DataLoader(dataset,
398
+ batch_size=1, shuffle=False,
399
+ num_workers=0, pin_memory=True, drop_last=False)
400
+
401
+ result_dict = {}
402
+ proposal_dict = []
403
+
404
+ num_class = opt["num_of_class"]
405
+ unit_size = opt['segment_size']
406
+ threshold = opt['threshold']
407
+ anchors = opt['anchors']
408
+
409
+ start_time = time.time()
410
+ total_frames = 0
411
+
412
+ for video_name in dataset.video_list:
413
+ input_queue = torch.zeros((unit_size, opt['feat_dim']))
414
+ sup_queue = torch.zeros(((unit_size, num_class - 1)))
415
+
416
+ duration = dataset.video_len[video_name]
417
+ video_time = float(dataset.video_dict[video_name]["duration"])
418
+ frame_to_time = 100.0 * video_time / duration
419
+
420
+ for idx in range(0, duration):
421
+ total_frames += 1
422
+ input_queue[:-1, :] = input_queue[1:, :].clone()
423
+ input_queue[-1:, :] = dataset._get_base_data(video_name, idx, idx + 1)
424
+
425
+ minput = input_queue.unsqueeze(0)
426
+ act_cls, act_reg, _ = model(minput.cuda())
427
+ act_cls = torch.softmax(act_cls, dim=-1)
428
+
429
+ cls_anc = act_cls.squeeze(0).detach().cpu().numpy()
430
+ reg_anc = act_reg.squeeze(0).detach().cpu().numpy()
431
+
432
+ proposal_anc_dict = []
433
+ for anc_idx in range(0, len(anchors)):
434
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
435
+
436
+ if len(cls) == 0:
437
+ continue
438
+
439
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
440
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
441
+ st = ed - length
442
+
443
+ for cidx in range(0, len(cls)):
444
+ label = cls[cidx]
445
+ tmp_dict = {}
446
+ tmp_dict["segment"] = [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)]
447
+ tmp_dict["score"] = float(cls_anc[anc_idx][label])
448
+ tmp_dict["label"] = dataset.label_name[label]
449
+ tmp_dict["gentime"] = float(idx * frame_to_time / 100.0)
450
+ proposal_anc_dict.append(tmp_dict)
451
+
452
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
453
+
454
+ sup_queue[:-1, :] = sup_queue[1:, :].clone()
455
+ sup_queue[-1, :] = 0
456
+ for proposal in proposal_anc_dict:
457
+ cls_idx = dataset.label_name.index(proposal['label'])
458
+ sup_queue[-1, cls_idx] = proposal["score"]
459
+
460
+ minput = sup_queue.unsqueeze(0)
461
+ suppress_conf = sup_model(minput.cuda())
462
+ suppress_conf = suppress_conf.squeeze(0).detach().cpu().numpy()
463
+
464
+ for cls in range(0, num_class - 1):
465
+ if suppress_conf[cls] > opt['sup_threshold']:
466
+ for proposal in proposal_anc_dict:
467
+ if proposal['label'] == dataset.label_name[cls]:
468
+ if check_overlap_proposal(proposal_dict, proposal, overlapThresh=opt['soft_nms']) is None:
469
+ proposal_dict.append(proposal)
470
+
471
+ result_dict[video_name] = proposal_dict
472
+ proposal_dict = []
473
+
474
+ end_time = time.time()
475
+ working_time = end_time - start_time
476
+ print("working time : {}s, {}fps, {} frames".format(working_time, total_frames / working_time, total_frames))
477
+
478
+ output_dict = {"version": "VERSION 1.3", "results": result_dict, "external_data": {}}
479
+ outfile = open(opt["result_file"].format(opt['exp']), "w")
480
+ json.dump(output_dict, outfile, indent=2)
481
+ outfile.close()
482
+
483
+ mAP = evaluation_detection(opt)
484
+ return mAP
485
+
486
+ def main(opt, video_name=None):
487
+ max_perf = 0
488
+ if not video_name and 'video_name' in opt:
489
+ video_name = opt['video_name']
490
+
491
+ if opt['mode'] == 'train':
492
+ max_perf = train(opt)
493
+ if opt['mode'] == 'test':
494
+ max_perf = test(opt, video_name=video_name)
495
+ if opt['mode'] == 'test_frame':
496
+ max_perf = test_frame(opt, video_name=video_name)
497
+ if opt['mode'] == 'test_online':
498
+ max_perf = test_online(opt, video_name=video_name)
499
+ if opt['mode'] == 'eval':
500
+ max_perf = evaluation_detection(opt)
501
+
502
+ return max_perf
503
+
504
+ if __name__ == '__main__':
505
+ opt = opts.parse_opt()
506
+ opt = vars(opt)
507
+ if not os.path.exists(opt["checkpoint_path"]):
508
+ os.makedirs(opt["checkpoint_path"])
509
+ opt_file = open(opt["checkpoint_path"] + "/" + opt["exp"] + "_opts.json", "w")
510
+ json.dump(opt, opt_file)
511
+ opt_file.close()
512
+
513
+ if opt['seed'] >= 0:
514
+ seed = opt['seed']
515
+ torch.manual_seed(seed)
516
+ np.random.seed(seed)
517
+
518
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
519
+
520
+ video_name = opt.get('video_name', None)
521
+ main(opt, video_name=video_name)
522
+ while(opt['wterm']):
523
+ pass
single result opts_egtea.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def parse_opt():
4
+ parser = argparse.ArgumentParser()
5
+ # Overall settings
6
+ parser.add_argument(
7
+ '--mode',
8
+ type=str,
9
+ default='train')
10
+ parser.add_argument(
11
+ '--video_name',
12
+ type=str,
13
+ default=None,
14
+ help='Name of the single video to evaluate')
15
+ parser.add_argument(
16
+ '--checkpoint_path',
17
+ type=str,
18
+ default='./checkpoint')
19
+ parser.add_argument(
20
+ '--segment_size',
21
+ type=int,
22
+ default=64)
23
+ parser.add_argument(
24
+ '--anchors',
25
+ type=str,
26
+ default='2,4,6,8,12,16')
27
+ parser.add_argument(
28
+ '--seed',
29
+ default=7,
30
+ type=int,
31
+ help='random seed for reproducibility')
32
+
33
+ # Overall Dataset settings
34
+ parser.add_argument(
35
+ '--num_of_class',
36
+ type=int,
37
+ default=23)
38
+ parser.add_argument(
39
+ '--data_format',
40
+ type=str,
41
+ default="npz_i3d")
42
+ parser.add_argument(
43
+ '--data_rescale',
44
+ default=False,
45
+ action='store_true')
46
+ parser.add_argument(
47
+ '--predefined_fps',
48
+ default=None,
49
+ type=float)
50
+ parser.add_argument(
51
+ '--rgb_only',
52
+ default=False,
53
+ action='store_true')
54
+ parser.add_argument(
55
+ '--video_anno',
56
+ type=str,
57
+ default="./data/egtea_annotations_split{}.json")
58
+ parser.add_argument(
59
+ '--video_feature_all_train',
60
+ type=str,
61
+ default="./data/I3D/")
62
+ parser.add_argument(
63
+ '--video_feature_all_test',
64
+ type=str,
65
+ default="./data/I3D/")
66
+ parser.add_argument(
67
+ '--setup',
68
+ type=str,
69
+ default="")
70
+ parser.add_argument(
71
+ '--exp',
72
+ type=str,
73
+ default="01")
74
+ parser.add_argument(
75
+ '--split',
76
+ type=str,
77
+ default="1")
78
+
79
+ # Network
80
+ parser.add_argument(
81
+ '--feat_dim',
82
+ type=int,
83
+ default=2048)
84
+ parser.add_argument(
85
+ '--hidden_dim',
86
+ type=int,
87
+ default=1024)
88
+ parser.add_argument(
89
+ '--out_dim',
90
+ type=int,
91
+ default=23)
92
+ parser.add_argument(
93
+ '--enc_layer',
94
+ type=int,
95
+ default=3)
96
+ parser.add_argument(
97
+ '--enc_head',
98
+ type=int,
99
+ default=8)
100
+ parser.add_argument(
101
+ '--dec_layer',
102
+ type=int,
103
+ default=5)
104
+ parser.add_argument(
105
+ '--dec_head',
106
+ type=int,
107
+ default=4)
108
+
109
+ # Training settings
110
+ parser.add_argument(
111
+ '--batch_size',
112
+ type=int,
113
+ default=128)
114
+ parser.add_argument(
115
+ '--lr',
116
+ type=float,
117
+ default=1e-4)
118
+ parser.add_argument(
119
+ '--weight_decay',
120
+ type=float,
121
+ default=1e-4)
122
+ parser.add_argument(
123
+ '--epoch',
124
+ type=int,
125
+ default=5)
126
+ parser.add_argument(
127
+ '--lr_step',
128
+ type=int,
129
+ default=3)
130
+
131
+ # Post processing
132
+ parser.add_argument(
133
+ '--alpha',
134
+ type=float,
135
+ default=1)
136
+ parser.add_argument(
137
+ '--beta',
138
+ type=float,
139
+ default=1)
140
+ parser.add_argument(
141
+ '--gamma',
142
+ type=float,
143
+ default=0.2)
144
+ parser.add_argument(
145
+ '--pptype',
146
+ type=str,
147
+ default="net")
148
+ parser.add_argument(
149
+ '--pos_threshold',
150
+ type=float,
151
+ default=0.5)
152
+ parser.add_argument(
153
+ '--sup_threshold',
154
+ type=float,
155
+ default=0.1)
156
+ parser.add_argument(
157
+ '--threshold',
158
+ type=float,
159
+ default=0.1)
160
+ parser.add_argument(
161
+ '--inference_subset',
162
+ type=str,
163
+ default="test")
164
+ parser.add_argument(
165
+ '--soft_nms',
166
+ type=float,
167
+ default=0.3)
168
+ parser.add_argument(
169
+ '--video_len_file',
170
+ type=str,
171
+ default="./output/video_len_{}.json")
172
+ parser.add_argument(
173
+ '--proposal_label_file',
174
+ type=str,
175
+ default="./output/proposal_label_{}.h5")
176
+ parser.add_argument(
177
+ '--suppress_label_file',
178
+ type=str,
179
+ default="./output/suppress_label_{}.h5")
180
+ parser.add_argument(
181
+ '--suppress_result_file',
182
+ type=str,
183
+ default="./output/suppress_result{}.h5")
184
+ parser.add_argument(
185
+ '--frame_result_file',
186
+ type=str,
187
+ default="./output/frame_result{}.h5")
188
+ parser.add_argument(
189
+ '--result_file',
190
+ type=str,
191
+ default="./output/result_proposal{}.json")
192
+ parser.add_argument(
193
+ '--wterm',
194
+ type=bool,
195
+ default=False)
196
+
197
+ args = parser.parse_args()
198
+ return args
supnet.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.parallel
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ import opts_egtea as opts
10
+ import time
11
+ import h5py
12
+ from iou_utils import *
13
+ from eval import evaluation_detection
14
+ from tensorboardX import SummaryWriter
15
+ from dataset import VideoDataSet, SuppressDataSet
16
+ from models import MYNET, SuppressNet
17
+ from loss_func import cls_loss_func, regress_loss_func, suppress_loss_func
18
+ from tqdm import tqdm
19
+
20
+ def train_one_epoch(opt, model, train_dataset, optimizer):
21
+ train_loader = torch.utils.data.DataLoader(train_dataset,
22
+ batch_size=opt['batch_size'], shuffle=True,
23
+ num_workers=0, pin_memory=True,drop_last=False)
24
+ epoch_cost = 0
25
+
26
+ for n_iter,(input_data,label) in enumerate(tqdm(train_loader)):
27
+ suppress_conf = model(input_data.cuda())
28
+
29
+ loss = suppress_loss_func(label,suppress_conf)
30
+ epoch_cost+= loss.detach().cpu().numpy()
31
+
32
+ optimizer.zero_grad()
33
+ loss.backward()
34
+ optimizer.step()
35
+
36
+ return n_iter, epoch_cost
37
+
38
+ def eval_one_epoch(opt, model, test_dataset):
39
+ test_loader = torch.utils.data.DataLoader(test_dataset,
40
+ batch_size=opt['batch_size'], shuffle=False,
41
+ num_workers=0, pin_memory=True,drop_last=False)
42
+ epoch_cost = 0
43
+
44
+ for n_iter,(input_data,label) in enumerate(tqdm(test_loader)):
45
+ suppress_conf = model(input_data.cuda())
46
+
47
+ loss = suppress_loss_func(label,suppress_conf)
48
+ epoch_cost+= loss.detach().cpu().numpy()
49
+
50
+ return n_iter, epoch_cost
51
+
52
+
53
+ def train(opt):
54
+ writer = SummaryWriter()
55
+ model = SuppressNet(opt).cuda()
56
+
57
+ optimizer = optim.Adam( model.parameters(),lr=opt["lr"],weight_decay = opt["weight_decay"])
58
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size = opt["lr_step"])
59
+
60
+ train_dataset = SuppressDataSet(opt,subset="train")
61
+ test_dataset = SuppressDataSet(opt,subset=opt['inference_subset'])
62
+
63
+ for n_epoch in range(opt['epoch']):
64
+ n_iter, epoch_cost = train_one_epoch(opt, model, train_dataset, optimizer)
65
+
66
+ writer.add_scalars('sup_data/cost', {'train': epoch_cost/(n_iter+1)}, n_epoch)
67
+ print("training loss(epoch %d): %f, lr - %f"%(n_epoch,
68
+ epoch_cost/(n_iter+1),
69
+ optimizer.param_groups[0]["lr"]) )
70
+
71
+ scheduler.step()
72
+ model.eval()
73
+
74
+ n_iter, eval_cost = eval_one_epoch(opt, model,test_dataset)
75
+
76
+ writer.add_scalars('sup_data/eval', {'test': eval_cost/(n_iter+1)}, n_epoch)
77
+ print("testing loss(epoch %d): %f"%(n_epoch,eval_cost/(n_iter+1)))
78
+
79
+ state = {'epoch': n_epoch + 1,
80
+ 'state_dict': model.state_dict()}
81
+ torch.save(state, opt["checkpoint_path"]+"/checkpoint_suppress_"+str(n_epoch+1)+".pth.tar" )
82
+ if eval_cost < model.best_loss:
83
+ model.best_loss = eval_cost
84
+ torch.save(state, opt["checkpoint_path"]+"/ckp_best_suppress.pth.tar" )
85
+
86
+ model.train()
87
+
88
+ writer.close()
89
+ return
90
+
91
+ def eval_frame(opt, model, dataset):
92
+ test_loader = torch.utils.data.DataLoader(dataset,
93
+ batch_size=opt['batch_size'], shuffle=False,
94
+ num_workers=0, pin_memory=True,drop_last=False)
95
+
96
+ labels_cls={}
97
+ labels_reg={}
98
+ output_cls={}
99
+ output_reg={}
100
+ for video_name in dataset.video_list:
101
+ labels_cls[video_name]=[]
102
+ labels_reg[video_name]=[]
103
+ output_cls[video_name]=[]
104
+ output_reg[video_name]=[]
105
+
106
+ start_time = time.time()
107
+ total_frames =0
108
+ epoch_cost = 0
109
+ epoch_cost_cls = 0
110
+ epoch_cost_reg = 0
111
+
112
+ for n_iter,(input_data,cls_label,reg_label, _) in enumerate(tqdm(test_loader)):
113
+ act_cls, act_reg, _ = model(input_data.cuda())
114
+
115
+ cost_reg = 0
116
+ cost_cls = 0
117
+
118
+ loss = cls_loss_func(cls_label,act_cls)
119
+ cost_cls = loss
120
+
121
+ epoch_cost_cls+= cost_cls.detach().cpu().numpy()
122
+
123
+ loss = regress_loss_func(reg_label,act_reg)
124
+ cost_reg = loss
125
+ epoch_cost_reg += cost_reg.detach().cpu().numpy()
126
+
127
+ cost= opt['alpha']*cost_cls +opt['beta']*cost_reg
128
+
129
+ epoch_cost += cost.detach().cpu().numpy()
130
+
131
+ act_cls = torch.softmax(act_cls, dim=-1)
132
+
133
+ total_frames+=input_data.size(0)
134
+
135
+ for b in range(0,input_data.size(0)):
136
+ video_name, st, ed, data_idx = dataset.inputs[n_iter*opt['batch_size']+b]
137
+ output_cls[video_name]+=[act_cls[b,:].detach().cpu().numpy()]
138
+ output_reg[video_name]+=[act_reg[b,:].detach().cpu().numpy()]
139
+ labels_cls[video_name]+=[cls_label[b,:].numpy()]
140
+ labels_reg[video_name]+=[reg_label[b,:].numpy()]
141
+
142
+ end_time = time.time()
143
+ working_time = end_time-start_time
144
+
145
+ for video_name in dataset.video_list:
146
+ labels_cls[video_name]=np.stack(labels_cls[video_name], axis=0)
147
+ labels_reg[video_name]=np.stack(labels_reg[video_name], axis=0)
148
+ output_cls[video_name]=np.stack(output_cls[video_name], axis=0)
149
+ output_reg[video_name]=np.stack(output_reg[video_name], axis=0)
150
+
151
+ cls_loss=epoch_cost_cls/n_iter
152
+ reg_loss=epoch_cost_reg/n_iter
153
+ tot_loss=epoch_cost/n_iter
154
+
155
+ return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
156
+
157
+
158
+ def test(opt):
159
+ model = SuppressNet(opt).cuda()
160
+ checkpoint = torch.load(opt["checkpoint_path"]+"/" + opt['exp'] + "ckp_best_suppress.pth.tar")
161
+ base_dict=checkpoint['state_dict']
162
+ model.load_state_dict(base_dict)
163
+ model.eval()
164
+
165
+ dataset = SuppressDataSet(opt,subset=opt['inference_subset'])
166
+
167
+ test_loader = torch.utils.data.DataLoader(dataset,
168
+ batch_size=opt['batch_size'], shuffle=False,
169
+ num_workers=0, pin_memory=True,drop_last=False)
170
+ labels={}
171
+ output={}
172
+ for video_name in dataset.video_list:
173
+ labels[video_name]=[]
174
+ output[video_name]=[]
175
+
176
+ for n_iter,(input_data,label) in enumerate(test_loader):
177
+ suppress_conf = model(input_data.cuda())
178
+
179
+ for b in range(0,input_data.size(0)):
180
+ video_name, idx = dataset.inputs[n_iter*opt['batch_size']+b]
181
+ output[video_name]+=[suppress_conf[b,:].detach().cpu().numpy()]
182
+ labels[video_name]+=[label[b,:].numpy()]
183
+
184
+ for video_name in dataset.video_list:
185
+ labels[video_name]=np.stack(labels[video_name], axis=0)
186
+ output[video_name]=np.stack(output[video_name], axis=0)
187
+
188
+ outfile = h5py.File(opt['suppress_result_file'].format(opt['exp']), 'w')
189
+
190
+ for video_name in dataset.video_list:
191
+ o=output[video_name]
192
+ l=labels[video_name]
193
+
194
+ dset_pred = outfile.create_dataset(video_name+'/pred', o.shape, maxshape=o.shape, chunks=True, dtype=np.float32)
195
+ dset_pred[:,:] = o[:,:]
196
+ dset_label = outfile.create_dataset(video_name+'/label', l.shape, maxshape=l.shape, chunks=True, dtype=np.float32)
197
+ dset_label[:,:] = l[:,:]
198
+ outfile.close()
199
+ print('complete')
200
+
201
+
202
+ def make_dataset(opt):
203
+
204
+ model = MYNET(opt).cuda()
205
+ checkpoint = torch.load(opt["checkpoint_path"]+"/"+opt['exp']+"_ckp_best.pth.tar")
206
+ base_dict=checkpoint['state_dict']
207
+ model.load_state_dict(base_dict)
208
+ model.eval()
209
+
210
+ dataset = VideoDataSet(opt,subset=opt['inference_subset'])
211
+
212
+ _, _, _, output_cls, output_reg, labels_cls, labels_reg, _, _ = eval_frame(opt, model,dataset)
213
+
214
+ proposal_dict=[]
215
+
216
+ outfile = h5py.File(opt['suppress_label_file'].format(opt['inference_subset']+'_'+opt['setup']), 'w')
217
+
218
+ num_class = opt["num_of_class"]-1
219
+ unit_size = opt['segment_size']
220
+ threshold=opt['threshold']
221
+ anchors=opt['anchors']
222
+
223
+ for video_name in dataset.video_list:
224
+ duration = dataset.video_len[video_name]
225
+
226
+ for idx in range(0,duration):
227
+ cls_anc = output_cls[video_name][idx]
228
+ reg_anc = output_reg[video_name][idx]
229
+
230
+ proposal_anc_dict=[]
231
+ for anc_idx in range(0,len(anchors)):
232
+ cls = np.argwhere(cls_anc[anc_idx][:-1]>opt['threshold']).reshape(-1)
233
+
234
+ if len(cls) == 0:
235
+ continue
236
+
237
+ ed= idx + anchors[anc_idx] * reg_anc[anc_idx][0]
238
+ length = anchors[anc_idx]* np.exp(reg_anc[anc_idx][1])
239
+ st= ed-length
240
+
241
+ for cidx in range(0,len(cls)):
242
+ label=cls[cidx]
243
+ tmp_dict={}
244
+ tmp_dict["segment"] = [st, ed]
245
+ tmp_dict["score"]= cls_anc[anc_idx][label]
246
+ tmp_dict["label"]=label
247
+ tmp_dict["gentime"]= idx
248
+ proposal_anc_dict.append(tmp_dict)
249
+
250
+ proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
251
+ proposal_dict+=proposal_anc_dict
252
+
253
+ nms_dict=non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
254
+
255
+ input_table = np.zeros((duration,unit_size,num_class), dtype=np.float32)
256
+ label_table = np.zeros((duration,num_class), dtype=np.float32)
257
+
258
+ for proposal in proposal_dict:
259
+ idx = proposal["gentime"]
260
+ conf = proposal["score"]
261
+ cls = proposal["label"]
262
+ for i in range(0,unit_size):
263
+ if idx+i < duration:
264
+ input_table[idx+i,unit_size-1-i,cls]=conf
265
+
266
+ for proposal in nms_dict:
267
+ idx = proposal["gentime"]
268
+ cls = proposal["label"]
269
+ label_table[idx:idx+3,cls]=1
270
+
271
+ dset_input_table = outfile.create_dataset(video_name+'/input', input_table.shape, maxshape=input_table.shape, chunks=True, dtype=np.float32)
272
+ dset_label_table = outfile.create_dataset(video_name+'/label', label_table.shape, maxshape=label_table.shape, chunks=True, dtype=np.float32)
273
+
274
+ dset_input_table[:]=input_table
275
+ dset_label_table[:]=label_table
276
+
277
+ proposal_dict=[]
278
+
279
+ print('complete')
280
+ return
281
+
282
+
283
+ def main(opt):
284
+ if opt['mode'] == 'train':
285
+ train(opt)
286
+ if opt['mode'] == 'test':
287
+ test(opt)
288
+ if opt['mode'] == 'make':
289
+ make_dataset(opt)
290
+
291
+ return
292
+
293
+ if __name__ == '__main__':
294
+ opt = opts.parse_opt()
295
+ opt = vars(opt)
296
+ if not os.path.exists(opt["checkpoint_path"]):
297
+ os.makedirs(opt["checkpoint_path"])
298
+ opt_file=open(opt["checkpoint_path"]+"/"+opt['exp']+"_opts.json","w")
299
+ json.dump(opt,opt_file)
300
+ opt_file.close()
301
+
302
+ if opt['seed'] >= 0:
303
+ seed = opt['seed']
304
+ torch.manual_seed(seed)
305
+ np.random.seed(seed)
306
+ #random.seed(seed)
307
+
308
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
309
+
310
+ main(opt)
311
+ while(opt['wterm']):
312
+ pass
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+
321
+ # import os
322
+ # import json
323
+ # import torch
324
+ # import torchvision
325
+ # import torch.nn.parallel
326
+ # import torch.nn.functional as F
327
+ # import torch.optim as optim
328
+ # import numpy as np
329
+ # # import opts_egtea as opts
330
+ # import opts_thumos as opts
331
+ # import time
332
+ # import h5py
333
+ # from iou_utils import *
334
+ # from eval import evaluation_detection
335
+ # from tensorboardX import SummaryWriter
336
+ # from dataset import VideoDataSet, SuppressDataSet
337
+ # from models import MYNET, SuppressNet
338
+ # from loss_func import cls_loss_func, regress_loss_func, suppress_loss_func
339
+ # from tqdm import tqdm
340
+
341
+ # def train_one_epoch(opt, model, train_dataset, optimizer):
342
+ # train_loader = torch.utils.data.DataLoader(train_dataset,
343
+ # batch_size=opt['batch_size'], shuffle=True,
344
+ # num_workers=0, pin_memory=True,drop_last=False)
345
+ # epoch_cost = 0
346
+
347
+ # for n_iter,(input_data,label) in enumerate(tqdm(train_loader)):
348
+ # suppress_conf = model(input_data.cuda())
349
+
350
+ # loss = suppress_loss_func(label,suppress_conf)
351
+ # epoch_cost+= loss.detach().cpu().numpy()
352
+
353
+ # optimizer.zero_grad()
354
+ # loss.backward()
355
+ # optimizer.step()
356
+
357
+ # return n_iter, epoch_cost
358
+
359
+ # def eval_one_epoch(opt, model, test_dataset):
360
+ # test_loader = torch.utils.data.DataLoader(test_dataset,
361
+ # batch_size=opt['batch_size'], shuffle=False,
362
+ # num_workers=0, pin_memory=True,drop_last=False)
363
+ # epoch_cost = 0
364
+
365
+ # for n_iter,(input_data,label) in enumerate(tqdm(test_loader)):
366
+ # suppress_conf = model(input_data.cuda())
367
+
368
+ # loss = suppress_loss_func(label,suppress_conf)
369
+ # epoch_cost+= loss.detach().cpu().numpy()
370
+
371
+ # return n_iter, epoch_cost
372
+
373
+
374
+ # def train(opt):
375
+ # writer = SummaryWriter()
376
+ # model = SuppressNet(opt).cuda()
377
+
378
+ # optimizer = optim.Adam( model.parameters(),lr=opt["lr"],weight_decay = opt["weight_decay"])
379
+ # scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size = opt["lr_step"])
380
+
381
+ # train_dataset = SuppressDataSet(opt,subset="train")
382
+ # test_dataset = SuppressDataSet(opt,subset=opt['inference_subset'])
383
+
384
+ # for n_epoch in range(opt['epoch']):
385
+ # n_iter, epoch_cost = train_one_epoch(opt, model, train_dataset, optimizer)
386
+
387
+ # writer.add_scalars('sup_data/cost', {'train': epoch_cost/(n_iter+1)}, n_epoch)
388
+ # print("training loss(epoch %d): %f, lr - %f"%(n_epoch,
389
+ # epoch_cost/(n_iter+1),
390
+ # optimizer.param_groups[0]["lr"]) )
391
+
392
+ # scheduler.step()
393
+ # model.eval()
394
+
395
+ # n_iter, eval_cost = eval_one_epoch(opt, model,test_dataset)
396
+
397
+ # writer.add_scalars('sup_data/eval', {'test': eval_cost/(n_iter+1)}, n_epoch)
398
+ # print("testing loss(epoch %d): %f"%(n_epoch,eval_cost/(n_iter+1)))
399
+
400
+ # state = {'epoch': n_epoch + 1,
401
+ # 'state_dict': model.state_dict()}
402
+ # torch.save(state, opt["checkpoint_path"]+"/checkpoint_suppress_"+str(n_epoch+1)+".pth.tar" )
403
+ # if eval_cost < model.best_loss:
404
+ # model.best_loss = eval_cost
405
+ # torch.save(state, opt["checkpoint_path"]+"/ckp_best_suppress.pth.tar" )
406
+
407
+ # model.train()
408
+
409
+ # writer.close()
410
+ # return
411
+
412
+ # def eval_frame(opt, model, dataset):
413
+ # test_loader = torch.utils.data.DataLoader(dataset,
414
+ # batch_size=opt['batch_size'], shuffle=False,
415
+ # num_workers=0, pin_memory=True,drop_last=False)
416
+
417
+ # labels_cls={}
418
+ # labels_reg={}
419
+ # output_cls={}
420
+ # output_reg={}
421
+ # for video_name in dataset.video_list:
422
+ # labels_cls[video_name]=[]
423
+ # labels_reg[video_name]=[]
424
+ # output_cls[video_name]=[]
425
+ # output_reg[video_name]=[]
426
+
427
+ # start_time = time.time()
428
+ # total_frames =0
429
+ # epoch_cost = 0
430
+ # epoch_cost_cls = 0
431
+ # epoch_cost_reg = 0
432
+
433
+ # for n_iter,(input_data,cls_label,reg_label, _) in enumerate(tqdm(test_loader)):
434
+ # act_cls, act_reg, _ = model(input_data.cuda())
435
+
436
+ # cost_reg = 0
437
+ # cost_cls = 0
438
+
439
+ # loss = cls_loss_func(cls_label,act_cls)
440
+ # cost_cls = loss
441
+
442
+ # epoch_cost_cls+= cost_cls.detach().cpu().numpy()
443
+
444
+ # loss = regress_loss_func(reg_label,act_reg)
445
+ # cost_reg = loss
446
+ # epoch_cost_reg += cost_reg.detach().cpu().numpy()
447
+
448
+ # cost= opt['alpha']*cost_cls +opt['beta']*cost_reg
449
+
450
+ # epoch_cost += cost.detach().cpu().numpy()
451
+
452
+ # act_cls = torch.softmax(act_cls, dim=-1)
453
+
454
+ # total_frames+=input_data.size(0)
455
+
456
+ # for b in range(0,input_data.size(0)):
457
+ # video_name, st, ed, data_idx = dataset.inputs[n_iter*opt['batch_size']+b]
458
+ # output_cls[video_name]+=[act_cls[b,:].detach().cpu().numpy()]
459
+ # output_reg[video_name]+=[act_reg[b,:].detach().cpu().numpy()]
460
+ # labels_cls[video_name]+=[cls_label[b,:].numpy()]
461
+ # labels_reg[video_name]+=[reg_label[b,:].numpy()]
462
+
463
+ # end_time = time.time()
464
+ # working_time = end_time-start_time
465
+
466
+ # for video_name in dataset.video_list:
467
+ # labels_cls[video_name]=np.stack(labels_cls[video_name], axis=0)
468
+ # labels_reg[video_name]=np.stack(labels_reg[video_name], axis=0)
469
+ # output_cls[video_name]=np.stack(output_cls[video_name], axis=0)
470
+ # output_reg[video_name]=np.stack(output_reg[video_name], axis=0)
471
+
472
+ # cls_loss=epoch_cost_cls/n_iter
473
+ # reg_loss=epoch_cost_reg/n_iter
474
+ # tot_loss=epoch_cost/n_iter
475
+
476
+ # return cls_loss, reg_loss, tot_loss, output_cls, output_reg, labels_cls, labels_reg, working_time, total_frames
477
+
478
+
479
+ # def test(opt):
480
+ # model = SuppressNet(opt).cuda()
481
+ # checkpoint = torch.load(opt["checkpoint_path"]+"/" + opt['exp'] + "ckp_best_suppress.pth.tar")
482
+ # base_dict=checkpoint['state_dict']
483
+ # model.load_state_dict(base_dict)
484
+ # model.eval()
485
+
486
+ # dataset = SuppressDataSet(opt,subset=opt['inference_subset'])
487
+
488
+ # test_loader = torch.utils.data.DataLoader(dataset,
489
+ # batch_size=opt['batch_size'], shuffle=False,
490
+ # num_workers=0, pin_memory=True,drop_last=False)
491
+ # labels={}
492
+ # output={}
493
+ # for video_name in dataset.video_list:
494
+ # labels[video_name]=[]
495
+ # output[video_name]=[]
496
+
497
+ # for n_iter,(input_data,label) in enumerate(test_loader):
498
+ # suppress_conf = model(input_data.cuda())
499
+
500
+ # for b in range(0,input_data.size(0)):
501
+ # video_name, idx = dataset.inputs[n_iter*opt['batch_size']+b]
502
+ # output[video_name]+=[suppress_conf[b,:].detach().cpu().numpy()]
503
+ # labels[video_name]+=[label[b,:].numpy()]
504
+
505
+ # for video_name in dataset.video_list:
506
+ # labels[video_name]=np.stack(labels[video_name], axis=0)
507
+ # output[video_name]=np.stack(output[video_name], axis=0)
508
+
509
+ # outfile = h5py.File(opt['suppress_result_file'].format(opt['exp']), 'w')
510
+
511
+ # for video_name in dataset.video_list:
512
+ # o=output[video_name]
513
+ # l=labels[video_name]
514
+
515
+ # dset_pred = outfile.create_dataset(video_name+'/pred', o.shape, maxshape=o.shape, chunks=True, dtype=np.float32)
516
+ # dset_pred[:,:] = o[:,:]
517
+ # dset_label = outfile.create_dataset(video_name+'/label', l.shape, maxshape=l.shape, chunks=True, dtype=np.float32)
518
+ # dset_label[:,:] = l[:,:]
519
+ # outfile.close()
520
+ # print('complete')
521
+
522
+
523
+ # def make_dataset(opt):
524
+
525
+ # model = MYNET(opt).cuda()
526
+ # checkpoint = torch.load(opt["checkpoint_path"]+"/"+opt['exp']+"_ckp_best.pth.tar")
527
+ # base_dict=checkpoint['state_dict']
528
+ # model.load_state_dict(base_dict)
529
+ # model.eval()
530
+
531
+ # # Fix: Set the 'split' key to match 'inference_subset'
532
+ # opt['split'] = opt['inference_subset']
533
+
534
+ # dataset = VideoDataSet(opt,subset=opt['inference_subset'])
535
+
536
+ # _, _, _, output_cls, output_reg, labels_cls, labels_reg, _, _ = eval_frame(opt, model,dataset)
537
+
538
+ # proposal_dict=[]
539
+
540
+ # outfile = h5py.File(opt['suppress_label_file'].format(opt['inference_subset']+'_'+opt['setup']), 'w')
541
+
542
+ # num_class = opt["num_of_class"]-1
543
+ # unit_size = opt['segment_size']
544
+ # threshold=opt['threshold']
545
+ # anchors=opt['anchors']
546
+
547
+ # for video_name in dataset.video_list:
548
+ # duration = dataset.video_len[video_name]
549
+
550
+ # for idx in range(0,duration):
551
+ # cls_anc = output_cls[video_name][idx]
552
+ # reg_anc = output_reg[video_name][idx]
553
+
554
+ # proposal_anc_dict=[]
555
+ # for anc_idx in range(0,len(anchors)):
556
+ # cls = np.argwhere(cls_anc[anc_idx][:-1]>opt['threshold']).reshape(-1)
557
+
558
+ # if len(cls) == 0:
559
+ # continue
560
+
561
+ # ed= idx + anchors[anc_idx] * reg_anc[anc_idx][0]
562
+ # length = anchors[anc_idx]* np.exp(reg_anc[anc_idx][1])
563
+ # st= ed-length
564
+
565
+ # for cidx in range(0,len(cls)):
566
+ # label=cls[cidx]
567
+ # tmp_dict={}
568
+ # tmp_dict["segment"] = [st, ed]
569
+ # tmp_dict["score"]= cls_anc[anc_idx][label]
570
+ # tmp_dict["label"]=label
571
+ # tmp_dict["gentime"]= idx
572
+ # proposal_anc_dict.append(tmp_dict)
573
+
574
+ # proposal_anc_dict = non_max_suppression(proposal_anc_dict, overlapThresh=opt['soft_nms'])
575
+ # proposal_dict+=proposal_anc_dict
576
+
577
+ # nms_dict=non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
578
+
579
+ # input_table = np.zeros((duration,unit_size,num_class), dtype=np.float32)
580
+ # label_table = np.zeros((duration,num_class), dtype=np.float32)
581
+
582
+ # for proposal in proposal_dict:
583
+ # idx = proposal["gentime"]
584
+ # conf = proposal["score"]
585
+ # cls = proposal["label"]
586
+ # for i in range(0,unit_size):
587
+ # if idx+i < duration:
588
+ # input_table[idx+i,unit_size-1-i,cls]=conf
589
+
590
+ # for proposal in nms_dict:
591
+ # idx = proposal["gentime"]
592
+ # cls = proposal["label"]
593
+ # label_table[idx:idx+3,cls]=1
594
+
595
+ # dset_input_table = outfile.create_dataset(video_name+'/input', input_table.shape, maxshape=input_table.shape, chunks=True, dtype=np.float32)
596
+ # dset_label_table = outfile.create_dataset(video_name+'/label', label_table.shape, maxshape=label_table.shape, chunks=True, dtype=np.float32)
597
+
598
+ # dset_input_table[:]=input_table
599
+ # dset_label_table[:]=label_table
600
+
601
+ # proposal_dict=[]
602
+
603
+ # outfile.close() # Added missing close() call
604
+ # print('complete')
605
+ # return
606
+
607
+
608
+ # def main(opt):
609
+ # if opt['mode'] == 'train':
610
+ # train(opt)
611
+ # if opt['mode'] == 'test':
612
+ # test(opt)
613
+ # if opt['mode'] == 'make':
614
+ # make_dataset(opt)
615
+
616
+ # return
617
+
618
+ # if __name__ == '__main__':
619
+ # opt = opts.parse_opt()
620
+ # opt = vars(opt)
621
+ # if not os.path.exists(opt["checkpoint_path"]):
622
+ # os.makedirs(opt["checkpoint_path"])
623
+ # opt_file=open(opt["checkpoint_path"]+"/"+opt['exp']+"_opts.json","w")
624
+ # json.dump(opt,opt_file)
625
+ # opt_file.close()
626
+
627
+ # if opt['seed'] >= 0:
628
+ # seed = opt['seed']
629
+ # torch.manual_seed(seed)
630
+ # np.random.seed(seed)
631
+ # #random.seed(seed)
632
+
633
+ # opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
634
+
635
+ # main(opt)
636
+ # while(opt['wterm']):
637
+ # pass