| import torch |
| import math |
| import random |
| import numpy as np |
| from PIL import Image |
|
|
| def random_insert_latent_frame( |
| image_latent: torch.Tensor, |
| noisy_model_input: torch.Tensor, |
| target_latents: torch.Tensor, |
| input_intervals: torch.Tensor, |
| output_intervals: torch.Tensor, |
| special_info |
| ): |
| """ |
| Inserts latent frames into noisy input, pads targets, and builds flattened intervals with flags. |
| |
| Args: |
| image_latent: [B, latent_count, C, H, W] |
| noisy_model_input:[B, F, C, H, W] |
| target_latents: [B, F, C, H, W] |
| input_intervals: [B, N, frames_per_latent, L] |
| output_intervals: [B, M, frames_per_latent, L] |
| |
| For each sample randomly choose: |
| Mode A (50%): |
| - Insert two image_latent frames at start of noisy input and targets. |
| - Pad target_latents by prepending two zero-frames. |
| - Pad input_intervals by repeating its last group once. |
| Mode B (50%): |
| - Insert one image_latent frame at start and repeat last noisy frame at end. |
| - Pad target_latents by prepending one one-frame and appending last target frame. |
| - Pad output_intervals by repeating its last group once. |
| |
| After padding intervals, flatten each group from [frames_per_latent, L] to [frames_per_latent * L], |
| then append a 4-element flag (1 for input groups, 0 for output groups). |
| |
| Returns: |
| outputs: Tensor [B, F+2, C, H, W] |
| new_targets: Tensor [B, F+2, C, H, W] |
| masks: Tensor [B, F+2] bool mask of latent inserts |
| intervals: Tensor [B, N+M+1, fpl * L + 4] |
| """ |
| B, F, C, H, W = noisy_model_input.shape |
| _, N, fpl, L = input_intervals.shape |
| _, M, _, _ = output_intervals.shape |
| device = noisy_model_input.device |
|
|
| new_F = F + 1 if special_info == "just_one" else F + 2 |
| outputs = torch.empty((B, new_F, C, H, W), device=device) |
| masks = torch.zeros((B, new_F), dtype=torch.bool, device=device) |
| combined_groups = N + M |
| feature_len = fpl * L |
| |
| |
| intervals = torch.empty((B, combined_groups, feature_len), device=device, |
| dtype=input_intervals.dtype) |
| new_targets = torch.empty((B, new_F, C, H, W), device=device, |
| dtype=target_latents.dtype) |
|
|
| for b in range(B): |
| latent = image_latent[b, 0] |
| frames = noisy_model_input[b] |
| tgt = target_latents[b] |
|
|
| limit = 10 if special_info == "use_a" else 0.5 |
| if special_info == "just_one": |
| |
| outputs[b, 0] = latent |
| masks[b, :1] = True |
| outputs[b, 1:] = frames |
|
|
| |
| large_number = torch.ones_like(tgt[0])*10000 |
| new_targets[b, 0] = large_number |
| new_targets[b, 1:] = tgt |
|
|
| |
| |
| in_groups = input_intervals[b] |
| out_groups = output_intervals[b] |
| elif random.random() < limit: |
| |
| outputs[b, 0] = latent |
| outputs[b, 1] = latent |
| masks[b, :2] = True |
| outputs[b, 2:] = frames |
|
|
| |
| large_number = torch.ones_like(tgt[0])*10000 |
| new_targets[b, 0] = large_number |
| new_targets[b, 1] = large_number |
| new_targets[b, 2:] = tgt |
|
|
| |
| pad_group = input_intervals[b, -1:].clone() |
| in_groups = torch.cat([input_intervals[b], pad_group], dim=0) |
| out_groups = output_intervals[b] |
| else: |
| |
| outputs[b, 0] = latent |
| masks[b, 0] = True |
| outputs[b, 1:new_F-1] = frames |
| outputs[b, new_F-1] = frames[-1] |
|
|
| |
| zero = torch.zeros_like(tgt[0]) |
| new_targets[b, 0] = zero |
| new_targets[b, 1:new_F-1] = tgt |
| new_targets[b, new_F-1] = tgt[-1] |
|
|
| |
| in_groups = input_intervals[b] |
| pad_group = output_intervals[b, -1:].clone() |
| out_groups = torch.cat([output_intervals[b], pad_group], dim=0) |
|
|
| |
| flat_in = in_groups.reshape(-1, feature_len) |
| proc_in = torch.cat([flat_in], dim=1) |
|
|
| flat_out = out_groups.reshape(-1, feature_len) |
| proc_out = torch.cat([flat_out], dim=1) |
|
|
| intervals[b] = torch.cat([proc_in, proc_out], dim=0) |
|
|
| return outputs, new_targets, masks, intervals |
|
|
|
|
|
|
|
|
| def transform_intervals( |
| intervals: torch.Tensor, |
| frames_per_latent: int = 4, |
| repeat_first: bool = True |
| ) -> torch.Tensor: |
| """ |
| Pad and reshape intervals into [B, num_latent_frames, frames_per_latent, L]. |
| |
| Args: |
| intervals: Tensor of shape [B, N, L] |
| frames_per_latent: number of frames per latent group (e.g., 4) |
| repeat_first: if True, pad at the beginning by repeating the first row; otherwise pad at the end by repeating the last row. |
| |
| Returns: |
| Tensor of shape [B, num_latent_frames, frames_per_latent, L] |
| """ |
| B, N, L = intervals.shape |
| num_latent = math.ceil(N / frames_per_latent) |
| target_N = num_latent * frames_per_latent |
| pad_count = target_N - N |
|
|
| if pad_count > 0: |
| |
| pad_row = intervals[:, :1, :] if repeat_first else intervals[:, -1:, :] |
| |
| pad = pad_row.repeat(1, pad_count, 1) |
| |
| if repeat_first: |
| expanded = torch.cat([pad, intervals], dim=1) |
| else: |
| expanded = torch.cat([intervals, pad], dim=1) |
| else: |
| expanded = intervals[:, :target_N, :] |
|
|
| |
| return expanded.view(B, num_latent, frames_per_latent, L) |
|
|
| import random |
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
|
|
| import random |
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
|
|
| def build_blur(frame_paths, gamma=2.2): |
| """ |
| Simulate motion blur using inverse-gamma (linear-light) summation: |
| - Load each image, convert to float32 sRGB [0,255] |
| - Linearize via inverse gamma: linear = (img/255)^gamma |
| - Sum linear values, average, then re-encode via gamma: (linear_avg)^(1/gamma)*255 |
| Returns a uint8 numpy array. |
| """ |
| acc_lin = None |
| for p in frame_paths: |
| img = np.array(Image.open(p).convert('RGB'), dtype=np.float32) |
| |
| lin = np.power(img / 255.0, gamma) |
| acc_lin = lin if acc_lin is None else acc_lin + lin |
| |
| avg_lin = acc_lin / len(frame_paths) |
| |
| srgb = np.power(avg_lin, 1.0 / gamma) * 255.0 |
| return np.clip(srgb, 0, 255).astype(np.uint8) |
|
|
| def generate_1x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1, start = None): |
| """ |
| 1× mode at arbitrary base_rate (units of 1/240s): |
| - Treat each output step as the sum of `base_rate` consecutive raw frames. |
| - Pick window size W ∈ [1, output_len] |
| - Randomly choose start index so W*base_rate frames fit |
| - Group raw frames into W groups of length base_rate |
| - Build blur image over all W*base_rate frames for input |
| - For each group, build a blurred output frame by summing its base_rate frames |
| - Pad sequence of W blurred frames to output_len by repeating last blurred frame |
| - Input interval always [-0.5, 0.5] |
| - Output intervals reflect each group’s coverage within [-0.5,0.5] |
| """ |
| N = len(frame_paths) |
| max_w = min(output_len, N // base_rate) |
| max_w = min(max_w, window_max) |
| W = random.randint(1, max_w) |
| if start is not None: |
| |
| assert N >= W * base_rate, f"Not enough frames for base_rate={base_rate}, need {W * base_rate}, got {N}" |
| else: |
| start = random.randint(0, N - W * base_rate) |
| |
|
|
| |
| group_starts = [start + i * base_rate for i in range(W)] |
| |
| blur_paths = [] |
| for gs in group_starts: |
| blur_paths.extend(frame_paths[gs:gs + base_rate]) |
| blur_img = build_blur(blur_paths) |
|
|
| |
| seq = [] |
| for gs in group_starts: |
| group = frame_paths[gs:gs + base_rate] |
| seq.append(build_blur(group)) |
| |
| seq += [seq[-1]] * (output_len - len(seq)) |
|
|
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) |
| |
| step = 1.0 / W |
| intervals = [[-0.5 + i * step, -0.5 + (i + 1) * step] for i in range(W)] |
| num_frames = len(intervals) |
| intervals += [intervals[-1]] * (output_len - W) |
| output_intervals = torch.tensor(intervals, dtype=torch.float) |
|
|
| return blur_img, seq, input_interval, output_intervals, num_frames |
|
|
| def generate_2x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1): |
| """ |
| 2× mode: |
| - Logical window of W output-steps so that 2*W ≤ output_len |
| - Raw window spans W*base_rate frames |
| - Build blur only over that raw window (flattened) for input |
| - before_count = W//2, after_count = W - before_count |
| - Define groups for before, during, and after each of length base_rate |
| - Build blurred frames for each group |
| - Pad sequence of 2*W blurred frames to output_len by repeating last |
| - Input interval always [-0.5,0.5] |
| - Output intervals relative to window: each group’s center |
| """ |
| N = len(frame_paths) |
| max_w = min(output_len // 2, N // base_rate) |
| max_w = min(max_w, window_max) |
| W = random.randint(1, max_w) |
| before_count = W // 2 |
| after_count = W - before_count |
| |
| min_start = before_count * base_rate |
| max_start = N - (W + after_count) * base_rate |
| |
| assert max_start >= min_start, f"Cannot satisfy before/after window for W={W}, base_rate={base_rate}, N={N}" |
| start = random.randint(min_start, max_start) |
|
|
|
|
| |
| window_starts = [start + i * base_rate for i in range(W)] |
| |
| blur_paths = [] |
| for gs in window_starts: |
| blur_paths.extend(frame_paths[gs:gs + base_rate]) |
|
|
|
|
| blur_img = build_blur(blur_paths) |
|
|
| |
| before_count = W // 2 |
| after_count = W - before_count |
| before_starts = [max(0, start - (i + 1) * base_rate) for i in range(before_count)][::-1] |
| after_starts = [min(N - base_rate, start + W * base_rate + i * base_rate) for i in range(after_count)] |
|
|
| |
| group_starts = before_starts + window_starts + after_starts |
| |
| seq = [] |
| for gs in group_starts: |
| group = frame_paths[gs:gs + base_rate] |
| seq.append(build_blur(group)) |
| |
| seq += [seq[-1]] * (output_len - len(seq)) |
|
|
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) |
| |
| half = 0.5 / W |
| centers = [((gs - start) / (W * base_rate)) - 0.5 + half |
| for gs in group_starts] |
| intervals = [[c - half, c + half] for c in centers] |
| num_frames = len(intervals) |
| intervals += [intervals[-1]] * (output_len - len(intervals)) |
| output_intervals = torch.tensor(intervals, dtype=torch.float) |
|
|
| return blur_img, seq, input_interval, output_intervals, num_frames |
|
|
|
|
| def generate_large_blur_sequence(frame_paths, window_max=16, output_len=17, base_rate=1): |
| """ |
| Large blur mode (fixed output_len=25) with instantaneous outputs: |
| - Raw window spans 25 * base_rate consecutive frames |
| - Build blur over that full raw window for input |
| - For output sequence: |
| • Pick 1 raw frame every `base_rate` (group_starts) |
| • Each output frame is the instantaneous frame at that raw index |
| - Input interval always [-0.5, 0.5] |
| - Output intervals reflect each 1-frame slice’s coverage within the blur window, |
| leaving gaps between. |
| """ |
| N = len(frame_paths) |
| total_raw = window_max * base_rate |
| assert N >= total_raw, f"Not enough frames for base_rate={base_rate}, need {total_raw}, got {N}" |
| start = random.randint(0, N - total_raw) |
|
|
| |
| raw_block = frame_paths[start:start + total_raw] |
| blur_img = build_blur(raw_block) |
|
|
| |
| seq = [] |
| group_starts = [start + i * base_rate for i in range(window_max)] |
| for gs in group_starts: |
| img = np.array(Image.open(frame_paths[gs]).convert('RGB'), dtype=np.uint8) |
| seq.append(img) |
| |
| seq += [seq[-1]] * (output_len - len(seq)) |
|
|
| |
| |
| intervals = [] |
| for gs in group_starts: |
| t0 = (gs - start) / total_raw - 0.5 |
| t1 = (gs + 1 - start) / total_raw - 0.5 |
| intervals.append([t0, t1]) |
| num_frames = len(intervals) |
| intervals += [intervals[-1]] * (output_len - len(intervals)) |
| output_intervals = torch.tensor(intervals, dtype=torch.float) |
|
|
| |
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) |
| return blur_img, seq, input_interval, output_intervals, num_frames |
|
|
| def generate_test_case(frame_paths, |
| window_max=16, |
| output_len=17, |
| in_start=None, |
| in_end=None, |
| out_start=None, |
| out_end = None, |
| center=None, |
| mode="1x", |
| fps=240): |
| """ |
| Generate blurred input + a target sequence + normalized intervals. |
| |
| Args: |
| frame_paths: list of all frame filepaths |
| window_max: number of groups/bins W |
| output_len: desired length of the output sequence |
| in_start, in_end: integer indices defining the raw window [in_start, in_end) |
| mode: one of "1x", "2x", or "lb" |
| fps: frames-per-second (only used to override mode=="2x" if fps==120) |
| |
| Returns: |
| blur_img: np.ndarray of the global blur over the window |
| seq: list of np.ndarray, length = output_len (blured groups or raw frames) |
| input_interval: torch.Tensor [[-0.5, 0.5]] |
| output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5] |
| """ |
| |
| raw_paths = frame_paths[in_start:in_end] |
|
|
| blur_img = build_blur(raw_paths) |
|
|
| |
| |
| seq = [ |
| np.array(Image.open(p).convert("RGB"), dtype=np.uint8) |
| for p in frame_paths[out_start:out_end] |
| ] |
|
|
| |
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) |
|
|
| |
| def normalize(x, in_start, in_end): |
| return (x - in_start) / (in_end - in_start) - 0.5 |
| |
| base_rate = 240 // fps |
| |
| |
| base_rate = 240 // fps |
| if mode == "1x": |
| assert in_start == out_start and in_end == out_end |
| |
| W = (out_end - out_start) // base_rate |
| |
| group_starts = [out_start + i * base_rate for i in range(W)] |
| group_ends = [out_start + (i + 1) * base_rate for i in range(W)] |
|
|
| elif mode == "2x": |
| W = (out_end - out_start) // base_rate |
| |
| group_starts = [out_start + i * base_rate for i in range(W)] |
| group_ends = [out_start + (i + 1) * base_rate for i in range(W)] |
|
|
| elif mode == "lb": |
| W = (out_end - out_start) // base_rate |
| |
| group_starts = [in_start + i * base_rate for i in range(W)] |
| group_ends = [s + 1 for s in group_starts] |
|
|
| else: |
| raise ValueError(f"Unsupported mode: {mode}") |
|
|
| |
| |
| summed_seq = [] |
| for s, e in zip(group_starts, group_ends): |
| |
| s_clamped = max(in_start, min(s, in_end-1)) |
| e_clamped = max(s_clamped+1, min(e, in_end)) |
| |
| summed = build_blur(frame_paths[s_clamped:e_clamped]) |
| summed_seq.append(summed) |
|
|
| |
| if len(summed_seq) < output_len: |
| summed_seq += [summed_seq[-1]] * (output_len - len(summed_seq)) |
|
|
| |
| def normalize(x): |
| return (x - in_start) / (in_end - in_start) - 0.5 |
|
|
| intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)] |
| num_frames = len(intervals) |
| if len(intervals) < output_len: |
| intervals += [intervals[-1]] * (output_len - len(intervals)) |
| |
| output_intervals = torch.tensor(intervals, dtype=torch.float) |
|
|
| |
| return blur_img, summed_seq, input_interval, output_intervals, seq, num_frames |
|
|
|
|
| def get_conditioning( |
| output_len=17, |
| in_start=None, |
| in_end=None, |
| out_start=None, |
| out_end=None, |
| mode="1x", |
| fps=240, |
| ): |
| """ |
| Generate normalized intervals conditioning singals. Just like the above function but without |
| loading any images (for inference only). |
| |
| Args: |
| output_len: desired length of the output sequence |
| in_start, in_end: integer indices defining the raw window [in_start, in_end) |
| mode: one of "1x", "2x", or "lb" |
| fps: frames-per-second (only used to override mode=="2x" if fps==120) |
| |
| Returns: |
| input_interval: torch.Tensor [[-0.5, 0.5]] |
| output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5] |
| """ |
|
|
| |
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) |
|
|
| |
| def normalize(x, in_start, in_end): |
| return (x - in_start) / (in_end - in_start) - 0.5 |
| |
| base_rate = 240 // fps |
| |
| |
| base_rate = 240 // fps |
| if mode == "1x": |
| assert in_start == out_start and in_end == out_end |
| |
| W = (out_end - out_start) // base_rate |
| |
| group_starts = [out_start + i * base_rate for i in range(W)] |
| group_ends = [out_start + (i + 1) * base_rate for i in range(W)] |
|
|
| elif mode == "2x": |
| W = (out_end - out_start) // base_rate |
| |
| group_starts = [out_start + i * base_rate for i in range(W)] |
| group_ends = [out_start + (i + 1) * base_rate for i in range(W)] |
|
|
| elif mode == "lb": |
| W = (out_end - out_start) // base_rate |
| |
| group_starts = [in_start + i * base_rate for i in range(W)] |
| group_ends = [s + 1 for s in group_starts] |
|
|
| else: |
| raise ValueError(f"Unsupported mode: {mode}") |
|
|
| |
| def normalize(x): |
| return (x - in_start) / (in_end - in_start) - 0.5 |
|
|
| intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)] |
| num_frames = len(intervals) |
| if len(intervals) < output_len: |
| intervals += [intervals[-1]] * (output_len - len(intervals)) |
| |
| output_intervals = torch.tensor(intervals, dtype=torch.float) |
|
|
| return input_interval, output_intervals, num_frames |
|
|