|
|
|
|
|
""" |
|
|
Frame order reconstruction algorithm using MSE and greedy path construction. |
|
|
|
|
|
Reconstructs temporal frame order from shuffled videos using grayscale MSE matrix, |
|
|
MST diameter endpoints, and double-ended greedy path building with local refinement. |
|
|
|
|
|
Usage: |
|
|
# Process shuffled videos and CSVs from shuffled_artifacts |
|
|
python reorder_frames_algorithm.py --csv_dir ./shuffled_artifacts/shuffled_CSVs --videos_dir ./shuffled_artifacts/shuffled_videos --out_dir ./shuffled_artifacts/ordered_CSVs |
|
|
|
|
|
Note: To generate reordered videos from predictions, use generate_ordered_videos_from_predictions.py |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import glob |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
IMG_SIZE = 64 |
|
|
VIDEO_EXTS = (".avi", ".mp4", ".mov", ".mkv") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_mse_matrix(frames: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
frames: [N, 1, H, W] on DEVICE |
|
|
Returns: |
|
|
mse[i,j]: mean squared error between frame i and j |
|
|
""" |
|
|
N = frames.shape[0] |
|
|
flat = frames.view(N, -1).float() |
|
|
|
|
|
sq = (flat ** 2).sum(dim=1, keepdim=True) |
|
|
dist2 = sq + sq.t() - 2.0 * (flat @ flat.t()) |
|
|
dist2 = torch.clamp(dist2, min=0.0) |
|
|
|
|
|
D = flat.shape[1] |
|
|
mse = dist2 / D |
|
|
mse.fill_diagonal_(0.0) |
|
|
return mse |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _mst_endpoints_via_diameter(mse: torch.Tensor): |
|
|
""" |
|
|
Build an MST on the dense MSE matrix (edge weights = mse). |
|
|
Return (u, v) = endpoints of the MST diameter (longest weighted path). |
|
|
""" |
|
|
N = mse.shape[0] |
|
|
if N <= 1: |
|
|
return (0, 0) |
|
|
|
|
|
device = mse.device |
|
|
used = torch.zeros(N, dtype=torch.bool, device=device) |
|
|
dist = torch.full((N,), float('inf'), device=device) |
|
|
parent = torch.full((N,), -1, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
used[0] = True |
|
|
dist = mse[0].clone() |
|
|
dist[0] = float('inf') |
|
|
|
|
|
for _ in range(N - 1): |
|
|
masked = dist.clone() |
|
|
masked[used] = float('inf') |
|
|
j = int(torch.argmin(masked).item()) |
|
|
used[j] = True |
|
|
|
|
|
|
|
|
w = mse[j] |
|
|
update_mask = (~used) & (w < dist) |
|
|
dist[update_mask] = w[update_mask] |
|
|
parent[update_mask] = j |
|
|
|
|
|
|
|
|
adj = [[] for _ in range(N)] |
|
|
for v in range(1, N): |
|
|
u = int(parent[v].item()) |
|
|
if u >= 0: |
|
|
w = float(mse[u, v].item()) |
|
|
adj[u].append((v, w)) |
|
|
adj[v].append((u, w)) |
|
|
|
|
|
def _farthest(src: int): |
|
|
|
|
|
distv = [-1.0] * N |
|
|
distv[src] = 0.0 |
|
|
stack = [src] |
|
|
while stack: |
|
|
x = stack.pop() |
|
|
for y, w in adj[x]: |
|
|
if distv[y] < 0.0: |
|
|
distv[y] = distv[x] + w |
|
|
stack.append(y) |
|
|
far = max(range(N), key=lambda k: distv[k]) |
|
|
return far, distv[far] |
|
|
|
|
|
a, _ = _farthest(0) |
|
|
b, _ = _farthest(a) |
|
|
return a, b |
|
|
|
|
|
def double_ended_greedy_from_pair(left: int, right: int, mse: torch.Tensor): |
|
|
""" |
|
|
Maintain a path [left ... right]. At each step, attach the unused frame |
|
|
with minimal MSE to either end (choose the cheaper side). |
|
|
""" |
|
|
N = mse.shape[0] |
|
|
used = torch.zeros(N, dtype=torch.bool, device=mse.device) |
|
|
used[left] = True |
|
|
used[right] = True |
|
|
|
|
|
path = [left, right] |
|
|
inf = float('inf') |
|
|
|
|
|
for _ in range(N - 2): |
|
|
|
|
|
candL = mse[:, left].clone() |
|
|
candL[used] = inf |
|
|
kL = int(torch.argmin(candL).item()) |
|
|
dL = float(candL[kL]) |
|
|
|
|
|
|
|
|
candR = mse[:, right].clone() |
|
|
candR[used] = inf |
|
|
kR = int(torch.argmin(candR).item()) |
|
|
dR = float(candR[kR]) |
|
|
|
|
|
if dL <= dR: |
|
|
path.insert(0, kL) |
|
|
used[kL] = True |
|
|
left = kL |
|
|
else: |
|
|
path.append(kR) |
|
|
used[kR] = True |
|
|
right = kR |
|
|
|
|
|
return path |
|
|
|
|
|
|
|
|
def parse_shuffled_list(s: str): |
|
|
""" |
|
|
Parse 'shuffled_frames_list' column. |
|
|
Example cell: |
|
|
"130,288,254,17,63,..." |
|
|
""" |
|
|
return [int(x) for x in str(s).split(",") if x.strip() != ""] |
|
|
|
|
|
|
|
|
|
|
|
def find_video_path(video_id: str, videos_dir: str) -> str: |
|
|
""" |
|
|
Resolve the video path for a given video_id. |
|
|
|
|
|
Tries: |
|
|
- videos_dir / "<video_id>" |
|
|
- videos_dir / "<video_id>.avi" |
|
|
- videos_dir / "<video_id>.*" where extension in VIDEO_EXTS |
|
|
""" |
|
|
|
|
|
direct = os.path.join(videos_dir, video_id) |
|
|
if os.path.isfile(direct): |
|
|
return direct |
|
|
|
|
|
|
|
|
direct_avi = direct + ".avi" |
|
|
if os.path.isfile(direct_avi): |
|
|
return direct_avi |
|
|
|
|
|
|
|
|
pattern = os.path.join(videos_dir, f"{video_id}*") |
|
|
candidates = [ |
|
|
p for p in glob.glob(pattern) |
|
|
if os.path.splitext(p)[1].lower() in VIDEO_EXTS |
|
|
] |
|
|
|
|
|
if not candidates: |
|
|
raise FileNotFoundError( |
|
|
f"No video file found for video_id={video_id} in {videos_dir}" |
|
|
) |
|
|
|
|
|
|
|
|
candidates.sort(key=lambda x: (len(os.path.basename(x)), x)) |
|
|
return candidates[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_video_gray(video_path: str, expected_num_frames: int = None) -> torch.Tensor: |
|
|
""" |
|
|
Load frames from a shuffled video as grayscale, |
|
|
resize to IMG_SIZE, and send to DEVICE. |
|
|
|
|
|
Returns: |
|
|
frames: [N, 1, H, W] float32 in [0,1] on DEVICE |
|
|
""" |
|
|
if not os.path.isfile(video_path): |
|
|
raise FileNotFoundError(f"Video not found: {video_path}") |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
raise IOError(f"Cannot open video: {video_path}") |
|
|
|
|
|
frames = [] |
|
|
while True: |
|
|
ok, frame = cap.read() |
|
|
if not ok: |
|
|
break |
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
|
gray = cv2.resize(gray, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA) |
|
|
frames.append(gray) |
|
|
cap.release() |
|
|
|
|
|
if len(frames) == 0: |
|
|
raise ValueError(f"No frames read from {video_path}") |
|
|
|
|
|
if expected_num_frames is not None and len(frames) != expected_num_frames: |
|
|
print( |
|
|
f"[WARN] {os.path.basename(video_path)}: " |
|
|
f"expected_num_frames={expected_num_frames}, read={len(frames)}" |
|
|
) |
|
|
|
|
|
arr = np.stack(frames, axis=0) |
|
|
t = torch.from_numpy(arr).float() |
|
|
t = t.unsqueeze(1) / 255.0 |
|
|
return t.to(DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_best_path(mse: torch.Tensor): |
|
|
"""Build temporal path using MST diameter endpoints and double-ended greedy growth.""" |
|
|
N = mse.shape[0] |
|
|
if N <= 2: |
|
|
return list(range(N)) |
|
|
|
|
|
|
|
|
a, b = _mst_endpoints_via_diameter(mse) |
|
|
|
|
|
|
|
|
path = double_ended_greedy_from_pair(a, b, mse) |
|
|
|
|
|
return path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_order_for_video(video_id: str, |
|
|
shuffled_order, |
|
|
videos_dir: str): |
|
|
""" |
|
|
Pipeline for a single video_id: |
|
|
- load shuffled video frames |
|
|
- compute MSE matrix |
|
|
- build best greedy path |
|
|
- refine path |
|
|
- map positions to original frame indices |
|
|
""" |
|
|
shuffled_order = list(shuffled_order) |
|
|
expected_num_frames = len(shuffled_order) |
|
|
|
|
|
video_path = find_video_path(video_id, videos_dir) |
|
|
frames = load_video_gray(video_path, expected_num_frames=expected_num_frames) |
|
|
frames = frames[:, 0:1, :, :] |
|
|
N = frames.shape[0] |
|
|
|
|
|
if N != expected_num_frames: |
|
|
print( |
|
|
f"[WARN] {video_id}: csv_frames={expected_num_frames}, " |
|
|
f"video_frames={N}. Using min of both." |
|
|
) |
|
|
m = min(expected_num_frames, N) |
|
|
shuffled_order = shuffled_order[:m] |
|
|
frames = frames[:m] |
|
|
N = m |
|
|
|
|
|
if N <= 1: |
|
|
return [int(x) for x in shuffled_order] |
|
|
|
|
|
mse = compute_mse_matrix(frames) |
|
|
path = build_best_path(mse) |
|
|
|
|
|
predicted = [int(shuffled_order[idx]) for idx in path] |
|
|
return predicted |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_all_csvs(csv_dir: str, videos_dir: str, out_dir: str): |
|
|
""" |
|
|
For each CSV in csv_dir: |
|
|
- read video_id, shuffled_frames_list |
|
|
- compute predicted order for each video |
|
|
- write a prediction CSV with same filename into out_dir |
|
|
""" |
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
|
|
csv_paths = sorted(glob.glob(os.path.join(csv_dir, "*.csv"))) |
|
|
if not csv_paths: |
|
|
raise FileNotFoundError(f"No CSV files found in {csv_dir}") |
|
|
|
|
|
for csv_path in csv_paths: |
|
|
df = pd.read_csv(csv_path) |
|
|
rows = [] |
|
|
|
|
|
if "video_id" not in df.columns or "shuffled_frames_list" not in df.columns: |
|
|
raise ValueError( |
|
|
f"CSV {csv_path} must contain 'video_id' and 'shuffled_frames_list' columns." |
|
|
) |
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
video_id = str(row["video_id"]).strip() |
|
|
shuffled_order = parse_shuffled_list(row["shuffled_frames_list"]) |
|
|
pred = predict_order_for_video(video_id, shuffled_order, videos_dir) |
|
|
pred_str = ",".join(str(x) for x in pred) |
|
|
rows.append({"video_id": video_id, "predicted_frames_list": pred_str}) |
|
|
|
|
|
out_csv = os.path.join(out_dir, os.path.basename(csv_path)) |
|
|
pd.DataFrame(rows).to_csv(out_csv, index=False) |
|
|
print(f"[OK] {os.path.basename(csv_path)} -> {os.path.basename(out_csv)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Reconstruct frame order from shuffled videos " |
|
|
"using grayscale MSE and CSV metadata." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--csv_dir", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Directory with shuffled CSV files (e.g. shuffled_csvs).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--videos_dir", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Directory with shuffled videos (e.g. UCF101_videos_shuffled).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--out_dir", |
|
|
type=str, |
|
|
default="./shuffled_artifacts/ordered_CSVs", |
|
|
help="Output directory for prediction CSVs.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
process_all_csvs(args.csv_dir, args.videos_dir, args.out_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|