Darknsu commited on
Commit
9bbcc8b
·
verified ·
1 Parent(s): e1fb045

Rename short main.py to main.py

Browse files
Files changed (2) hide show
  1. main.py +636 -0
  2. short main.py +0 -0
main.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ from tqdm import tqdm
7
+ import gradio as gr
8
+ import opts_egtea as opts
9
+ from dataset import VideoDataSet, calc_iou
10
+ from models import MYNET, SuppressNet
11
+ from loss_func import cls_loss_func, regress_loss_func
12
+ from eval import evaluation_detection
13
+ from iou_utils import non_max_suppression, check_overlap_proposal
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.patches as patches
16
+ from PIL import Image, ImageDraw, ImageFont
17
+ from typing import List, Dict, Optional
18
+
19
+ # Visualization Configuration
20
+ VIS_CONFIG = {
21
+ 'frame_interval': 1.0,
22
+ 'max_frames': 20,
23
+ 'save_dir': './output/visualizations',
24
+ 'video_save_dir': './output/videos',
25
+ 'gt_color': '#1f77b4', # Blue for ground truth
26
+ 'pred_color': '#ff7f0e', # Orange for predictions
27
+ 'fontsize_label': 10,
28
+ 'fontsize_title': 14,
29
+ 'frame_highlight_both': 'green',
30
+ 'frame_highlight_gt': 'red',
31
+ 'frame_highlight_pred': 'black',
32
+ 'iou_threshold': 0.3,
33
+ 'frame_scale_factor': 0.8,
34
+ 'video_text_scale': 0.5,
35
+ 'video_gt_text_color': (180, 119, 31), # BGR for OpenCV
36
+ 'video_pred_text_color': (14, 127, 255), # BGR for OpenCV
37
+ 'video_text_thickness': 1,
38
+ 'video_font_path': './data/Poppins ExtraBold Italic 800.ttf',
39
+ 'video_font_fallback': '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf',
40
+ 'video_pred_text_y': 0.45,
41
+ 'video_gt_text_y': 0.55,
42
+ 'video_footer_height': 150,
43
+ 'video_gt_bar_y': 0.5,
44
+ 'video_pred_bar_y': 0.8,
45
+ 'video_bar_height': 0.15,
46
+ 'video_bar_text_scale': 0.7,
47
+ 'min_segment_duration': 1.0,
48
+ 'video_frame_text_y': 0.05,
49
+ 'video_bar_label_x': 10,
50
+ 'video_bar_label_scale': 0.5,
51
+ 'scroll_window_duration': 20.0,
52
+ 'scroll_speed': 0.2,
53
+ }
54
+
55
+ # Determine device
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ print(f"Using device: {device}")
58
+
59
+ def annotate_video_with_actions(
60
+ video_id: str,
61
+ pred_segments: List[Dict],
62
+ gt_segments: List[Dict],
63
+ video_path: str,
64
+ save_dir: str = VIS_CONFIG['video_save_dir'],
65
+ text_scale: float = VIS_CONFIG['video_text_scale'] * 1.2,
66
+ gt_text_color: tuple = VIS_CONFIG['video_gt_text_color'],
67
+ pred_text_color: tuple = VIS_CONFIG['video_pred_text_color'],
68
+ text_thickness: int = VIS_CONFIG['video_text_thickness']
69
+ ) -> str:
70
+ os.makedirs(save_dir, exist_ok=True)
71
+ cap = cv2.VideoCapture(video_path)
72
+ if not cap.isOpened():
73
+ return f"Error: Could not open video {video_path}."
74
+
75
+ fps = cap.get(cv2.CAP_PROP_FPS)
76
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
77
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
78
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
79
+ duration = total_frames / fps
80
+
81
+ footer_height = VIS_CONFIG['video_footer_height']
82
+ output_height = frame_height + footer_height
83
+ output_path = os.path.join(save_dir, f"annotated_{video_id}.avi")
84
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
85
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, output_height))
86
+
87
+ if not out.isOpened():
88
+ cap.release()
89
+ return f"Error: Could not initialize video writer for {output_path}."
90
+
91
+ min_duration = VIS_CONFIG['min_segment_duration']
92
+ gt_segments = [seg for seg in gt_segments if seg['duration'] >= min_duration]
93
+ pred_segments = [seg for seg in pred_segments if seg['duration'] >= min_duration]
94
+
95
+ color_palette = [
96
+ (128, 0, 0), (60, 20, 220), (0, 128, 0), (128, 0, 128), (79, 69, 54),
97
+ (128, 128, 0), (0, 0, 128), (130, 0, 75), (34, 139, 34), (0, 85, 204),
98
+ (149, 146, 209), (235, 206, 135), (250, 230, 230), (191, 226, 159),
99
+ (185, 218, 255), (255, 204, 204), (193, 182, 255), (201, 252, 189),
100
+ (144, 128, 112), (112, 25, 25), (102, 51, 102), (0, 128, 128), (171, 71, 0)
101
+ ]
102
+ action_labels = set(seg['label'] for seg in gt_segments).union(seg['label'] for seg in pred_segments)
103
+ action_color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(action_labels)}
104
+
105
+ gt_color_rgb = (gt_text_color[2], gt_text_color[1], gt_text_color[0])
106
+ pred_color_rgb = (pred_text_color[2], pred_text_color[1], pred_text_color[0])
107
+
108
+ font_path = VIS_CONFIG['video_font_path']
109
+ font_fallback = VIS_CONFIG['video_font_fallback']
110
+ font_size = int(20 * text_scale)
111
+ bar_font_size = int(20 * VIS_CONFIG['video_bar_text_scale'])
112
+ font = None
113
+ bar_font = None
114
+ try:
115
+ font = ImageFont.truetype(font_path, font_size)
116
+ bar_font = ImageFont.truetype(font_path, bar_font_size)
117
+ except IOError:
118
+ try:
119
+ font = ImageFont.truetype(font_fallback, font_size)
120
+ bar_font = ImageFont.truetype(font_fallback, bar_font_size)
121
+ except IOError:
122
+ font = None
123
+ bar_font = None
124
+
125
+ window_size = 20.0
126
+ num_windows = int(np.ceil(duration / window_size))
127
+ text_bar_gap = 48
128
+ text_x = 10
129
+
130
+ frame_idx = 0
131
+ written_frames = 0
132
+ while cap.isOpened():
133
+ ret, frame = cap.read()
134
+ if not ret:
135
+ break
136
+
137
+ extended_frame = np.zeros((output_height, frame_width, 3), dtype=np.uint8)
138
+ extended_frame[:frame_height, :, :] = frame
139
+ extended_frame[frame_height:, :, :] = 255
140
+
141
+ timestamp = frame_idx / fps
142
+ window_idx = int(timestamp // window_size)
143
+ window_start = window_idx * window_size
144
+ window_end = min(window_start + window_size, duration)
145
+ window_duration = window_end - window_start
146
+ window_timestamp = timestamp - window_start
147
+
148
+ gt_labels = [seg['label'] for seg in gt_segments if seg['start'] <= timestamp <= seg['end']]
149
+ gt_text = "GT: " + ", ".join(gt_labels) if gt_labels else ""
150
+ pred_labels = [seg['label'] for seg in pred_segments if seg['start'] <= timestamp <= seg['end']]
151
+ pred_text = "Pred: " + ", ".join(pred_labels) if pred_labels else ""
152
+
153
+ footer_y = frame_height
154
+ gt_bar_y = footer_y + int(0.2 * footer_height)
155
+ pred_bar_y = footer_y + int(0.5 * footer_height)
156
+ bar_height = int(VIS_CONFIG['video_bar_height'] * footer_height)
157
+
158
+ if font:
159
+ gt_text_bbox = bar_font.getbbox("GT")
160
+ pred_text_bbox = bar_font.getbbox("Pred")
161
+ gt_text_width = gt_text_bbox[2] - gt_text_bbox[0]
162
+ pred_text_width = pred_text_bbox[2] - pred_text_bbox[0]
163
+ else:
164
+ gt_text_size, _ = cv2.getTextSize("GT", cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
165
+ pred_text_size, _ = cv2.getTextSize("Pred", cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
166
+ gt_text_width = gt_text_size[0]
167
+ pred_text_width = pred_text_size[0]
168
+ max_text_width = max(gt_text_width, pred_text_width)
169
+ bar_start_x = text_x + max_text_width + text_bar_gap
170
+ bar_width = frame_width - bar_start_x
171
+
172
+ for seg in gt_segments:
173
+ if seg['start'] <= window_end and seg['end'] >= window_start:
174
+ start_t = max(seg['start'], window_start)
175
+ end_t = min(seg['end'], window_start + window_timestamp)
176
+ start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
177
+ end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
178
+ if end_x > start_x:
179
+ cv2.rectangle(
180
+ extended_frame,
181
+ (start_x, gt_bar_y),
182
+ (end_x, gt_bar_y + bar_height),
183
+ action_color_map[seg['label']],
184
+ -1
185
+ )
186
+
187
+ for seg in pred_segments:
188
+ if seg['start'] <= window_end and seg['end'] >= window_start:
189
+ start_t = max(seg['start'], window_start)
190
+ end_t = min(seg['end'], window_start + window_timestamp)
191
+ start_x = bar_start_x + int(((start_t - window_start) / window_duration) * bar_width)
192
+ end_x = bar_start_x + int(((end_t - window_start) / window_duration) * bar_width)
193
+ if end_x > start_x:
194
+ cv2.rectangle(
195
+ extended_frame,
196
+ (start_x, pred_bar_y),
197
+ (end_x, pred_bar_y + bar_height),
198
+ action_color_map[seg['label']],
199
+ -1
200
+ )
201
+
202
+ if font:
203
+ frame_rgb = cv2.cvtColor(extended_frame, cv2.COLOR_BGR2RGB)
204
+ pil_image = Image.fromarray(frame_rgb)
205
+ draw = ImageDraw.Draw(pil_image)
206
+
207
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
208
+ frame_text_bbox = draw.textbbox((0, 0), frame_info, font=font)
209
+ frame_text_width = frame_text_bbox[2] - frame_text_bbox[0]
210
+ frame_text_x = (frame_width - frame_text_width) // 2
211
+ draw.text((frame_text_x, 10), frame_info, font=font, fill=(0, 0, 0))
212
+
213
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
214
+ window_text_bbox = draw.textbbox((0, 0), window_info, font=bar_font)
215
+ window_text_width = window_text_bbox[2] - window_text_bbox[0]
216
+ window_text_x = (frame_width - window_text_width) // 2
217
+ draw.text((window_text_x, footer_y + 10), window_info, font=bar_font, fill=(0, 0, 0))
218
+
219
+ if gt_text:
220
+ gt_y = int(frame_height * VIS_CONFIG['video_gt_text_y'])
221
+ draw.text((10, gt_y), gt_text, font=font, fill=gt_color_rgb)
222
+
223
+ if pred_text:
224
+ pred_y = int(frame_height * VIS_CONFIG['video_pred_text_y'])
225
+ draw.text((10, pred_y), pred_text, font=font, fill=pred_color_rgb)
226
+
227
+ draw.text((text_x, gt_bar_y + bar_height // 2), "GT", font=bar_font, fill=gt_color_rgb)
228
+ draw.text((text_x, pred_bar_y + bar_height // 2), "Pred", font=bar_font, fill=pred_color_rgb)
229
+
230
+ extended_frame = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
231
+ else:
232
+ frame_info = f"Frame: {frame_idx} | FPS: {fps:.2f}"
233
+ text_size, _ = cv2.getTextSize(frame_info, cv2.FONT_HERSHEY_DUPLEX, text_scale, text_thickness)
234
+ frame_text_x = (frame_width - text_size[0]) // 2
235
+ cv2.putText(
236
+ extended_frame,
237
+ frame_info,
238
+ (frame_text_x, 30),
239
+ cv2.FONT_HERSHEY_DUPLEX,
240
+ text_scale,
241
+ (0, 0, 0),
242
+ text_thickness,
243
+ cv2.LINE_AA
244
+ )
245
+ window_info = f"{window_start:.1f}s - {window_end:.1f}s"
246
+ window_text_size, _ = cv2.getTextSize(window_info, cv2.FONT_HERSHEY_DUPLEX, VIS_CONFIG['video_bar_text_scale'], 1)
247
+ window_text_x = (frame_width - window_text_size[0]) // 2
248
+ cv2.putText(
249
+ extended_frame,
250
+ window_info,
251
+ (window_text_x, footer_y + 20),
252
+ cv2.FONT_HERSHEY_DUPLEX,
253
+ VIS_CONFIG['video_bar_text_scale'],
254
+ (0, 0, 0),
255
+ 1,
256
+ cv2.LINE_AA
257
+ )
258
+ if gt_text:
259
+ cv2.putText(
260
+ extended_frame,
261
+ gt_text,
262
+ (10, int(frame_height * VIS_CONFIG['video_gt_text_y'])),
263
+ cv2.FONT_HERSHEY_DUPLEX,
264
+ text_scale,
265
+ gt_text_color,
266
+ text_thickness,
267
+ cv2.LINE_AA
268
+ )
269
+ if pred_text:
270
+ cv2.putText(
271
+ extended_frame,
272
+ pred_text,
273
+ (10, int(frame_height * VIS_CONFIG['video_pred_text_y'])),
274
+ cv2.FONT_HERSHEY_DUPLEX,
275
+ text_scale,
276
+ pred_text_color,
277
+ text_thickness,
278
+ cv2.LINE_AA
279
+ )
280
+ cv2.putText(
281
+ extended_frame,
282
+ "GT",
283
+ (text_x, gt_bar_y + bar_height // 2 + 5),
284
+ cv2.FONT_HERSHEY_DUPLEX,
285
+ VIS_CONFIG['video_bar_text_scale'],
286
+ gt_text_color,
287
+ 1,
288
+ cv2.LINE_AA
289
+ )
290
+ cv2.putText(
291
+ extended_frame,
292
+ "Pred",
293
+ (text_x, pred_bar_y + bar_height // 2 + 5),
294
+ cv2.FONT_HERSHEY_DUPLEX,
295
+ VIS_CONFIG['video_bar_text_scale'],
296
+ pred_text_color,
297
+ 1,
298
+ cv2.LINE_AA
299
+ )
300
+
301
+ out.write(extended_frame)
302
+ written_frames += 1
303
+ frame_idx += 1
304
+
305
+ cap.release()
306
+ out.release()
307
+ mp4_path = os.path.splitext(output_path)[0] + '.mp4'
308
+ os.system(f"ffmpeg -i {output_path} -vcodec libx264 -acodec aac {mp4_path} -y")
309
+ return mp4_path if os.path.exists(mp4_path) else output_path
310
+
311
+ def visualize_action_lengths(
312
+ video_id: str,
313
+ pred_segments: List[Dict],
314
+ gt_segments: List[Dict],
315
+ video_path: str,
316
+ duration: float,
317
+ save_dir: str = VIS_CONFIG['save_dir'],
318
+ frame_interval: float = VIS_CONFIG['frame_interval']
319
+ ) -> str:
320
+ os.makedirs(save_dir, exist_ok=True)
321
+ num_frames = int(duration / frame_interval) + 1
322
+ if num_frames > VIS_CONFIG['max_frames']:
323
+ frame_interval = duration / (VIS_CONFIG['max_frames'] - 1)
324
+ num_frames = VIS_CONFIG['max_frames']
325
+
326
+ frame_times = np.linspace(0, duration, num_frames, endpoint=False)
327
+ frames = []
328
+ cap = cv2.VideoCapture(video_path)
329
+ if not cap.isOpened():
330
+ frames = [np.ones((100, 100, 3), dtype=np.uint8) * 255 for _ in frame_times]
331
+ else:
332
+ for t in frame_times:
333
+ cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
334
+ ret, frame = cap.read()
335
+ if ret:
336
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
337
+ frame = cv2.resize(frame, (int(frame.shape[1] * 0.5), int(frame.shape[0] * 0.5)))
338
+ frames.append(frame)
339
+ else:
340
+ frames.append(np.ones((100, 100, 3), dtype=np.uint8) * 255)
341
+ cap.release()
342
+
343
+ fig = plt.figure(figsize=(num_frames * VIS_CONFIG['frame_scale_factor'], 6), constrained_layout=True)
344
+ gs = fig.add_gridspec(3, num_frames, height_ratios=[3, 1, 1])
345
+
346
+ for i, (t, frame) in enumerate(zip(frame_times, frames)):
347
+ ax = fig.add_subplot(gs[0, i])
348
+ gt_hit = any(seg['start'] <= t <= seg['end'] for seg in gt_segments)
349
+ pred_hit = any(seg['start'] <= t <= seg['end'] for seg in pred_segments)
350
+ border_color = None
351
+ if gt_hit and pred_hit:
352
+ border_color = VIS_CONFIG['frame_highlight_both']
353
+ elif gt_hit:
354
+ border_color = VIS_CONFIG['frame_highlight_gt']
355
+ elif pred_hit:
356
+ border_color = VIS_CONFIG['frame_highlight_pred']
357
+
358
+ ax.imshow(frame)
359
+ ax.axis('off')
360
+ if border_color:
361
+ for spine in ax.spines.values():
362
+ spine.set_edgecolor(border_color)
363
+ spine.set_linewidth(2)
364
+ ax.set_title(f"{t:.1f}s", fontsize=VIS_CONFIG['fontsize_label'], color=border_color or 'black')
365
+
366
+ ax_gt = fig.add_subplot(gs[1, :])
367
+ ax_gt.set_xlim(0, duration)
368
+ ax_gt.set_ylim(0, 1)
369
+ ax_gt.axis('off')
370
+ ax_gt.text(-0.02 * duration, 0.5, "Ground Truth", fontsize=VIS_CONFIG['fontsize_title'], va='center', ha='right', weight='bold')
371
+
372
+ for seg in gt_segments:
373
+ start, end = seg['start'], seg['end']
374
+ width = end - start
375
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
376
+ ax_gt.add_patch(patches.Rectangle(
377
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['gt_color'], edgecolor='black', alpha=0.8
378
+ ))
379
+ ax_gt.text((start + end) / 2, 0.5, label, ha='center', va='center', fontsize=VIS_CONFIG['fontsize_label'], color='white')
380
+ ax_gt.text(start, 0.2, f"{start:.1f}", ha='center', fontsize=8, color='black')
381
+ ax_gt.text(end, 0.2, f"{end:.1f}", ha='center', fontsize=8, color='black')
382
+
383
+ ax_pred = fig.add_subplot(gs[2, :])
384
+ ax_pred.set_xlim(0, duration)
385
+ ax_pred.set_ylim(0, 1)
386
+ ax_pred.axis('off')
387
+ ax_pred.text(-0.02 * duration, 0.5, "Prediction", fontsize=VIS_CONFIG['fontsize_title'], va='center', ha='right', weight='bold')
388
+
389
+ for seg in pred_segments:
390
+ start, end = seg['start'], seg['end']
391
+ width = end - start
392
+ label = seg['label'][:10] + '...' if len(seg['label']) > 10 else seg['label']
393
+ ax_pred.add_patch(patches.Rectangle(
394
+ (start, 0.3), width, 0.4, facecolor=VIS_CONFIG['pred_color'], edgecolor='black', alpha=0.8
395
+ ))
396
+ ax_pred.text((start + end) / 2, 0.5, label, ha='center', va='center', fontsize=VIS_CONFIG['fontsize_label'], color='white')
397
+ ax_pred.text(start, 0.8, f"{start:.1f}", ha='center', fontsize=8, color='black')
398
+ ax_pred.text(end, 0.8, f"{end:.1f}", ha='center', fontsize=8, color='black')
399
+
400
+ jpg_path = os.path.join(save_dir, f"viz_{video_id}.png")
401
+ plt.savefig(jpg_path, dpi=100, bbox_inches='tight')
402
+ plt.close()
403
+ return jpg_path
404
+
405
+ def eval_frame(opt, model, dataset):
406
+ test_loader = torch.utils.data.DataLoader(dataset, batch_size=opt['batch_size'], shuffle=False, num_workers=0, pin_memory=False)
407
+ labels_cls = {video_name: [] for video_name in dataset.video_list}
408
+ labels_reg = {video_name: [] for video_name in dataset.video_list}
409
+ output_cls = {video_name: [] for video_name in dataset.video_list}
410
+ output_reg = {video_name: [] for video_name in dataset.video_list}
411
+
412
+ total_frames = 0
413
+ for n_iter, (input_data, cls_label, reg_label, _) in enumerate(tqdm(test_loader)):
414
+ input_data = input_data.to(device)
415
+ cls_label = cls_label.to(device)
416
+ reg_label = reg_label.to(device)
417
+ act_cls, act_reg, _ = model(input_data.float())
418
+
419
+ act_cls = torch.softmax(act_cls, dim=-1)
420
+ total_frames += input_data.size(0)
421
+
422
+ for b in range(input_data.size(0)):
423
+ video_name, _, _, _ = dataset.inputs[n_iter * opt['batch_size'] + b]
424
+ output_cls[video_name].append(act_cls[b, :].detach().cpu().numpy())
425
+ output_reg[video_name].append(act_reg[b, :].detach().cpu().numpy())
426
+ labels_cls[video_name].append(cls_label[b, :].cpu().numpy())
427
+ labels_reg[video_name].append(reg_label[b, :].cpu().numpy())
428
+
429
+ for video_name in dataset.video_list:
430
+ labels_cls[video_name] = np.stack(labels_cls[video_name], axis=0)
431
+ labels_reg[video_name] = np.stack(labels_reg[video_name], axis=0)
432
+ output_cls[video_name] = np.stack(output_cls[video_name], axis=0)
433
+ output_reg[video_name] = np.stack(output_reg[video_name], axis=0)
434
+
435
+ return output_cls, output_reg, labels_cls, labels_reg
436
+
437
+ def eval_map_nms(opt, dataset, output_cls, output_reg):
438
+ result_dict = {}
439
+ proposal_dict = []
440
+ anchors = opt['anchors']
441
+
442
+ for video_name in dataset.video_list:
443
+ duration = dataset.video_len[video_name]
444
+ video_time = float(dataset.video_dict[video_name]["duration"])
445
+ frame_to_time = 100.0 * video_time / duration
446
+
447
+ for idx in range(duration):
448
+ cls_anc = output_cls[video_name][idx]
449
+ reg_anc = output_reg[video_name][idx]
450
+ proposal_anc_dict = []
451
+
452
+ for anc_idx in range(len(anchors)):
453
+ cls = np.argwhere(cls_anc[anc_idx][:-1] > opt['threshold']).reshape(-1)
454
+ if len(cls) == 0:
455
+ continue
456
+ ed = idx + anchors[anc_idx] * reg_anc[anc_idx][0]
457
+ length = anchors[anc_idx] * np.exp(reg_anc[anc_idx][1])
458
+ st = ed - length
459
+ for cidx in range(len(cls)):
460
+ label = cls[cidx]
461
+ tmp_dict = {
462
+ "segment": [float(st * frame_to_time / 100.0), float(ed * frame_to_time / 100.0)],
463
+ "score": float(cls_anc[anc_idx][label]),
464
+ "label": dataset.label_name[label],
465
+ "gentime": float(idx * frame_to_time / 100.0)
466
+ }
467
+ proposal_anc_dict.append(tmp_dict)
468
+
469
+ proposal_dict += proposal_anc_dict
470
+
471
+ proposal_dict = non_max_suppression(proposal_dict, overlapThresh=opt['soft_nms'])
472
+ result_dict[video_name] = proposal_dict
473
+ proposal_dict = []
474
+
475
+ return result_dict
476
+
477
+ def process_input(video_file, npz_file, checkpoint_path, split_number):
478
+ # Parse options
479
+ opt = opts.parse_opt()
480
+ opt = vars(opt)
481
+ opt['mode'] = 'test'
482
+ opt['split'] = str(split_number)
483
+ opt['checkpoint_path'] = './checkpoints'
484
+ opt['video_feature_all_test'] = './data/I3D/'
485
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
486
+ opt['batch_size'] = 1 # Single video processing
487
+ os.makedirs(opt['checkpoint_path'], exist_ok=True)
488
+ os.makedirs(opt['video_feature_all_test'], exist_ok=True)
489
+
490
+ # Handle input
491
+ video_name = "user_upload"
492
+ video_path = None
493
+ if video_file:
494
+ video_path = video_file
495
+ # Placeholder for I3D feature extraction (to be implemented or assumed precomputed)
496
+ return "Error: Real-time I3D feature extraction not supported. Please upload .npz file."
497
+
498
+ if npz_file:
499
+ npz_path = os.path.join(opt['video_feature_all_test'], f"{video_name}.npz")
500
+ os.makedirs(os.path.dirname(npz_path), exist_ok=True)
501
+ np.savez(npz_path, rgb=np.load(npz_file)['rgb'], flow=np.load(npz_file)['flow'])
502
+
503
+ # Load model
504
+ model = MYNET(opt).to(device)
505
+ checkpoint = torch.load(checkpoint_path, map_location=device)
506
+ model.load_state_dict(checkpoint['state_dict'])
507
+ model.eval()
508
+
509
+ # Create dataset
510
+ dataset = VideoDataSet(opt, subset='test', video_name=video_name)
511
+
512
+ # Run inference
513
+ output_cls, output_reg, labels_cls, labels_reg = eval_frame(opt, model, dataset)
514
+ result_dict = eval_map_nms(opt, dataset, output_cls, output_reg)
515
+
516
+ # Load annotations if available
517
+ gt_segments = []
518
+ duration = 0
519
+ video_anno_file = opt["video_anno"].format(opt["split"])
520
+ if os.path.exists(video_anno_file):
521
+ with open(video_anno_file, 'r') as f:
522
+ anno_data = json.load(f)
523
+ if video_name in anno_data['database']:
524
+ gt_annotations = anno_data['database'][video_name]['annotations']
525
+ duration = anno_data['database'][video_name]['duration']
526
+ for anno in gt_annotations:
527
+ start, end = anno['segment']
528
+ gt_segments.append({'label': anno['label'], 'start': start, 'end': end, 'duration': end - start})
529
+
530
+ pred_segments = []
531
+ for pred in result_dict.get(video_name, []):
532
+ start, end = pred['segment']
533
+ pred_segments.append({
534
+ 'label': pred['label'],
535
+ 'start': start,
536
+ 'end': end,
537
+ 'duration': end - start,
538
+ 'score': pred['score']
539
+ })
540
+
541
+ # Generate comparison table
542
+ output_text = f"Predicted Actions for Video: {video_name}\n\n"
543
+ if gt_segments:
544
+ matches = []
545
+ iou_threshold = VIS_CONFIG['iou_threshold']
546
+ used_gt_indices = set()
547
+ for pred in pred_segments:
548
+ best_iou = 0
549
+ best_gt_idx = None
550
+ for gt_idx, gt in enumerate(gt_segments):
551
+ if gt_idx in used_gt_indices:
552
+ continue
553
+ iou = calc_iou([pred['end'], pred['duration']], [gt['end'], gt['duration']])
554
+ if iou > best_iou and iou >= iou_threshold:
555
+ best_iou = iou
556
+ best_gt_idx = gt_idx
557
+ if best_gt_idx is not None:
558
+ matches.append({'pred': pred, 'gt': gt_segments[best_gt_idx], 'iou': best_iou})
559
+ used_gt_indices.add(best_gt_idx)
560
+ else:
561
+ matches.append({'pred': pred, 'gt': None, 'iou': 0})
562
+
563
+ for gt_idx, gt in enumerate(gt_segments):
564
+ if gt_idx not in used_gt_indices:
565
+ matches.append({'pred': None, 'gt': gt, 'iou': 0})
566
+
567
+ output_text += "{:<20} {:<30} {:<30} {:<15} {:<10}\n".format(
568
+ "Action Label", "Predicted Segment (s)", "Ground Truth Segment (s)", "Duration Diff (s)", "IoU")
569
+ output_text += "-" * 105 + "\n"
570
+ for match in matches:
571
+ pred = match['pred']
572
+ gt = match['gt']
573
+ iou = match['iou']
574
+ if pred and gt:
575
+ label = pred['label'] if pred['label'] == gt['label'] else f"{pred['label']} (GT: {gt['label']})"
576
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
577
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
578
+ duration_diff = pred['duration'] - gt['duration']
579
+ output_text += "{:<20} {:<30} {:<30} {:<15.2f} {:<10.2f}\n".format(
580
+ label, pred_str, gt_str, duration_diff, iou)
581
+ elif pred:
582
+ pred_str = f"[{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s)"
583
+ output_text += "{:<20} {:<30} {:<30} {:<15} {:<10.2f}\n".format(
584
+ pred['label'], pred_str, "None", "N/A", iou)
585
+ elif gt:
586
+ gt_str = f"[{gt['start']:.2f}, {gt['end']:.2f}] ({gt['duration']:.2f}s)"
587
+ output_text += "{:<20} {:<30} {:<30} {:<15} {:<10.2f}\n".format(
588
+ gt['label'], "None", gt_str, "N/A", iou)
589
+
590
+ matched_count = sum(1 for m in matches if m['pred'] and m['gt'])
591
+ avg_duration_diff = np.mean([m['pred']['duration'] - m['gt']['duration'] for m in matches if m['pred'] and m['gt']]) if matched_count > 0 else 0
592
+ avg_iou = np.mean([m['iou'] for m in matches if m['iou'] > 0]) if any(m['iou'] > 0 for m in matches) else 0
593
+ output_text += "\nSummary:\n"
594
+ output_text += f"- Total Predictions: {len(pred_segments)}\n"
595
+ output_text += f"- Total Ground Truth: {len(gt_segments)}\n"
596
+ output_text += f"- Matched Segments: {matched_count}\n"
597
+ output_text += f"- Average Duration Difference (Matched): {avg_duration_diff:.2f}s\n"
598
+ output_text += f"- Average IoU (Matched): {avg_iou:.2f}\n"
599
+ else:
600
+ output_text += "No ground truth annotations available.\nPredicted Segments:\n"
601
+ for pred in pred_segments:
602
+ output_text += f"- {pred['label']}: [{pred['start']:.2f}, {pred['end']:.2f}] ({pred['duration']:.2f}s), Score: {pred['score']:.2f}\n"
603
+
604
+ # Generate visualizations
605
+ viz_path = None
606
+ video_out_path = None
607
+ if video_file and os.path.exists(video_file):
608
+ duration = max([seg['end'] for seg in pred_segments + gt_segments], default=1.0)
609
+ viz_path = visualize_action_lengths(video_name, pred_segments, gt_segments, video_file, duration)
610
+ video_out_path = annotate_video_with_actions(video_name, pred_segments, gt_segments, video_file)
611
+
612
+ return output_text, viz_path, video_out_path
613
+
614
+ # Gradio Interface
615
+ iface = gr.Interface(
616
+ fn=process_input,
617
+ inputs=[
618
+ gr.Video(label="Upload Video (Optional, requires .npz for processing)"),
619
+ gr.File(label="Upload I3D .npz File"),
620
+ gr.File(label="Upload Model Checkpoint (.pth.tar)", file_types=[".pth.tar"]),
621
+ gr.Dropdown(label="Split Number", choices=["1", "2", "3"], value="1")
622
+ ],
623
+ outputs=[
624
+ gr.Textbox(label="Action Predictions"),
625
+ gr.Image(label="Action Visualization", type="filepath"),
626
+ gr.Video(label="Annotated Video")
627
+ ],
628
+ title="Temporal Action Localization",
629
+ description="Upload an I3D .npz file and a trained model checkpoint to predict actions. Optionally upload a video to generate visualizations. Select the annotation split number."
630
+ )
631
+
632
+ if __name__ == '__main__':
633
+ opt = opts.parse_opt()
634
+ opt = vars(opt)
635
+ opt['anchors'] = [int(item) for item in opt['anchors'].split(',')]
636
+ iface.launch()
short main.py DELETED
The diff for this file is too large to render. See raw diff