#!/usr/bin/env python3 """ Frame order reconstruction algorithm using MSE and greedy path construction. Reconstructs temporal frame order from shuffled videos using grayscale MSE matrix, MST diameter endpoints, and double-ended greedy path building with local refinement. Usage: # Process shuffled videos and CSVs from shuffled_artifacts python reorder_frames_algorithm.py --csv_dir ./shuffled_artifacts/shuffled_CSVs --videos_dir ./shuffled_artifacts/shuffled_videos --out_dir ./shuffled_artifacts/ordered_CSVs Note: To generate reordered videos from predictions, use generate_ordered_videos_from_predictions.py """ import argparse import os import glob import cv2 import numpy as np import pandas as pd import torch # ========================= # Config # ========================= DEVICE = "cuda" if torch.cuda.is_available() else "cpu" IMG_SIZE = 64 VIDEO_EXTS = (".avi", ".mp4", ".mov", ".mkv") # ========================= # Pairwise MSE on GPU # ========================= def compute_mse_matrix(frames: torch.Tensor) -> torch.Tensor: """ frames: [N, 1, H, W] on DEVICE Returns: mse[i,j]: mean squared error between frame i and j """ N = frames.shape[0] flat = frames.view(N, -1).float() # [N, D] sq = (flat ** 2).sum(dim=1, keepdim=True) # [N,1] dist2 = sq + sq.t() - 2.0 * (flat @ flat.t()) dist2 = torch.clamp(dist2, min=0.0) D = flat.shape[1] mse = dist2 / D mse.fill_diagonal_(0.0) return mse # ========================= # Utils # ========================= def _mst_endpoints_via_diameter(mse: torch.Tensor): """ Build an MST on the dense MSE matrix (edge weights = mse). Return (u, v) = endpoints of the MST diameter (longest weighted path). """ N = mse.shape[0] if N <= 1: return (0, 0) device = mse.device used = torch.zeros(N, dtype=torch.bool, device=device) dist = torch.full((N,), float('inf'), device=device) parent = torch.full((N,), -1, dtype=torch.long, device=device) # start Prim from node 0 used[0] = True dist = mse[0].clone() dist[0] = float('inf') for _ in range(N - 1): masked = dist.clone() masked[used] = float('inf') j = int(torch.argmin(masked).item()) used[j] = True # relax edges to unused nodes w = mse[j] update_mask = (~used) & (w < dist) dist[update_mask] = w[update_mask] parent[update_mask] = j # build adjacency list of the MST adj = [[] for _ in range(N)] for v in range(1, N): u = int(parent[v].item()) if u >= 0: w = float(mse[u, v].item()) adj[u].append((v, w)) adj[v].append((u, w)) def _farthest(src: int): # single-source longest distances on a tree via DFS distv = [-1.0] * N distv[src] = 0.0 stack = [src] while stack: x = stack.pop() for y, w in adj[x]: if distv[y] < 0.0: distv[y] = distv[x] + w stack.append(y) far = max(range(N), key=lambda k: distv[k]) return far, distv[far] a, _ = _farthest(0) b, _ = _farthest(a) return a, b def double_ended_greedy_from_pair(left: int, right: int, mse: torch.Tensor): """ Maintain a path [left ... right]. At each step, attach the unused frame with minimal MSE to either end (choose the cheaper side). """ N = mse.shape[0] used = torch.zeros(N, dtype=torch.bool, device=mse.device) used[left] = True used[right] = True path = [left, right] inf = float('inf') for _ in range(N - 2): # best to left candL = mse[:, left].clone() candL[used] = inf kL = int(torch.argmin(candL).item()) dL = float(candL[kL]) # best to right candR = mse[:, right].clone() candR[used] = inf kR = int(torch.argmin(candR).item()) dR = float(candR[kR]) if dL <= dR: path.insert(0, kL) used[kL] = True left = kL else: path.append(kR) used[kR] = True right = kR return path def parse_shuffled_list(s: str): """ Parse 'shuffled_frames_list' column. Example cell: "130,288,254,17,63,..." """ return [int(x) for x in str(s).split(",") if x.strip() != ""] def find_video_path(video_id: str, videos_dir: str) -> str: """ Resolve the video path for a given video_id. Tries: - videos_dir / "" - videos_dir / ".avi" - videos_dir / ".*" where extension in VIDEO_EXTS """ # direct exact path (some CSVs store full filename) direct = os.path.join(videos_dir, video_id) if os.path.isfile(direct): return direct # try with .avi extension direct_avi = direct + ".avi" if os.path.isfile(direct_avi): return direct_avi # fallback: any file that starts with video_id pattern = os.path.join(videos_dir, f"{video_id}*") candidates = [ p for p in glob.glob(pattern) if os.path.splitext(p)[1].lower() in VIDEO_EXTS ] if not candidates: raise FileNotFoundError( f"No video file found for video_id={video_id} in {videos_dir}" ) # deterministic choice candidates.sort(key=lambda x: (len(os.path.basename(x)), x)) return candidates[0] # ========================= # Video loading (grayscale) # ========================= def load_video_gray(video_path: str, expected_num_frames: int = None) -> torch.Tensor: """ Load frames from a shuffled video as grayscale, resize to IMG_SIZE, and send to DEVICE. Returns: frames: [N, 1, H, W] float32 in [0,1] on DEVICE """ if not os.path.isfile(video_path): raise FileNotFoundError(f"Video not found: {video_path}") cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise IOError(f"Cannot open video: {video_path}") frames = [] while True: ok, frame = cap.read() if not ok: break gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) gray = cv2.resize(gray, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA) frames.append(gray) cap.release() if len(frames) == 0: raise ValueError(f"No frames read from {video_path}") if expected_num_frames is not None and len(frames) != expected_num_frames: print( f"[WARN] {os.path.basename(video_path)}: " f"expected_num_frames={expected_num_frames}, read={len(frames)}" ) arr = np.stack(frames, axis=0) # [N, H, W] t = torch.from_numpy(arr).float() # [N, H, W] t = t.unsqueeze(1) / 255.0 # [N, 1, H, W] in [0,1] return t.to(DEVICE) # ========================= # Path construction # ========================= def build_best_path(mse: torch.Tensor): """Build temporal path using MST diameter endpoints and double-ended greedy growth.""" N = mse.shape[0] if N <= 2: return list(range(N)) # smart seed via MST diameter a, b = _mst_endpoints_via_diameter(mse) # grow from both ends path = double_ended_greedy_from_pair(a, b, mse) return path # ========================= # Per-video prediction # ========================= def predict_order_for_video(video_id: str, shuffled_order, videos_dir: str): """ Pipeline for a single video_id: - load shuffled video frames - compute MSE matrix - build best greedy path - refine path - map positions to original frame indices """ shuffled_order = list(shuffled_order) expected_num_frames = len(shuffled_order) video_path = find_video_path(video_id, videos_dir) frames = load_video_gray(video_path, expected_num_frames=expected_num_frames) frames = frames[:, 0:1, :, :] # use only Y channel for MSE N = frames.shape[0] if N != expected_num_frames: print( f"[WARN] {video_id}: csv_frames={expected_num_frames}, " f"video_frames={N}. Using min of both." ) m = min(expected_num_frames, N) shuffled_order = shuffled_order[:m] frames = frames[:m] N = m if N <= 1: return [int(x) for x in shuffled_order] mse = compute_mse_matrix(frames) path = build_best_path(mse) predicted = [int(shuffled_order[idx]) for idx in path] return predicted # ========================= # Process all CSVs # ========================= def process_all_csvs(csv_dir: str, videos_dir: str, out_dir: str): """ For each CSV in csv_dir: - read video_id, shuffled_frames_list - compute predicted order for each video - write a prediction CSV with same filename into out_dir """ os.makedirs(out_dir, exist_ok=True) csv_paths = sorted(glob.glob(os.path.join(csv_dir, "*.csv"))) if not csv_paths: raise FileNotFoundError(f"No CSV files found in {csv_dir}") for csv_path in csv_paths: df = pd.read_csv(csv_path) rows = [] if "video_id" not in df.columns or "shuffled_frames_list" not in df.columns: raise ValueError( f"CSV {csv_path} must contain 'video_id' and 'shuffled_frames_list' columns." ) for _, row in df.iterrows(): video_id = str(row["video_id"]).strip() shuffled_order = parse_shuffled_list(row["shuffled_frames_list"]) pred = predict_order_for_video(video_id, shuffled_order, videos_dir) pred_str = ",".join(str(x) for x in pred) rows.append({"video_id": video_id, "predicted_frames_list": pred_str}) out_csv = os.path.join(out_dir, os.path.basename(csv_path)) pd.DataFrame(rows).to_csv(out_csv, index=False) print(f"[OK] {os.path.basename(csv_path)} -> {os.path.basename(out_csv)}") # ========================= # CLI # ========================= def parse_args(): parser = argparse.ArgumentParser( description="Reconstruct frame order from shuffled videos " "using grayscale MSE and CSV metadata." ) parser.add_argument( "--csv_dir", type=str, required=True, help="Directory with shuffled CSV files (e.g. shuffled_csvs).", ) parser.add_argument( "--videos_dir", type=str, required=True, help="Directory with shuffled videos (e.g. UCF101_videos_shuffled).", ) parser.add_argument( "--out_dir", type=str, default="./shuffled_artifacts/ordered_CSVs", help="Output directory for prediction CSVs.", ) return parser.parse_args() def main(): args = parse_args() process_all_csvs(args.csv_dir, args.videos_dir, args.out_dir) if __name__ == "__main__": main()