Hamzah commited on
Commit
139a373
·
1 Parent(s): f297659

first commit

Browse files
Files changed (3) hide show
  1. main.py +600 -0
  2. outliers_removal_algorithm.py +206 -0
  3. reorder_frames_algorithm.py +380 -0
main.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main script for video processing: outlier detection and/or frame reordering.
4
+
5
+ Place your videos in the './inference' folder and run this script to process them.
6
+ Processed videos will be saved with '_fixed' suffix.
7
+
8
+ This script can perform:
9
+ 1. Outlier detection only (--task outliers)
10
+ 2. Frame reordering only (--task reorder)
11
+ 3. Both operations (--task both): outlier detection first, then reordering
12
+
13
+ Uses DBSCAN for outlier detection.
14
+
15
+ Usage:
16
+ # Process all videos in ./inference folder
17
+ python main.py --input-dir ./inference --task both
18
+
19
+ # Process a single video from inference folder
20
+ python main.py --video ./inference/my_video.avi --task both
21
+
22
+ # Custom output directory (save to outlier_artifacts)
23
+ python main.py --input-dir ./inference --task outliers --output-dir ./outlier_artifacts/cleaned_videos
24
+
25
+ # Custom DBSCAN parameters
26
+ python main.py --input-dir ./inference --task both --eps 0.5 --min-samples 40
27
+
28
+ # Process videos from UCF101_videos with custom model (DINOv2)
29
+ python main.py --input-dir ./UCF101_videos --task outliers --model-type dinov2
30
+
31
+ # Process videos with ResNet18 model
32
+ python main.py --input-dir ./inference --task outliers --model-type resnet18
33
+
34
+ Output:
35
+ - Default: Videos saved in same directory as input with '_fixed' suffix
36
+ - With --output-dir: Videos saved in specified directory with '_fixed' suffix
37
+ - Outlier detection: video_fixed.avi (outliers removed)
38
+ - Frame reordering: video_fixed.avi (frames reordered)
39
+ - Both: video_fixed.avi (outliers removed AND frames reordered, no intermediate files)
40
+ """
41
+
42
+ import os
43
+ import argparse
44
+ import glob
45
+ from pathlib import Path
46
+ import cv2
47
+ import numpy as np
48
+ import torch
49
+ from PIL import Image
50
+ from tqdm import tqdm
51
+
52
+ from outliers_removal_algorithm import dbscan_outliers, USE_GPU
53
+ from reorder_frames_algorithm import load_video_gray, compute_blurred_mse_matrix, build_best_path
54
+
55
+ # Device configuration
56
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ # Supported video extensions
59
+ VIDEO_EXTS = ('.avi', '.mp4', '.mov', '.mkv')
60
+
61
+ # ==========================================
62
+ # EMBEDDING EXTRACTION (Outlier Detection)
63
+ # ==========================================
64
+
65
+ def load_embedding_model(model_type='clip', model_path=None, device='cuda'):
66
+ """Load CLIP, DINOv2, or ResNet18 model for embedding extraction."""
67
+ print(f"Loading {model_type.upper()} model...")
68
+
69
+ if model_type == 'clip':
70
+ import clip
71
+ model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
72
+ model.eval()
73
+ torch.set_grad_enabled(False)
74
+ embedding_dim = 512
75
+
76
+ def extract_fn(image_batch):
77
+ with torch.no_grad():
78
+ feats = model.encode_image(image_batch)
79
+ feats = torch.nn.functional.normalize(feats, dim=-1)
80
+ return feats
81
+
82
+ print(f"CLIP model loaded: ViT-B/32 ({embedding_dim}-dim)")
83
+ return extract_fn, preprocess, embedding_dim
84
+
85
+ elif model_type == 'dinov2':
86
+ from transformers import pipeline
87
+ from torchvision import transforms
88
+
89
+ if model_path is None:
90
+ model_path = "facebook/dinov2-base"
91
+
92
+ feature_extractor = pipeline(
93
+ model=model_path,
94
+ task="image-feature-extraction",
95
+ device=0 if (device == 'cuda' and torch.cuda.is_available()) else -1
96
+ )
97
+
98
+ test_img = Image.new('RGB', (224, 224))
99
+ test_emb = feature_extractor(test_img)
100
+ embedding_dim = len(test_emb[0])
101
+
102
+ preprocess = transforms.Compose([
103
+ transforms.Resize(256),
104
+ transforms.CenterCrop(224),
105
+ transforms.ToTensor(),
106
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
107
+ ])
108
+
109
+ def extract_fn(image_batch):
110
+ images = []
111
+ for i in range(image_batch.shape[0]):
112
+ img_tensor = image_batch[i]
113
+ img_np = img_tensor.cpu().permute(1, 2, 0).numpy()
114
+ img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
115
+ img_np = (img_np * 255).clip(0, 255).astype(np.uint8)
116
+ images.append(Image.fromarray(img_np))
117
+
118
+ features = feature_extractor(images)
119
+ feats = torch.tensor(features, device=device).squeeze(1)
120
+ feats = torch.nn.functional.normalize(feats, dim=-1)
121
+ return feats
122
+
123
+ print(f"DINOv2 model loaded: {model_path} ({embedding_dim}-dim)")
124
+ return extract_fn, preprocess, embedding_dim
125
+
126
+ elif model_type == 'resnet18':
127
+ from torchvision import models, transforms
128
+
129
+ # Load ResNet18 pretrained model
130
+ model = models.resnet18(pretrained=True)
131
+ # Remove the final classification layer to get embeddings
132
+ model = torch.nn.Sequential(*list(model.children())[:-1])
133
+ model = model.to(device)
134
+ model.eval()
135
+ torch.set_grad_enabled(False)
136
+
137
+ embedding_dim = 512 # ResNet18 final layer dimension
138
+
139
+ preprocess = transforms.Compose([
140
+ transforms.Resize(256),
141
+ transforms.CenterCrop(224),
142
+ transforms.ToTensor(),
143
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
144
+ ])
145
+
146
+ def extract_fn(image_batch):
147
+ with torch.no_grad():
148
+ feats = model(image_batch)
149
+ feats = feats.squeeze(-1).squeeze(-1) # Remove spatial dimensions
150
+ feats = torch.nn.functional.normalize(feats, dim=-1)
151
+ return feats
152
+
153
+ print(f"ResNet18 model loaded ({embedding_dim}-dim)")
154
+ return extract_fn, preprocess, embedding_dim
155
+
156
+ else:
157
+ raise ValueError(f"Unknown model type: {model_type}")
158
+
159
+
160
+ def extract_video_embeddings(video_path, extract_fn, preprocess, device='cuda', batch_size=128):
161
+ """Extract embeddings for all frames in a video."""
162
+ cap = cv2.VideoCapture(str(video_path))
163
+ if not cap.isOpened():
164
+ raise ValueError(f"Cannot open video: {video_path}")
165
+
166
+ fps = cap.get(cv2.CAP_PROP_FPS)
167
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
168
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
169
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
170
+
171
+ print(f"Video: {Path(video_path).name}")
172
+ print(f"Properties: {width}x{height}, {fps:.2f} fps, {total_frames} frames")
173
+ print(f"Extracting embeddings with batch_size={batch_size}...")
174
+
175
+ frame_batch = []
176
+ all_embeddings = []
177
+
178
+ with tqdm(total=total_frames, desc="Extracting", unit="frame") as pbar:
179
+ while True:
180
+ ret, frame = cap.read()
181
+ if not ret:
182
+ break
183
+
184
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
+ pil_image = Image.fromarray(frame_rgb)
186
+ frame_tensor = preprocess(pil_image)
187
+ frame_batch.append(frame_tensor)
188
+
189
+ if len(frame_batch) >= batch_size:
190
+ batch = torch.stack(frame_batch, dim=0)
191
+ if device == 'cuda':
192
+ batch = batch.pin_memory().to(device, non_blocking=True)
193
+ else:
194
+ batch = batch.to(device)
195
+
196
+ feats = extract_fn(batch)
197
+ all_embeddings.append(feats.cpu())
198
+ frame_batch.clear()
199
+ pbar.update(batch_size)
200
+
201
+ if frame_batch:
202
+ batch = torch.stack(frame_batch, dim=0)
203
+ if device == 'cuda':
204
+ batch = batch.pin_memory().to(device, non_blocking=True)
205
+ else:
206
+ batch = batch.to(device)
207
+
208
+ feats = extract_fn(batch)
209
+ all_embeddings.append(feats.cpu())
210
+ pbar.update(len(frame_batch))
211
+
212
+ cap.release()
213
+
214
+ embeddings = torch.cat(all_embeddings, dim=0)
215
+ print(f"Extracted {len(embeddings)} embeddings")
216
+
217
+ return embeddings, fps, width, height
218
+
219
+
220
+ # ==========================================
221
+ # VIDEO SAVING
222
+ # ==========================================
223
+
224
+ def save_cleaned_video(video_path, predictions, output_path, fps, width, height):
225
+ """Create cleaned video with outliers removed."""
226
+ num_outliers = predictions.sum()
227
+ num_inliers = len(predictions) - num_outliers
228
+
229
+ print(f"\nOutlier Detection Results:")
230
+ print(f" Total frames: {len(predictions)}")
231
+ print(f" Inliers: {num_inliers} ({100*num_inliers/len(predictions):.1f}%)")
232
+ print(f" Outliers: {num_outliers} ({100*num_outliers/len(predictions):.1f}%)")
233
+
234
+ cap = cv2.VideoCapture(str(video_path))
235
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
236
+ out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
237
+
238
+ frame_id = 0
239
+ kept = 0
240
+
241
+ print(f"\nGenerating cleaned video: {Path(output_path).name}")
242
+ with tqdm(total=len(predictions), desc="Writing", unit="frame") as pbar:
243
+ while True:
244
+ ret, frame = cap.read()
245
+ if not ret:
246
+ break
247
+
248
+ if frame_id < len(predictions) and not predictions[frame_id]:
249
+ out.write(frame)
250
+ kept += 1
251
+
252
+ frame_id += 1
253
+ pbar.update(1)
254
+
255
+ cap.release()
256
+ out.release()
257
+
258
+ print(f"Cleaned video saved: {output_path}")
259
+ return output_path
260
+
261
+
262
+ def save_reordered_video(video_path, frame_order, output_path):
263
+ """Create reordered video using predicted frame order."""
264
+ # Load all frames
265
+ cap = cv2.VideoCapture(str(video_path))
266
+ frames = []
267
+ while True:
268
+ ret, frame = cap.read()
269
+ if not ret:
270
+ break
271
+ frames.append(frame)
272
+
273
+ fps = cap.get(cv2.CAP_PROP_FPS)
274
+ height, width = frames[0].shape[:2]
275
+ cap.release()
276
+
277
+ print(f"\nFrame Reordering Results:")
278
+ print(f" Total frames: {len(frames)}")
279
+ print(f" Reconstructed order: {len(frame_order)} frames")
280
+
281
+ # Write reordered video
282
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
283
+ out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
284
+
285
+ print(f"\nGenerating reordered video: {Path(output_path).name}")
286
+ for idx in tqdm(frame_order, desc="Writing", unit="frame"):
287
+ if 0 <= idx < len(frames):
288
+ out.write(frames[idx])
289
+
290
+ out.release()
291
+
292
+ print(f"Reordered video saved: {output_path}")
293
+ return output_path
294
+
295
+
296
+ def save_cleaned_and_reordered_video(video_path, outlier_predictions, frame_order, output_path):
297
+ """Create video with outliers removed and frames reordered in one pass."""
298
+ # Load all frames
299
+ cap = cv2.VideoCapture(str(video_path))
300
+ all_frames = []
301
+ while True:
302
+ ret, frame = cap.read()
303
+ if not ret:
304
+ break
305
+ all_frames.append(frame)
306
+
307
+ fps = cap.get(cv2.CAP_PROP_FPS)
308
+ height, width = all_frames[0].shape[:2]
309
+ cap.release()
310
+
311
+ # Filter out outliers
312
+ inlier_frames = [all_frames[i] for i in range(len(all_frames))
313
+ if i < len(outlier_predictions) and not outlier_predictions[i]]
314
+
315
+ num_outliers = outlier_predictions.sum()
316
+ print(f"\nCombined Processing Results:")
317
+ print(f" Original frames: {len(all_frames)}")
318
+ print(f" Outliers removed: {num_outliers} ({100*num_outliers/len(all_frames):.1f}%)")
319
+ print(f" Inlier frames: {len(inlier_frames)} ({100*len(inlier_frames)/len(all_frames):.1f}%)")
320
+ print(f" Reordered frames: {len(frame_order)}")
321
+
322
+ # Write reordered video with only inlier frames
323
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
324
+ out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
325
+
326
+ print(f"\nGenerating final video: {Path(output_path).name}")
327
+ for idx in tqdm(frame_order, desc="Writing", unit="frame"):
328
+ if 0 <= idx < len(inlier_frames):
329
+ out.write(inlier_frames[idx])
330
+
331
+ out.release()
332
+
333
+ print(f"Final video saved: {output_path}")
334
+ return output_path
335
+
336
+
337
+ # ==========================================
338
+ # MAIN PIPELINE
339
+ # ==========================================
340
+
341
+ def run_outlier_detection(video_path, output_path, args):
342
+ """Run outlier detection pipeline using imported functions."""
343
+ print("OUTLIER DETECTION")
344
+ print(f"GPU Acceleration: {'Enabled (cuML)' if USE_GPU else 'Disabled (CPU/sklearn)'}")
345
+
346
+ # Load embedding model
347
+ extract_fn, preprocess, embedding_dim = load_embedding_model(
348
+ model_type=args.model_type,
349
+ model_path=args.model_path,
350
+ device=DEVICE
351
+ )
352
+
353
+ # Extract embeddings
354
+ embeddings, fps, width, height = extract_video_embeddings(
355
+ video_path, extract_fn, preprocess, DEVICE, args.batch_size
356
+ )
357
+
358
+ # Detect outliers using DBSCAN
359
+ print(f"\nRunning DBSCAN outlier detection...")
360
+ predictions = dbscan_outliers(
361
+ embeddings,
362
+ eps=args.eps,
363
+ min_samples=args.min_samples
364
+ )
365
+
366
+ # Save cleaned video
367
+ cleaned_path = save_cleaned_video(video_path, predictions, output_path, fps, width, height)
368
+ return cleaned_path
369
+
370
+
371
+ def run_frame_reordering(video_path, output_path):
372
+ """Run frame reordering pipeline."""
373
+ print("\n" + "=" * 80)
374
+ print("FRAME REORDERING")
375
+ print("=" * 80)
376
+
377
+ print(f"Loading video: {Path(video_path).name}")
378
+ frames = load_video_gray(str(video_path))
379
+ print(f"Loaded {len(frames)} frames")
380
+
381
+ print("Computing MSE matrix...")
382
+ mse = compute_blurred_mse_matrix(frames)
383
+
384
+ print("Building temporal path...")
385
+ path = build_best_path(mse)
386
+
387
+ # Save reordered video
388
+ reordered_path = save_reordered_video(video_path, path, output_path)
389
+ return reordered_path
390
+
391
+
392
+ def run_both_tasks(video_path, output_path, args):
393
+ """Run both outlier detection and frame reordering without saving intermediate video."""
394
+ print("\n" + "=" * 80)
395
+ print("STEP 1: OUTLIER DETECTION")
396
+ print("=" * 80)
397
+ print(f"GPU Acceleration: {'Enabled (cuML)' if USE_GPU else 'Disabled (CPU/sklearn)'}")
398
+
399
+ # Load embedding model and extract embeddings
400
+ extract_fn, preprocess, embedding_dim = load_embedding_model(
401
+ model_type=args.model_type,
402
+ model_path=args.model_path,
403
+ device=DEVICE
404
+ )
405
+
406
+ embeddings, fps, width, height = extract_video_embeddings(
407
+ video_path, extract_fn, preprocess, DEVICE, args.batch_size
408
+ )
409
+
410
+ # Detect outliers using DBSCAN
411
+ print(f"\nRunning DBSCAN outlier detection...")
412
+ outlier_predictions = dbscan_outliers(
413
+ embeddings,
414
+ eps=args.eps,
415
+ min_samples=args.min_samples
416
+ )
417
+
418
+ num_outliers = outlier_predictions.sum()
419
+ num_inliers = len(outlier_predictions) - num_outliers
420
+ print(f"\nOutlier Detection Results:")
421
+ print(f" Total frames: {len(outlier_predictions)}")
422
+ print(f" Inliers: {num_inliers} ({100*num_inliers/len(outlier_predictions):.1f}%)")
423
+ print(f" Outliers: {num_outliers} ({100*num_outliers/len(outlier_predictions):.1f}%)")
424
+
425
+ # Step 2: Frame reordering on inlier frames
426
+ print("\n" + "=" * 80)
427
+ print("STEP 2: FRAME REORDERING (on inlier frames)")
428
+ print("=" * 80)
429
+
430
+ all_frames = load_video_gray(str(video_path))
431
+
432
+ # Filter to only inlier frames
433
+ inlier_frames = []
434
+ for i in range(len(all_frames)):
435
+ if i < len(outlier_predictions) and not outlier_predictions[i]:
436
+ inlier_frames.append(all_frames[i])
437
+
438
+ inlier_frames = torch.stack(inlier_frames, dim=0)
439
+ mse = compute_blurred_mse_matrix(inlier_frames)
440
+ path = build_best_path(mse)
441
+
442
+ # Save final video (cleaned and reordered)
443
+ final_path = save_cleaned_and_reordered_video(video_path, outlier_predictions, path, output_path)
444
+ return final_path
445
+
446
+
447
+ def get_output_path(input_path, output_dir, suffix="_fixed"):
448
+ """Determine the output path based on input path and output directory."""
449
+ input_path = Path(input_path)
450
+
451
+ if output_dir:
452
+ # Use specified output directory
453
+ output_dir = Path(output_dir)
454
+ output_dir.mkdir(exist_ok=True, parents=True)
455
+ output_name = f"{input_path.stem}{suffix}{input_path.suffix}"
456
+ return output_dir / output_name
457
+ else:
458
+ # Save in same directory as input
459
+ output_name = f"{input_path.stem}{suffix}{input_path.suffix}"
460
+ return input_path.parent / output_name
461
+
462
+
463
+ def process_single_video(video_path, args):
464
+ """Process a single video file."""
465
+ video_path = Path(video_path)
466
+
467
+ if not video_path.exists():
468
+ print(f"Error: Video not found: {video_path}")
469
+ return
470
+
471
+ print("=" * 80)
472
+ print(f"Processing: {video_path.name}")
473
+ print("=" * 80)
474
+ print(f"Task: {args.task.upper()}")
475
+ print("=" * 80)
476
+
477
+ # Determine output path
478
+ output_path = get_output_path(video_path, args.output_dir)
479
+
480
+ # Execute tasks
481
+ if args.task == "outliers":
482
+ run_outlier_detection(str(video_path), str(output_path), args)
483
+
484
+ elif args.task == "reorder":
485
+ run_frame_reordering(str(video_path), str(output_path))
486
+
487
+ elif args.task == "both":
488
+ # Run both tasks without saving intermediate video
489
+ run_both_tasks(str(video_path), str(output_path), args)
490
+
491
+ print("\n" + "=" * 80)
492
+ print("PROCESSING COMPLETE")
493
+ print("=" * 80)
494
+ print(f"Output: {output_path}")
495
+
496
+
497
+ def process_directory(input_dir, args):
498
+ """Process all videos in a directory."""
499
+ input_dir = Path(input_dir)
500
+
501
+ if not input_dir.exists():
502
+ print(f"Error: Directory not found: {input_dir}")
503
+ return
504
+
505
+ # Find all video files
506
+ video_files = []
507
+ for ext in VIDEO_EXTS:
508
+ video_files.extend(input_dir.glob(f"*{ext}"))
509
+
510
+ video_files = sorted(video_files)
511
+
512
+ if not video_files:
513
+ print(f"No video files found in {input_dir}")
514
+ print(f"Supported extensions: {VIDEO_EXTS}")
515
+ return
516
+
517
+ print("=" * 80)
518
+ print(f"Found {len(video_files)} video(s) in {input_dir}")
519
+ print("=" * 80)
520
+
521
+ # Process each video
522
+ for i, video_path in enumerate(video_files, 1):
523
+ print(f"\n[{i}/{len(video_files)}] Processing: {video_path.name}")
524
+
525
+ # Determine output path
526
+ output_path = get_output_path(video_path, args.output_dir)
527
+
528
+ try:
529
+ # Execute tasks
530
+ if args.task == "outliers":
531
+ run_outlier_detection(str(video_path), str(output_path), args)
532
+
533
+ elif args.task == "reorder":
534
+ run_frame_reordering(str(video_path), str(output_path))
535
+
536
+ elif args.task == "both":
537
+ # Run both tasks without saving intermediate video
538
+ run_both_tasks(str(video_path), str(output_path), args)
539
+
540
+ print(f" ✓ Saved: {output_path}")
541
+
542
+ except Exception as e:
543
+ print(f" ✗ Error processing {video_path.name}: {e}")
544
+ continue
545
+
546
+ print("\n" + "=" * 80)
547
+ print("BATCH PROCESSING COMPLETE")
548
+ print("=" * 80)
549
+
550
+
551
+ def main():
552
+ parser = argparse.ArgumentParser(
553
+ description="Main script for video processing: outlier detection (DBSCAN) and/or frame reordering"
554
+ )
555
+
556
+ # Input arguments (mutually exclusive)
557
+ input_group = parser.add_mutually_exclusive_group(required=True)
558
+ input_group.add_argument("--video",
559
+ help="Process a single video file")
560
+ input_group.add_argument("--input-dir",
561
+ help="Process all videos in a directory (default: ./inference)")
562
+
563
+ # Task selection
564
+ parser.add_argument("--task", required=True, choices=["outliers", "reorder", "both"],
565
+ help="Task to perform: outliers, reorder, or both")
566
+
567
+ # Output directory (optional)
568
+ parser.add_argument("--output-dir",
569
+ help="Output directory (default: same as input directory)")
570
+
571
+ # Outlier detection parameters
572
+ parser.add_argument("--model-type", default="clip", choices=["clip", "dinov2", "resnet18"],
573
+ help="Embedding model type for outlier detection")
574
+ parser.add_argument("--model-path", help="Path to DINOv2 model (optional)")
575
+ parser.add_argument("--batch-size", type=int, default=128,
576
+ help="Batch size for embedding extraction")
577
+
578
+ # DBSCAN parameters
579
+ parser.add_argument("--eps", type=float, default=0.5,
580
+ help="DBSCAN: Epsilon parameter")
581
+ parser.add_argument("--min-samples", type=int, default=40,
582
+ help="DBSCAN: Minimum samples parameter")
583
+
584
+ args = parser.parse_args()
585
+
586
+ # Default to ./inference if neither --video nor --input-dir specified
587
+ # (This won't happen due to required=True, but keeping for clarity)
588
+
589
+ if args.task in ["outliers", "both"]:
590
+ print(f"DBSCAN parameters: eps={args.eps}, min_samples={args.min_samples}")
591
+
592
+ # Process based on input mode
593
+ if args.video:
594
+ process_single_video(args.video, args)
595
+ elif args.input_dir:
596
+ process_directory(args.input_dir, args)
597
+
598
+
599
+ if __name__ == "__main__":
600
+ main()
outliers_removal_algorithm.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Outlier removal algorithm for video frame embeddings using DBSCAN.
4
+
5
+ Reads embeddings, detects outliers, and exports predictions to CSV files.
6
+ GPU acceleration is automatically detected and used if available.
7
+
8
+ Usage:
9
+ # Process CLIP embeddings from outlier_artifacts
10
+ python outliers_removal_algorithm.py --embeddings-dir ./outlier_artifacts/embeddings --output-dir ./outlier_artifacts/cleaned_CSVs --model-type clip
11
+
12
+ # Process DINOv2 embeddings
13
+ python outliers_removal_algorithm.py --embeddings-dir ./outlier_artifacts/embeddings --output-dir ./outlier_artifacts/cleaned_CSVs --model-type dinov2
14
+
15
+ # Process ResNet18 embeddings
16
+ python outliers_removal_algorithm.py --embeddings-dir ./outlier_artifacts/embeddings --output-dir ./outlier_artifacts/cleaned_CSVs --model-type resnet18
17
+
18
+ # Custom DBSCAN parameters with CLIP embeddings
19
+ python outliers_removal_algorithm.py --embeddings-dir ./outlier_artifacts/embeddings --output-dir ./outlier_artifacts/cleaned_CSVs --model-type clip --eps 0.45 --min-samples 50
20
+
21
+ # Filter to specific action category
22
+ python outliers_removal_algorithm.py --embeddings-dir ./outlier_artifacts/embeddings --output-dir ./outlier_artifacts/cleaned_CSVs --model-type clip --action-filter Crawling
23
+
24
+ # Limit processing to first 10 videos
25
+ python outliers_removal_algorithm.py --embeddings-dir ./outlier_artifacts/embeddings --output-dir ./outlier_artifacts/cleaned_CSVs --model-type clip --max-videos 10
26
+
27
+ Note: To generate cleaned videos from predictions, use generate_cleaned_videos_from_predictions.py
28
+ """
29
+
30
+ import os
31
+ import glob
32
+ import csv
33
+ import argparse
34
+ import numpy as np
35
+ import torch
36
+ from pathlib import Path
37
+
38
+ try:
39
+ import cupy as cp
40
+ from cuml.cluster import DBSCAN as cuDBSCAN
41
+ CUML_AVAILABLE = True
42
+ except ImportError:
43
+ CUML_AVAILABLE = False
44
+
45
+ from sklearn.cluster import DBSCAN as skDBSCAN
46
+
47
+ # Automatically detect GPU availability
48
+ USE_GPU = CUML_AVAILABLE and torch.cuda.is_available()
49
+
50
+
51
+ def to_numpy(x):
52
+ """Convert tensor or array to numpy float32."""
53
+ if isinstance(x, torch.Tensor):
54
+ x = x.detach().cpu().numpy()
55
+ return np.asarray(x, dtype=np.float32)
56
+
57
+
58
+ def dbscan_outliers(X, eps=0.55, min_samples=10):
59
+ """
60
+ Detect outliers using DBSCAN (noise points).
61
+
62
+ Args:
63
+ X: Feature matrix (N, D)
64
+ eps: DBSCAN epsilon parameter
65
+ min_samples: DBSCAN minimum samples parameter
66
+
67
+ Returns:
68
+ Boolean array of shape (N,) where True = outlier
69
+ """
70
+ X = to_numpy(X)
71
+ if USE_GPU:
72
+ labels = cuDBSCAN(eps=eps, min_samples=min_samples).fit_predict(cp.asarray(X)).get()
73
+ else:
74
+ labels = skDBSCAN(eps=eps, min_samples=min_samples, n_jobs=-1).fit_predict(X)
75
+ return labels == -1
76
+
77
+
78
+ def extract_action_name(filename, model_type):
79
+ """Extract action category from embedding filename based on model type."""
80
+ name = os.path.basename(filename)
81
+ suffix = f'_{model_type}_embeddings'
82
+ name = name.replace(suffix + '.pt', '').replace(suffix + '.pth', '')
83
+ return name
84
+
85
+
86
+ def process_all_embeddings(emb_dir, eps, min_samples, output_dir, model_type='clip',
87
+ max_videos=None, action_filter=None):
88
+ """
89
+ Process all embeddings and export predictions to CSV files.
90
+
91
+ Args:
92
+ emb_dir: Directory containing embedding .pt files
93
+ eps: DBSCAN epsilon parameter
94
+ min_samples: DBSCAN minimum samples parameter
95
+ output_dir: Directory to save CSV predictions
96
+ model_type: Model type to load ('clip', 'dinov2', or 'resnet18')
97
+ max_videos: Limit processing to first N videos
98
+ action_filter: Filter to specific action category
99
+ """
100
+ # Filter files by model type (e.g., *_clip_embeddings.pt, *_dinov2_embeddings.pt, or *_resnet18_embeddings.pt)
101
+ pattern = f"*_{model_type}_embeddings.pt"
102
+ pt_files = sorted(glob.glob(os.path.join(emb_dir, pattern)))
103
+
104
+ if action_filter:
105
+ pt_files = [f for f in pt_files if action_filter.lower() in os.path.basename(f).lower()]
106
+ print(f"Filtering to action: {action_filter}")
107
+ print(f"Found {len(pt_files)} matching file(s)")
108
+
109
+ # Create output directory
110
+ output_path = Path(output_dir)
111
+ output_path.mkdir(exist_ok=True, parents=True)
112
+
113
+ print("=" * 80)
114
+ print("OUTLIER REMOVAL ALGORITHM - DBSCAN")
115
+ print("=" * 80)
116
+ print(f"Model type: {model_type.upper()}")
117
+ print(f"GPU Acceleration: {'Enabled (cuML)' if USE_GPU else 'Disabled (CPU/sklearn)'}")
118
+ print(f"Embeddings dir: {emb_dir}")
119
+ print(f"Output dir: {output_dir}")
120
+ print(f"DBSCAN parameters: eps={eps}, min_samples={min_samples}")
121
+ print(f"Total embedding files: {len(pt_files)}")
122
+ print("=" * 80)
123
+
124
+ total_videos = 0
125
+
126
+ for pt_path in pt_files:
127
+ data = torch.load(pt_path, map_location="cpu")
128
+ action_name = extract_action_name(pt_path, model_type)
129
+ print(f"Processing action: {action_name}")
130
+
131
+ # Create CSV for this action
132
+ csv_path = output_path / f"{action_name}.csv"
133
+
134
+ with open(csv_path, 'w', newline='') as csvfile:
135
+ writer = csv.writer(csvfile)
136
+ writer.writerow(['video_id', 'predicted_outliers_list'])
137
+
138
+ for video_name, video_data in data.items():
139
+ if max_videos and total_videos >= max_videos:
140
+ break
141
+
142
+ total_videos += 1
143
+ embeddings = video_data["embeddings"]
144
+
145
+ # Run DBSCAN outlier detection
146
+ predictions = dbscan_outliers(embeddings, eps=eps, min_samples=min_samples)
147
+
148
+ # Convert boolean array to list of outlier indices
149
+ outlier_indices = np.where(predictions)[0].tolist()
150
+ outliers_str = ",".join(map(str, outlier_indices))
151
+
152
+ # Write to CSV
153
+ writer.writerow([video_name, outliers_str])
154
+
155
+ num_outliers = predictions.sum()
156
+ num_frames = len(embeddings)
157
+
158
+ if max_videos and total_videos >= max_videos:
159
+ break
160
+
161
+ print("\n" + "=" * 80)
162
+ print("PROCESSING COMPLETE")
163
+ print("=" * 80)
164
+ print(f"Total videos processed: {total_videos}")
165
+ print(f"CSV files saved to: {output_path.absolute()}")
166
+ print("\nNext step: Use generate_cleaned_videos_from_predictions.py to create cleaned videos")
167
+ print("=" * 80)
168
+
169
+
170
+ def main():
171
+ parser = argparse.ArgumentParser(
172
+ description="Outlier removal algorithm using DBSCAN: detect outliers and export predictions to CSV"
173
+ )
174
+
175
+ parser.add_argument("--embeddings-dir", required=True,
176
+ help="Directory containing embedding .pt files")
177
+ parser.add_argument("--output-dir", default="./outlier_artifacts/cleaned_CSVs",
178
+ help="Directory to save prediction CSV files")
179
+ parser.add_argument("--model-type", type=str, choices=['clip', 'dinov2', 'resnet18'], default='clip',
180
+ help="Model type to load: 'clip', 'dinov2', or 'resnet18' (default: clip)")
181
+ parser.add_argument("--max-videos", type=int,
182
+ help="Limit processing to first N videos")
183
+ parser.add_argument("--action-filter",
184
+ help="Filter to specific action category (e.g., 'Crawling')")
185
+
186
+ # DBSCAN parameters
187
+ parser.add_argument("--eps", type=float, default=0.5,
188
+ help="DBSCAN: Epsilon parameter")
189
+ parser.add_argument("--min-samples", type=int, default=40,
190
+ help="DBSCAN: Minimum samples parameter")
191
+
192
+ args = parser.parse_args()
193
+
194
+ process_all_embeddings(
195
+ emb_dir=args.embeddings_dir,
196
+ eps=args.eps,
197
+ min_samples=args.min_samples,
198
+ output_dir=args.output_dir,
199
+ model_type=args.model_type,
200
+ max_videos=args.max_videos,
201
+ action_filter=args.action_filter
202
+ )
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
reorder_frames_algorithm.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Frame order reconstruction algorithm using MSE and greedy path construction.
4
+
5
+ Reconstructs temporal frame order from shuffled videos using grayscale MSE matrix,
6
+ MST diameter endpoints, and double-ended greedy path building with local refinement.
7
+
8
+ Usage:
9
+ # Process shuffled videos and CSVs from shuffled_artifacts
10
+ python reorder_frames_algorithm.py --csv_dir ./shuffled_artifacts/shuffled_CSVs --videos_dir ./shuffled_artifacts/shuffled_videos --out_dir ./shuffled_artifacts/ordered_CSVs
11
+
12
+ Note: To generate reordered videos from predictions, use generate_ordered_videos_from_predictions.py
13
+ """
14
+
15
+ import argparse
16
+ import os
17
+ import glob
18
+
19
+ import cv2
20
+ import numpy as np
21
+ import pandas as pd
22
+ import torch
23
+
24
+
25
+ # =========================
26
+ # Config
27
+ # =========================
28
+
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+ IMG_SIZE = 64
31
+ VIDEO_EXTS = (".avi", ".mp4", ".mov", ".mkv")
32
+
33
+ # =========================
34
+ # Pairwise MSE on GPU
35
+ # =========================
36
+
37
+ def compute_mse_matrix(frames: torch.Tensor) -> torch.Tensor:
38
+ """
39
+ frames: [N, 1, H, W] on DEVICE
40
+ Returns:
41
+ mse[i,j]: mean squared error between frame i and j
42
+ """
43
+ N = frames.shape[0]
44
+ flat = frames.view(N, -1).float() # [N, D]
45
+
46
+ sq = (flat ** 2).sum(dim=1, keepdim=True) # [N,1]
47
+ dist2 = sq + sq.t() - 2.0 * (flat @ flat.t())
48
+ dist2 = torch.clamp(dist2, min=0.0)
49
+
50
+ D = flat.shape[1]
51
+ mse = dist2 / D
52
+ mse.fill_diagonal_(0.0)
53
+ return mse
54
+
55
+ # =========================
56
+ # Utils
57
+ # =========================
58
+
59
+ def _mst_endpoints_via_diameter(mse: torch.Tensor):
60
+ """
61
+ Build an MST on the dense MSE matrix (edge weights = mse).
62
+ Return (u, v) = endpoints of the MST diameter (longest weighted path).
63
+ """
64
+ N = mse.shape[0]
65
+ if N <= 1:
66
+ return (0, 0)
67
+
68
+ device = mse.device
69
+ used = torch.zeros(N, dtype=torch.bool, device=device)
70
+ dist = torch.full((N,), float('inf'), device=device)
71
+ parent = torch.full((N,), -1, dtype=torch.long, device=device)
72
+
73
+ # start Prim from node 0
74
+ used[0] = True
75
+ dist = mse[0].clone()
76
+ dist[0] = float('inf')
77
+
78
+ for _ in range(N - 1):
79
+ masked = dist.clone()
80
+ masked[used] = float('inf')
81
+ j = int(torch.argmin(masked).item())
82
+ used[j] = True
83
+
84
+ # relax edges to unused nodes
85
+ w = mse[j]
86
+ update_mask = (~used) & (w < dist)
87
+ dist[update_mask] = w[update_mask]
88
+ parent[update_mask] = j
89
+
90
+ # build adjacency list of the MST
91
+ adj = [[] for _ in range(N)]
92
+ for v in range(1, N):
93
+ u = int(parent[v].item())
94
+ if u >= 0:
95
+ w = float(mse[u, v].item())
96
+ adj[u].append((v, w))
97
+ adj[v].append((u, w))
98
+
99
+ def _farthest(src: int):
100
+ # single-source longest distances on a tree via DFS
101
+ distv = [-1.0] * N
102
+ distv[src] = 0.0
103
+ stack = [src]
104
+ while stack:
105
+ x = stack.pop()
106
+ for y, w in adj[x]:
107
+ if distv[y] < 0.0:
108
+ distv[y] = distv[x] + w
109
+ stack.append(y)
110
+ far = max(range(N), key=lambda k: distv[k])
111
+ return far, distv[far]
112
+
113
+ a, _ = _farthest(0)
114
+ b, _ = _farthest(a)
115
+ return a, b
116
+
117
+ def double_ended_greedy_from_pair(left: int, right: int, mse: torch.Tensor):
118
+ """
119
+ Maintain a path [left ... right]. At each step, attach the unused frame
120
+ with minimal MSE to either end (choose the cheaper side).
121
+ """
122
+ N = mse.shape[0]
123
+ used = torch.zeros(N, dtype=torch.bool, device=mse.device)
124
+ used[left] = True
125
+ used[right] = True
126
+
127
+ path = [left, right]
128
+ inf = float('inf')
129
+
130
+ for _ in range(N - 2):
131
+ # best to left
132
+ candL = mse[:, left].clone()
133
+ candL[used] = inf
134
+ kL = int(torch.argmin(candL).item())
135
+ dL = float(candL[kL])
136
+
137
+ # best to right
138
+ candR = mse[:, right].clone()
139
+ candR[used] = inf
140
+ kR = int(torch.argmin(candR).item())
141
+ dR = float(candR[kR])
142
+
143
+ if dL <= dR:
144
+ path.insert(0, kL)
145
+ used[kL] = True
146
+ left = kL
147
+ else:
148
+ path.append(kR)
149
+ used[kR] = True
150
+ right = kR
151
+
152
+ return path
153
+
154
+
155
+ def parse_shuffled_list(s: str):
156
+ """
157
+ Parse 'shuffled_frames_list' column.
158
+ Example cell:
159
+ "130,288,254,17,63,..."
160
+ """
161
+ return [int(x) for x in str(s).split(",") if x.strip() != ""]
162
+
163
+
164
+
165
+ def find_video_path(video_id: str, videos_dir: str) -> str:
166
+ """
167
+ Resolve the video path for a given video_id.
168
+
169
+ Tries:
170
+ - videos_dir / "<video_id>"
171
+ - videos_dir / "<video_id>.avi"
172
+ - videos_dir / "<video_id>.*" where extension in VIDEO_EXTS
173
+ """
174
+ # direct exact path (some CSVs store full filename)
175
+ direct = os.path.join(videos_dir, video_id)
176
+ if os.path.isfile(direct):
177
+ return direct
178
+
179
+ # try with .avi extension
180
+ direct_avi = direct + ".avi"
181
+ if os.path.isfile(direct_avi):
182
+ return direct_avi
183
+
184
+ # fallback: any file that starts with video_id
185
+ pattern = os.path.join(videos_dir, f"{video_id}*")
186
+ candidates = [
187
+ p for p in glob.glob(pattern)
188
+ if os.path.splitext(p)[1].lower() in VIDEO_EXTS
189
+ ]
190
+
191
+ if not candidates:
192
+ raise FileNotFoundError(
193
+ f"No video file found for video_id={video_id} in {videos_dir}"
194
+ )
195
+
196
+ # deterministic choice
197
+ candidates.sort(key=lambda x: (len(os.path.basename(x)), x))
198
+ return candidates[0]
199
+
200
+
201
+ # =========================
202
+ # Video loading (grayscale)
203
+ # =========================
204
+
205
+ def load_video_gray(video_path: str, expected_num_frames: int = None) -> torch.Tensor:
206
+ """
207
+ Load frames from a shuffled video as grayscale,
208
+ resize to IMG_SIZE, and send to DEVICE.
209
+
210
+ Returns:
211
+ frames: [N, 1, H, W] float32 in [0,1] on DEVICE
212
+ """
213
+ if not os.path.isfile(video_path):
214
+ raise FileNotFoundError(f"Video not found: {video_path}")
215
+
216
+ cap = cv2.VideoCapture(video_path)
217
+ if not cap.isOpened():
218
+ raise IOError(f"Cannot open video: {video_path}")
219
+
220
+ frames = []
221
+ while True:
222
+ ok, frame = cap.read()
223
+ if not ok:
224
+ break
225
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
226
+ gray = cv2.resize(gray, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
227
+ frames.append(gray)
228
+ cap.release()
229
+
230
+ if len(frames) == 0:
231
+ raise ValueError(f"No frames read from {video_path}")
232
+
233
+ if expected_num_frames is not None and len(frames) != expected_num_frames:
234
+ print(
235
+ f"[WARN] {os.path.basename(video_path)}: "
236
+ f"expected_num_frames={expected_num_frames}, read={len(frames)}"
237
+ )
238
+
239
+ arr = np.stack(frames, axis=0) # [N, H, W]
240
+ t = torch.from_numpy(arr).float() # [N, H, W]
241
+ t = t.unsqueeze(1) / 255.0 # [N, 1, H, W] in [0,1]
242
+ return t.to(DEVICE)
243
+
244
+
245
+ # =========================
246
+ # Path construction
247
+ # =========================
248
+
249
+ def build_best_path(mse: torch.Tensor):
250
+ """Build temporal path using MST diameter endpoints and double-ended greedy growth."""
251
+ N = mse.shape[0]
252
+ if N <= 2:
253
+ return list(range(N))
254
+
255
+ # smart seed via MST diameter
256
+ a, b = _mst_endpoints_via_diameter(mse)
257
+
258
+ # grow from both ends
259
+ path = double_ended_greedy_from_pair(a, b, mse)
260
+
261
+ return path
262
+
263
+
264
+ # =========================
265
+ # Per-video prediction
266
+ # =========================
267
+
268
+ def predict_order_for_video(video_id: str,
269
+ shuffled_order,
270
+ videos_dir: str):
271
+ """
272
+ Pipeline for a single video_id:
273
+ - load shuffled video frames
274
+ - compute MSE matrix
275
+ - build best greedy path
276
+ - refine path
277
+ - map positions to original frame indices
278
+ """
279
+ shuffled_order = list(shuffled_order)
280
+ expected_num_frames = len(shuffled_order)
281
+
282
+ video_path = find_video_path(video_id, videos_dir)
283
+ frames = load_video_gray(video_path, expected_num_frames=expected_num_frames)
284
+ frames = frames[:, 0:1, :, :] # use only Y channel for MSE
285
+ N = frames.shape[0]
286
+
287
+ if N != expected_num_frames:
288
+ print(
289
+ f"[WARN] {video_id}: csv_frames={expected_num_frames}, "
290
+ f"video_frames={N}. Using min of both."
291
+ )
292
+ m = min(expected_num_frames, N)
293
+ shuffled_order = shuffled_order[:m]
294
+ frames = frames[:m]
295
+ N = m
296
+
297
+ if N <= 1:
298
+ return [int(x) for x in shuffled_order]
299
+
300
+ mse = compute_mse_matrix(frames)
301
+ path = build_best_path(mse)
302
+
303
+ predicted = [int(shuffled_order[idx]) for idx in path]
304
+ return predicted
305
+
306
+ # =========================
307
+ # Process all CSVs
308
+ # =========================
309
+
310
+ def process_all_csvs(csv_dir: str, videos_dir: str, out_dir: str):
311
+ """
312
+ For each CSV in csv_dir:
313
+ - read video_id, shuffled_frames_list
314
+ - compute predicted order for each video
315
+ - write a prediction CSV with same filename into out_dir
316
+ """
317
+ os.makedirs(out_dir, exist_ok=True)
318
+
319
+ csv_paths = sorted(glob.glob(os.path.join(csv_dir, "*.csv")))
320
+ if not csv_paths:
321
+ raise FileNotFoundError(f"No CSV files found in {csv_dir}")
322
+
323
+ for csv_path in csv_paths:
324
+ df = pd.read_csv(csv_path)
325
+ rows = []
326
+
327
+ if "video_id" not in df.columns or "shuffled_frames_list" not in df.columns:
328
+ raise ValueError(
329
+ f"CSV {csv_path} must contain 'video_id' and 'shuffled_frames_list' columns."
330
+ )
331
+
332
+ for _, row in df.iterrows():
333
+ video_id = str(row["video_id"]).strip()
334
+ shuffled_order = parse_shuffled_list(row["shuffled_frames_list"])
335
+ pred = predict_order_for_video(video_id, shuffled_order, videos_dir)
336
+ pred_str = ",".join(str(x) for x in pred)
337
+ rows.append({"video_id": video_id, "predicted_frames_list": pred_str})
338
+
339
+ out_csv = os.path.join(out_dir, os.path.basename(csv_path))
340
+ pd.DataFrame(rows).to_csv(out_csv, index=False)
341
+ print(f"[OK] {os.path.basename(csv_path)} -> {os.path.basename(out_csv)}")
342
+
343
+
344
+ # =========================
345
+ # CLI
346
+ # =========================
347
+
348
+ def parse_args():
349
+ parser = argparse.ArgumentParser(
350
+ description="Reconstruct frame order from shuffled videos "
351
+ "using grayscale MSE and CSV metadata."
352
+ )
353
+ parser.add_argument(
354
+ "--csv_dir",
355
+ type=str,
356
+ required=True,
357
+ help="Directory with shuffled CSV files (e.g. shuffled_csvs).",
358
+ )
359
+ parser.add_argument(
360
+ "--videos_dir",
361
+ type=str,
362
+ required=True,
363
+ help="Directory with shuffled videos (e.g. UCF101_videos_shuffled).",
364
+ )
365
+ parser.add_argument(
366
+ "--out_dir",
367
+ type=str,
368
+ default="./shuffled_artifacts/ordered_CSVs",
369
+ help="Output directory for prediction CSVs.",
370
+ )
371
+ return parser.parse_args()
372
+
373
+
374
+ def main():
375
+ args = parse_args()
376
+ process_all_csvs(args.csv_dir, args.videos_dir, args.out_dir)
377
+
378
+
379
+ if __name__ == "__main__":
380
+ main()