| import os |
| os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0" |
| from decord import VideoReader, cpu |
| import cv2 |
| import argparse |
| import numpy as np |
| from tqdm import tqdm |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| from module.read_frame_decord import sample_frames_uniform, collect_needed, cache_needed_frames, read_window_from_cache |
|
|
| |
| |
| |
| def norm01(x, eps=1e-6): |
| x = x.astype(np.float32) |
| mn, mx = float(x.min()), float(x.max()) |
| return (x - mn) / (mx - mn + eps) |
|
|
| |
| def upsample_grid(grid, out_hw, interp=cv2.INTER_NEAREST): |
| H, W = out_hw |
| return cv2.resize(grid.astype(np.float32), (W, H), interpolation=interp) |
|
|
| def gaussian_blur(x, sigma=1.0): |
| if sigma <= 0: |
| return x |
| ksize = int(6 * sigma + 1) |
| if ksize % 2 == 0: |
| ksize += 1 |
| if ksize < 3: |
| ksize = 3 |
| return cv2.GaussianBlur(x, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT_101) |
|
|
| |
| def gradient_mag(gray): |
| gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) |
| gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) |
| dst = np.sqrt(gx * gx + gy * gy) |
| |
| |
| return dst |
|
|
| |
| def edge_sobel(gray, thr=0.20, dilate_px=2): |
| g = gradient_mag(gray).astype(np.float32) |
| g = norm01(g) |
| edge = (g >= thr).astype(np.uint8) |
| if dilate_px > 0: |
| k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * dilate_px + 1, 2 * dilate_px + 1)) |
| edge = cv2.dilate(edge, k, iterations=1) |
| |
| |
| return edge |
|
|
| |
| def block_edge(edge_mask, block=16): |
| h, w = edge_mask.shape |
| gh, gw = h // block, w // block |
| frac = np.zeros((gh, gw), dtype=np.float32) |
| for by in range(gh): |
| for bx in range(gw): |
| y0, x0 = by * block, bx * block |
| frac[by, bx] = float(edge_mask[y0:y0 + block, x0:x0 + block].mean()) |
| |
| |
| return frac |
|
|
| |
| def block_mean(img, block=16): |
| h, w = img.shape |
| gh, gw = h // block, w // block |
| x = img[:gh*block, :gw*block] |
| x = x.reshape(gh, block, gw, block) |
| return x.mean(axis=(1, 3)).astype(np.float32) |
|
|
| |
| def blockiness_boundary_map(gray, block_size=16, blur_sigma=1.0): |
| h, w = gray.shape |
| v = np.zeros((h, w), dtype=np.float32) |
| hmap = np.zeros((h, w), dtype=np.float32) |
|
|
| for x in range(block_size, w, block_size): |
| v[:, x] = np.abs(gray[:, x] - gray[:, x - 1]) |
| for y in range(block_size, h, block_size): |
| hmap[y, :] = np.abs(gray[y, :] - gray[y - 1, :]) |
| b = v + hmap |
| if blur_sigma > 0: |
| b = gaussian_blur(b, blur_sigma) |
| return b |
|
|
| |
| |
| |
| def dct_energy_ratios(gray, block=16, low_k=4, mid_k=8, eps=1e-6): |
| """ |
| Return per-block energy ratios for frequency bands. |
| low: [0:low_k, 0:low_k] |
| mid: [0:mid_k, 0:mid_k] - low |
| high: remaining |
| """ |
| H, W = gray.shape |
| gh, gw = H // block, W // block |
| r_low = np.zeros((gh, gw), dtype=np.float32) |
| r_mid = np.zeros((gh, gw), dtype=np.float32) |
| r_high = np.zeros((gh, gw), dtype=np.float32) |
|
|
| for by in range(gh): |
| for bx in range(gw): |
| y0, x0 = by * block, bx * block |
| patch = gray[y0:y0 + block, x0:x0 + block].astype(np.float32) |
|
|
| C = cv2.dct(patch) |
| E = C * C |
| E_low = E[:low_k, :low_k].sum() |
| E_mid = E[:mid_k, :mid_k].sum() - E_low |
| E_total = E.sum() |
| E_high = max(E_total - (E_low + E_mid), 0.0) |
|
|
| denom = E_total + eps |
| r_low[by, bx] = E_low / denom |
| r_mid[by, bx] = E_mid / denom |
| r_high[by, bx] = E_high / denom |
| return r_low, r_mid, r_high |
|
|
| |
| def temporal_fft_map(gray_seq, *, block=16, hf_start_bin=2, eps=1e-6): |
| |
| grids = [block_mean(g.astype(np.float32), block=block) for g in gray_seq] |
| s = np.stack(grids, axis=0) |
|
|
| |
| X = np.fft.rfft(s, axis=0) |
| E = (X.real * X.real + X.imag * X.imag).astype(np.float32) |
|
|
| |
| dc = E[0] |
| if E.shape[0] <= 1: |
| z = np.zeros_like(dc, dtype=np.float32) |
| return z, z |
| P = E[1:] |
| non_dc = P.sum(axis=0) |
|
|
| |
| motion = non_dc / (dc + eps) |
|
|
| |
| start = max(int(hf_start_bin) - 1, 0) |
| if start >= P.shape[0]: |
| flicker = np.zeros_like(non_dc, dtype=np.float32) |
| else: |
| hi = P[start:].sum(axis=0) |
| flicker = hi / (non_dc + eps) |
| return motion.astype(np.float32), flicker.astype(np.float32) |
|
|
| def fuse_temporal_maps(motion_grid, flicker_grid, *, beta=0.5): |
| m = norm01(motion_grid) |
| f = np.clip(flicker_grid, 0.0, 1.0) |
| |
| w = m * ((1.0 - beta) + beta * f) |
| return norm01(w) |
|
|
| |
| |
| |
| def compute_twostream_dct( |
| gray_seq, |
| *, |
| block=16, |
| ): |
| K = len(gray_seq) |
| gray_anchor = gray_seq[0] |
| H, W = gray_anchor.shape |
|
|
| r_low_stack, r_mid_stack, r_high_stack = [], [], [] |
| for g in gray_seq: |
| r_low, r_mid, r_high = dct_energy_ratios(g, block=block) |
| r_low_stack.append(r_low) |
| r_mid_stack.append(r_mid) |
| r_high_stack.append(r_high) |
| r_low_stack = np.stack(r_low_stack, axis=0) |
| r_mid_stack = np.stack(r_mid_stack, axis=0) |
| r_high_stack = np.stack(r_high_stack, axis=0) |
|
|
| |
| anchor_low_grid = r_low_stack[0] |
| anchor_mid_grid = r_mid_stack[0] |
| anchor_high_grid = r_high_stack[0] |
|
|
| |
| edge_mask = edge_sobel(gray_anchor) |
| edge_frac = block_edge(edge_mask, block=block) |
| mh_band = r_mid_stack[0] + r_high_stack[0] |
| ring_score = np.maximum(mh_band, 0.0) |
| edge_min_frac = 0.05 |
| ringing_grid = np.where(edge_frac >= edge_min_frac, edge_frac * ring_score, 0.0).astype(np.float32) |
| s = np.percentile(ringing_grid, 99) + 1e-6 |
| ringing_grid01 = np.clip(ringing_grid / s, 0.0, 1.0) |
|
|
| |
| hf = 0.5 * r_mid_stack[0] + 1.0 * r_high_stack[0] |
| blur_raw = np.clip(1.0 - hf, 0.0, 1.0) |
| sobel_g = gradient_mag(gray_anchor).astype(np.float32) |
| sobel_g_grid = block_mean(sobel_g, block=block) |
| sobel_g_grid = norm01(sobel_g_grid) |
| blur_grid = np.clip(blur_raw * sobel_g_grid, 0.0, 1.0) |
|
|
| |
| boundary_pix = blockiness_boundary_map(gray_anchor, block_size=block) |
| blockiness_grid = norm01(block_mean(boundary_pix, block=block)) |
|
|
| |
| if K >= 4: |
| motion_grid, flick_grid = temporal_fft_map(gray_seq, block=block, hf_start_bin=2) |
| temporal_grid = fuse_temporal_maps(motion_grid, flick_grid, beta=0.5) |
| elif K == 2: |
| E_stack = norm01(r_mid_stack + r_high_stack) |
| temporal_grid = norm01(np.abs(E_stack[1] - E_stack[0])) |
|
|
| |
| w_art = norm01(1.0 * ringing_grid01 + 1.0 * blur_grid + 1.0 * blockiness_grid + 1.0 * temporal_grid) |
| w_str = 1.0 - w_art |
|
|
| debug = { |
| |
| "dct_low_grid": anchor_low_grid, |
| "dct_mid_grid": anchor_mid_grid, |
| "dct_high_grid": anchor_high_grid, |
| |
| "ringing_grid": ringing_grid01, |
| "edge_px": edge_mask, |
| |
| "blur_grid": blur_grid, |
| |
| "blockiness_grid": blockiness_grid, |
| |
| "temporal_grid": temporal_grid, |
| } |
| return w_art, w_str, debug |
|
|
| |
| |
| |
| def save_panel(out_png, frame_rgb, w_art, w_str, debug): |
| fig = plt.figure(figsize=(16, 9), dpi=160) |
| def add(ax_i, title, img, cmap=None, vmin=0, vmax=1): |
| ax = fig.add_subplot(3, 4, ax_i) |
| ax.set_title(title) |
| if cmap is None: |
| ax.imshow(img) |
| else: |
| ax.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax) |
| ax.axis("off") |
|
|
| add(1, "Frame_t (anchor)", frame_rgb) |
| add(2, "DCT LOW (grid)", norm01(debug["dct_low_grid"]), cmap="viridis") |
| add(3, "DCT MID (grid)", norm01(debug["dct_mid_grid"]), cmap="viridis") |
| add(4, "DCT HIGH (grid)", norm01(debug["dct_high_grid"]), cmap="viridis") |
| |
| add(5, "EDGE (mask)", debug["edge_px"], cmap="viridis") |
| add(6, "RINGING (mid/high, grid)", debug["ringing_grid"], cmap="viridis") |
| |
| add(7, "BLUR (lowpass, grid)", debug["blur_grid"], cmap="viridis") |
| |
| add(8, "BLOCKINESS (boundary, grid)", debug["blockiness_grid"], cmap="viridis") |
| |
| add(9, "TEMPORAL (grid)", debug["temporal_grid"], cmap="viridis") |
| |
| add(10, "W_art", w_art, cmap="viridis") |
| add(11, "W_str", w_str, cmap="viridis") |
| os.makedirs(os.path.dirname(out_png), exist_ok=True) |
| fig.tight_layout() |
| fig.savefig(out_png) |
| plt.close(fig) |
|
|
|
|
| |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument("--video", default="/home/xinyi/Project/FD-VQA/test_videos/SDR_Animal_5ngj.mp4") |
| parser.add_argument("--out_dir", default="/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only") |
| parser.add_argument("--size", type=int, default=224) |
|
|
| |
| parser.add_argument("--num_anchors", type=int, default=16) |
| parser.add_argument("--win", type=int, default=6) |
| parser.add_argument("--win_step", type=int, default=1) |
| parser.add_argument("--block", type=int, default=16) |
|
|
| parser.add_argument("--no_panel", action="store_true") |
| args = parser.parse_args() |
| os.makedirs(args.out_dir, exist_ok=True) |
|
|
| vr = VideoReader(args.video, ctx=cpu(0)) |
| total_frames = len(vr) |
| if total_frames <= 1: |
| raise RuntimeError(f"Video too short / frame count unavailable: total_frames={total_frames}") |
| print("total_frames:", total_frames) |
|
|
| size = args.size |
| win = args.win |
| win_step = args.win_step |
| num_anchors = args.num_anchors |
|
|
| anchor_idxs = sample_frames_uniform(total_frames, num_anchors, win=win, win_step=win_step) |
| needed = collect_needed(anchor_idxs, total_frames, win, win_step) |
| print("anchor_idxs:", anchor_idxs) |
| cache = cache_needed_frames(vr, needed, size) |
| print("cached:", len(cache), "needed:", len(needed)) |
|
|
| frame_all, w_art_all, w_str_all = [], [], [] |
| anchors_kept = [] |
| image_idx = 0 |
|
|
| for anchor in tqdm(anchor_idxs, desc="Processing anchors (DCT)"): |
| out = read_window_from_cache(cache, anchor, total_frames, win, win_step) |
| if out is None: |
| continue |
| anchor_frame, gray_seq, idxs = out |
|
|
| w_art, w_str, dbg = compute_twostream_dct( |
| gray_seq, |
| block=args.block, |
| ) |
|
|
| frame_all.append(anchor_frame) |
| w_art_all.append(w_art) |
| w_str_all.append(w_str) |
| anchors_kept.append(idxs) |
|
|
| image_idx += 1 |
| if not args.no_panel: |
| save_panel( |
| os.path.join(args.out_dir, f"anchor_{anchor:03d}_{image_idx:02d}.png"), |
| anchor_frame, |
| w_art, |
| w_str, |
| dbg, |
| ) |
|
|
| print(f"Done. Outputs saved to: {args.out_dir}") |
| print(anchors_kept) |
| print(f"total_frames={total_frames}, num_anchors_target={num_anchors}, anchors_produced={len(w_str_all)}") |
|
|
| if __name__ == "__main__": |
| main() |