File size: 4,752 Bytes
db25ead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import csv
from pathlib import Path
import os
os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
from decord import VideoReader, cpu

from module.frequency_dct import compute_twostream_dct
from module.read_frame_decord import (
    sample_frames_uniform,
    collect_needed,
    cache_needed_frames,
    read_window_from_cache,
)

def process_video(
    vr,
    # resize
    size=224,
    # anchors + window
    num_anchors=16,
    win=6,
    win_step=1,
    # DCT / weights
    block=16,
):
    """
    Returns (frame_all, w_art_all, w_str_all, anchors_kept).
    frame_all: anchor frames (RGB uint8, HxWx3)
    w_art_all/w_str_all: maps (float32, HxW)
    anchors_kept: frame indices used per anchor - list[list[int]]
    """
    total_frames = len(vr)
    if total_frames <= 1:
        raise RuntimeError(f"Video too short / invalid frame count: {total_frames}")

    anchor_idxs = sample_frames_uniform(total_frames, num_anchors, win=win, win_step=win_step)
    needed = collect_needed(anchor_idxs, total_frames, win, win_step)
    cache = cache_needed_frames(vr, needed, size)

    frame_all, w_art_all, w_str_all = [], [], []
    anchors_kept = []

    for anchor in anchor_idxs:
        out = read_window_from_cache(cache, anchor, total_frames, win, win_step)
        if out is None:
            continue

        anchor_frame, gray_seq, idxs = out

        w_art, w_str, _dbg = compute_twostream_dct(
            gray_seq,
            block=block,
        )

        frame_all.append(anchor_frame)
        w_art_all.append(w_art.astype(np.float32, copy=False))
        w_str_all.append(w_str.astype(np.float32, copy=False))
        anchors_kept.append([int(x) for x in idxs])

    return frame_all, w_art_all, w_str_all, anchors_kept


def compute_weight_map(frame_all, w_art_all, w_str_all):
    if len(frame_all) == 0:
        raise ValueError("No frames produced.")
    if not (len(frame_all) == len(w_art_all) == len(w_str_all)):
        raise ValueError(
            f"Length mismatch: frames={len(frame_all)}, w_art_all={len(w_art_all)}, w_str_all={len(w_str_all)}"
        )
    frames_np = np.stack(frame_all, axis=0)  # (N,H,W,3) uint8
    w_art_np = np.stack(w_art_all, axis=0)         # (N,H,W) float32
    w_str_np = np.stack(w_str_all, axis=0)         # (N,H,W) float32
    return frames_np, w_art_np, w_str_np

def read_vid_mos_csv(csv_path):
    rows = []
    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        if not reader.fieldnames:
            raise RuntimeError("CSV has no header")
        for r in reader:
            vid = str(r["vid"]).strip()
            mos = float(r["mos"])
            rows.append((vid, mos))
    return rows

def rng_rows(rows, seed=0):
    rng = np.random.default_rng(int(seed))
    idx = np.arange(len(rows))
    rng.shuffle(idx)
    train = [rows[i] for i in idx[:]]
    return train

if __name__ == "__main__":
    csv_path = "/home/xinyi/Project/FD-VQA/metadata/TEST_metadata.csv"
    db_path = "/home/xinyi/Project/FD-VQA/test_videos/"
    # video_path = "/home/xinyi/Project/FD-VQA/test_videos/NesAirFortressIn4108.37ByTool23.mp4"

    rows = read_vid_mos_csv(str(csv_path))
    train = rng_rows(rows)
    for i in range(len(train)):
        vid, mos = train[i]
        print(vid, mos)
        # get video path
        base_path = Path(db_path) / vid
        video_path = None
        for ext in ("mp4", "avi", "mkv"):
            p = Path(str(base_path) + f".{ext}")
            if p.exists():
                video_path = str(p)
                break
        if video_path is None:
            raise FileNotFoundError(f"Cannot find {vid} video")

        try:
            # read video
            vr = VideoReader(video_path, ctx=cpu(0))
            frame_all, w_art_all, w_str_all, anchors_kept = process_video(
                vr,
                size=224,
                num_anchors=16,
                win=6,
                win_step=1,
                block=16,
            )
            frames_np, w_art_np, w_str_np = compute_weight_map(frame_all, w_art_all, w_str_all)
            print("frames_np:", frames_np.shape, frames_np.dtype)
            print("w_art_np:", w_art_np.shape, w_art_np.dtype)
            print("w_str_np:", w_str_np.shape, w_str_np.dtype)
            print("anchors_kept:", len(anchors_kept), "example:", anchors_kept)
        except Exception as e:
            print("\n[DATA ERROR]")
            print("idx:", i)
            print("vid:", vid)
            print("path:", video_path)
            raise
        finally:
            # release decord video reader
            try:
                if vr is not None:
                    del vr
            except Exception:
                pass