|
|
import io |
|
|
import os |
|
|
import glob |
|
|
from pathlib import Path |
|
|
import pickle |
|
|
import random |
|
|
import time |
|
|
|
|
|
|
|
|
import cv2 |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image, ImageOps, ImageCms |
|
|
from decord import VideoReader |
|
|
from torch.utils.data.dataset import Dataset |
|
|
from controlnet_aux import CannyDetector, HEDdetector |
|
|
import torch.nn.functional as F |
|
|
from helpers import generate_1x_sequence, generate_2x_sequence, generate_large_blur_sequence, generate_test_case |
|
|
|
|
|
|
|
|
def unpack_mm_params(p): |
|
|
if isinstance(p, (tuple, list)): |
|
|
return p[0], p[1] |
|
|
elif isinstance(p, (int, float)): |
|
|
return p, p |
|
|
raise Exception(f'Unknown input parameter type.\nParameter: {p}.\nType: {type(p)}') |
|
|
|
|
|
|
|
|
def resize_for_crop(image, min_h, min_w): |
|
|
img_h, img_w = image.shape[-2:] |
|
|
|
|
|
if img_h >= min_h and img_w >= min_w: |
|
|
coef = min(min_h / img_h, min_w / img_w) |
|
|
elif img_h <= min_h and img_w <=min_w: |
|
|
coef = max(min_h / img_h, min_w / img_w) |
|
|
else: |
|
|
coef = min_h / img_h if min_h > img_h else min_w / img_w |
|
|
|
|
|
out_h, out_w = int(img_h * coef), int(img_w * coef) |
|
|
resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True) |
|
|
return resized_image |
|
|
|
|
|
|
|
|
|
|
|
class BaseClass(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
data_dir, |
|
|
output_dir, |
|
|
image_size=(320, 512), |
|
|
hflip_p=0.5, |
|
|
controlnet_type='canny', |
|
|
split='train', |
|
|
*args, |
|
|
**kwargs |
|
|
): |
|
|
self.split = split |
|
|
self.height, self.width = unpack_mm_params(image_size) |
|
|
self.data_dir = data_dir |
|
|
self.output_dir = output_dir |
|
|
self.hflip_p = hflip_p |
|
|
self.image_size = image_size |
|
|
self.length = 0 |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
|
|
|
def load_frames(self, frames): |
|
|
|
|
|
|
|
|
pixel_values = torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().float() |
|
|
|
|
|
pixel_values = pixel_values / 127.5 - 1.0 |
|
|
|
|
|
pixel_values = F.interpolate( |
|
|
pixel_values, |
|
|
size=(self.height, self.width), |
|
|
mode="bilinear", |
|
|
align_corners=False |
|
|
) |
|
|
return pixel_values |
|
|
|
|
|
def get_batch(self, idx): |
|
|
raise Exception('Get batch method is not realized.') |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
while True: |
|
|
try: |
|
|
video, caption, motion_blur = self.get_batch(idx) |
|
|
break |
|
|
except Exception as e: |
|
|
print(e) |
|
|
idx = random.randint(0, self.length - 1) |
|
|
|
|
|
video, = [ |
|
|
resize_for_crop(x, self.height, self.width) for x in [video] |
|
|
] |
|
|
video, = [ |
|
|
transforms.functional.center_crop(x, (self.height, self.width)) for x in [video] |
|
|
] |
|
|
data = { |
|
|
'video': video, |
|
|
'caption': caption, |
|
|
} |
|
|
return data |
|
|
|
|
|
def load_as_srgb(path): |
|
|
img = Image.open(path) |
|
|
img = ImageOps.exif_transpose(img) |
|
|
|
|
|
if 'icc_profile' in img.info: |
|
|
icc = img.info['icc_profile'] |
|
|
src_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc)) |
|
|
dst_profile = ImageCms.createProfile("sRGB") |
|
|
img = ImageCms.profileToProfile(img, src_profile, dst_profile, outputMode='RGB') |
|
|
else: |
|
|
img = img.convert("RGB") |
|
|
return img |
|
|
|
|
|
class GoProMotionBlurDataset(BaseClass): |
|
|
def __init__(self, |
|
|
*args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
if self.split == 'train': |
|
|
self.blur_root = os.path.join(self.data_dir, 'train', 'blur') |
|
|
self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp') |
|
|
elif self.split in ['val', 'test']: |
|
|
self.blur_root = os.path.join(self.data_dir, 'test', 'blur') |
|
|
self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp') |
|
|
else: |
|
|
raise ValueError(f"Unsupported split: {self.split}") |
|
|
|
|
|
|
|
|
pattern = os.path.join(self.blur_root, '*', '*.png') |
|
|
|
|
|
self.blur_paths = sorted(glob.glob(pattern)) |
|
|
|
|
|
if self.split == 'val': |
|
|
|
|
|
self.blur_paths = self.blur_paths[:5] |
|
|
|
|
|
filtered_blur_paths = [] |
|
|
for path in self.blur_paths: |
|
|
output_deblurred_dir = os.path.join(self.output_dir, "deblurred") |
|
|
full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4") |
|
|
if not os.path.exists(full_output_path): |
|
|
filtered_blur_paths.append(path) |
|
|
|
|
|
|
|
|
self.window_size = 7 |
|
|
self.pad = 2 |
|
|
self.output_length = self.window_size + self.pad |
|
|
self.half_window = self.window_size // 2 |
|
|
self.length = len(self.blur_paths) |
|
|
|
|
|
|
|
|
self.input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) |
|
|
|
|
|
|
|
|
step = 1.0 / (self.window_size - 1) |
|
|
|
|
|
window_intervals = [] |
|
|
for i in range(self.window_size): |
|
|
start = -0.5 + i * step |
|
|
if i < self.window_size - 1: |
|
|
end = -0.5 + (i + 1) * step |
|
|
else: |
|
|
end = 0.5 |
|
|
window_intervals.append([start, end]) |
|
|
|
|
|
intervals = window_intervals + [window_intervals[-1]] * self.pad |
|
|
self.output_interval = torch.tensor(intervals, dtype=torch.float) |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
blur_path = self.blur_paths[idx] |
|
|
seq_name = os.path.basename(os.path.dirname(blur_path)) |
|
|
frame_name = os.path.basename(blur_path) |
|
|
center_idx = int(os.path.splitext(frame_name)[0]) |
|
|
|
|
|
|
|
|
start_idx = center_idx - self.half_window |
|
|
end_idx = center_idx + self.half_window |
|
|
|
|
|
|
|
|
sharp_dir = os.path.join(self.sharp_root, seq_name) |
|
|
frames = [] |
|
|
for i in range(start_idx, end_idx + 1): |
|
|
sharp_filename = f"{i:06d}.png" |
|
|
sharp_path = os.path.join(sharp_dir, sharp_filename) |
|
|
img = Image.open(sharp_path).convert('RGB') |
|
|
frames.append(img) |
|
|
|
|
|
|
|
|
while len(frames) < self.output_length: |
|
|
frames.append(frames[-1]) |
|
|
|
|
|
|
|
|
blur_img = Image.open(blur_path).convert('RGB') |
|
|
|
|
|
|
|
|
video = self.load_frames(np.array(frames)) |
|
|
blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) |
|
|
end_time = time.time() |
|
|
data = { |
|
|
'file_name': os.path.join(seq_name, frame_name), |
|
|
'blur_img': blur_input, |
|
|
'video': video, |
|
|
"caption": "", |
|
|
'motion_blur_amount': torch.tensor(self.half_window, dtype=torch.long), |
|
|
'input_interval': self.input_interval, |
|
|
'output_interval': self.output_interval, |
|
|
"num_frames": self.window_size, |
|
|
"mode": "1x", |
|
|
} |
|
|
return data |
|
|
|
|
|
|
|
|
class OutsidePhotosDataset(BaseClass): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.image_paths = sorted(glob.glob(os.path.join(self.data_dir, '**', '*.*'), recursive=True)) |
|
|
|
|
|
INTERVALS = [ |
|
|
{"in_start": 0, "in_end": 16, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "1x", "fps": 240}, |
|
|
{"in_start": 4, "in_end": 12, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "2x", "fps": 240},] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cleaned_intervals = [] |
|
|
for image_path in self.image_paths: |
|
|
for interval in INTERVALS: |
|
|
|
|
|
i = interval.copy() |
|
|
|
|
|
i['video_name'] = image_path |
|
|
video_name = i['video_name'] |
|
|
mode = i['mode'] |
|
|
|
|
|
vid_name_w_extension = os.path.relpath(video_name, self.data_dir).split('.')[0] |
|
|
output_name = ( |
|
|
f"{vid_name_w_extension}_{mode}.mp4" |
|
|
) |
|
|
|
|
|
full_output_path = os.path.join("/datasets/sai/gencam/cogvideox/training/cogvideox-outsidephotos/deblurred", output_name) |
|
|
|
|
|
|
|
|
if not os.path.exists(full_output_path): |
|
|
self.cleaned_intervals.append(i) |
|
|
|
|
|
|
|
|
self.length = len(self.cleaned_intervals) |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
interval = self.cleaned_intervals[idx] |
|
|
|
|
|
in_start = interval['in_start'] |
|
|
in_end = interval['in_end'] |
|
|
out_start = interval['out_start'] |
|
|
out_end = interval['out_end'] |
|
|
center = interval['center'] |
|
|
window = interval['window_size'] |
|
|
mode = interval['mode'] |
|
|
fps = interval['fps'] |
|
|
|
|
|
|
|
|
image_path = interval['video_name'] |
|
|
blur_img_original = load_as_srgb(image_path) |
|
|
H,W = blur_img_original.size |
|
|
|
|
|
frame_paths = [] |
|
|
frame_paths = ["../assets/dummy_image.png" for _ in range(window)] |
|
|
|
|
|
|
|
|
_, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( |
|
|
frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps |
|
|
) |
|
|
file_name = image_path |
|
|
|
|
|
|
|
|
relative_file_name = os.path.relpath(file_name, self.data_dir) |
|
|
base_dir = os.path.dirname(relative_file_name) |
|
|
frame_stem = os.path.splitext(os.path.basename(file_name))[0] |
|
|
|
|
|
new_filename = ( |
|
|
f"{frame_stem}_{mode}.png" |
|
|
) |
|
|
|
|
|
blur_img =blur_img_original.resize((self.image_size[1], self.image_size[0])) |
|
|
|
|
|
|
|
|
relative_file_name = os.path.join(base_dir, new_filename) |
|
|
|
|
|
|
|
|
blur_input = self.load_frames(np.expand_dims(blur_img, 0).copy()) |
|
|
|
|
|
video = self.load_frames(np.stack(seq_frames, axis=0)) |
|
|
|
|
|
|
|
|
data = { |
|
|
'file_name': relative_file_name, |
|
|
"original_size": (H, W), |
|
|
'blur_img': blur_input, |
|
|
'video': video, |
|
|
'caption': "", |
|
|
'input_interval': inp_int, |
|
|
'output_interval': out_int, |
|
|
"num_frames": num_frames, |
|
|
} |
|
|
return data |
|
|
|
|
|
class FullMotionBlurDataset(BaseClass): |
|
|
""" |
|
|
A dataset that randomly selects among 1×, 2×, or large-blur modes per sample. |
|
|
Uses category-specific <split>_list.txt files under each subfolder of FullDataset to assemble sequences. |
|
|
In 'test' split, it instead loads precomputed intervals from intervals_test.pkl and uses generate_test_case. |
|
|
""" |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.seq_dirs = [] |
|
|
|
|
|
|
|
|
if self.split == 'test': |
|
|
pkl_path = os.path.join(self.data_dir, 'intervals_test.pkl') |
|
|
with open(pkl_path, 'rb') as f: |
|
|
self.test_intervals = pickle.load(f) |
|
|
assert self.test_intervals, f"No test intervals found in {pkl_path}" |
|
|
|
|
|
cleaned_intervals = [] |
|
|
for interval in self.test_intervals: |
|
|
|
|
|
in_start = interval['in_start'] |
|
|
in_end = interval['in_end'] |
|
|
out_start = interval['out_start'] |
|
|
out_end = interval['out_end'] |
|
|
center = interval['center'] |
|
|
window = interval['window_size'] |
|
|
mode = interval['mode'] |
|
|
fps = interval['fps'] |
|
|
category, seq = interval['video_name'].split('/') |
|
|
seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq) |
|
|
frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) |
|
|
rel_path = os.path.relpath(frame_paths[center], self.data_dir) |
|
|
rel_path = os.path.splitext(rel_path)[0] |
|
|
|
|
|
output_name = ( |
|
|
f"{rel_path}_" |
|
|
f"in{in_start:04d}_ie{in_end:04d}_" |
|
|
f"os{out_start:04d}_oe{out_end:04d}_" |
|
|
f"ctr{center:04d}_win{window:04d}_" |
|
|
f"fps{fps:04d}_{mode}.mp4" |
|
|
) |
|
|
output_deblurred_dir = os.path.join(self.output_dir, "deblurred") |
|
|
full_output_path = os.path.join(output_deblurred_dir, output_name) |
|
|
|
|
|
|
|
|
if not os.path.exists(full_output_path): |
|
|
cleaned_intervals.append(interval) |
|
|
print("Len of test intervals after cleaning: ", len(cleaned_intervals)) |
|
|
print("Len of test intervals before cleaning: ", len(self.test_intervals)) |
|
|
self.test_intervals = cleaned_intervals |
|
|
|
|
|
|
|
|
|
|
|
list_file = 'train_list.txt' if self.split == 'train' else 'test_list.txt' |
|
|
for category in sorted(os.listdir(self.data_dir)): |
|
|
cat_dir = os.path.join(self.data_dir, category) |
|
|
if not os.path.isdir(cat_dir): |
|
|
continue |
|
|
list_path = os.path.join(cat_dir, list_file) |
|
|
if os.path.isfile(list_path): |
|
|
with open(list_path, 'r') as f: |
|
|
for line in f: |
|
|
rel = line.strip() |
|
|
if not rel: |
|
|
continue |
|
|
seq_dir = os.path.join(self.data_dir, rel) |
|
|
if os.path.isdir(seq_dir): |
|
|
self.seq_dirs.append(seq_dir) |
|
|
else: |
|
|
fps_root = os.path.join(cat_dir, 'lower_fps_frames') |
|
|
if os.path.isdir(fps_root): |
|
|
for seq in sorted(os.listdir(fps_root)): |
|
|
seq_path = os.path.join(fps_root, seq) |
|
|
if os.path.isdir(seq_path): |
|
|
self.seq_dirs.append(seq_path) |
|
|
|
|
|
if self.split == 'val': |
|
|
self.seq_dirs = self.seq_dirs[:5] |
|
|
if self.split == 'train': |
|
|
self.seq_dirs *= 10 |
|
|
|
|
|
assert self.seq_dirs, \ |
|
|
f"No sequences found for split '{self.split}' in {self.data_dir}" |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.test_intervals) if self.split == 'test' else len(self.seq_dirs) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
if self.split == 'test': |
|
|
interval = self.test_intervals[idx] |
|
|
category, seq = interval['video_name'].split('/') |
|
|
seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq) |
|
|
frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) |
|
|
|
|
|
in_start = interval['in_start'] |
|
|
in_end = interval['in_end'] |
|
|
out_start = interval['out_start'] |
|
|
out_end = interval['out_end'] |
|
|
center = interval['center'] |
|
|
window = interval['window_size'] |
|
|
mode = interval['mode'] |
|
|
fps = interval['fps'] |
|
|
|
|
|
|
|
|
blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( |
|
|
frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps |
|
|
) |
|
|
file_name = frame_paths[center] |
|
|
|
|
|
else: |
|
|
seq_dir = self.seq_dirs[idx] |
|
|
frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) |
|
|
mode = random.choice(['1x', '2x', 'large_blur']) |
|
|
|
|
|
if mode == '1x' or len(frame_paths) < 50: |
|
|
base_rate = random.choice([1, 2]) |
|
|
blur_img, seq_frames, inp_int, out_int, _ = generate_1x_sequence( |
|
|
frame_paths, window_max=16, output_len=17, base_rate=base_rate |
|
|
) |
|
|
elif mode == '2x': |
|
|
base_rate = random.choice([1, 2]) |
|
|
blur_img, seq_frames, inp_int, out_int, _ = generate_2x_sequence( |
|
|
frame_paths, window_max=16, output_len=17, base_rate=base_rate |
|
|
) |
|
|
else: |
|
|
max_base = min((len(frame_paths) - 1) // 17, 3) |
|
|
base_rate = random.randint(1, max_base) |
|
|
blur_img, seq_frames, inp_int, out_int, _ = generate_large_blur_sequence( |
|
|
frame_paths, window_max=16, output_len=17, base_rate=base_rate |
|
|
) |
|
|
file_name = frame_paths[0] |
|
|
num_frames = 16 |
|
|
|
|
|
|
|
|
blur_input = self.load_frames(np.expand_dims(blur_img, 0)) |
|
|
|
|
|
video = self.load_frames(np.stack(seq_frames, axis=0)) |
|
|
|
|
|
|
|
|
relative_file_name = os.path.relpath(file_name, self.data_dir) |
|
|
|
|
|
if self.split == 'test': |
|
|
|
|
|
base_dir = os.path.dirname(relative_file_name) |
|
|
frame_stem = os.path.splitext(os.path.basename(relative_file_name))[0] |
|
|
|
|
|
|
|
|
new_filename = ( |
|
|
f"{frame_stem}_" |
|
|
f"in{in_start:04d}_ie{in_end:04d}_" |
|
|
f"os{out_start:04d}_oe{out_end:04d}_" |
|
|
f"ctr{center:04d}_win{window:04d}_" |
|
|
f"fps{fps:04d}_{mode}.png" |
|
|
) |
|
|
|
|
|
|
|
|
relative_file_name = os.path.join(base_dir, new_filename) |
|
|
|
|
|
data = { |
|
|
'file_name': relative_file_name, |
|
|
'blur_img': blur_input, |
|
|
'num_frames': num_frames, |
|
|
'video': video, |
|
|
'caption': "", |
|
|
'mode': mode, |
|
|
'input_interval': inp_int, |
|
|
'output_interval': out_int, |
|
|
} |
|
|
if self.split == 'test': |
|
|
high_fps_video = self.load_frames(np.stack(high_fps_video, axis=0)) |
|
|
data['high_fps_video'] = high_fps_video |
|
|
return data |
|
|
|
|
|
class GoPro2xMotionBlurDataset(BaseClass): |
|
|
def __init__(self, |
|
|
*args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
if self.split == 'train': |
|
|
self.blur_root = os.path.join(self.data_dir, 'train', 'blur') |
|
|
self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp') |
|
|
elif self.split in ['val', 'test']: |
|
|
self.blur_root = os.path.join(self.data_dir, 'test', 'blur') |
|
|
self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp') |
|
|
else: |
|
|
raise ValueError(f"Unsupported split: {self.split}") |
|
|
|
|
|
|
|
|
pattern = os.path.join(self.blur_root, '*', '*.png') |
|
|
|
|
|
def get_sharp_paths(blur_paths): |
|
|
sharp_paths = [] |
|
|
for blur_path in blur_paths: |
|
|
base_dir = blur_path.replace('/blur/', '/sharp/') |
|
|
frame_num = int(os.path.basename(blur_path).split('.')[0]) |
|
|
dir_path = os.path.dirname(base_dir) |
|
|
sequence = [ |
|
|
os.path.join(dir_path, f"{frame_num + offset:06d}.png") |
|
|
for offset in range(-6, 7) |
|
|
] |
|
|
if all(os.path.exists(path) for path in sequence): |
|
|
sharp_paths.append(sequence) |
|
|
return sharp_paths |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.blur_paths = sorted(glob.glob(pattern)) |
|
|
filtered_blur_paths = [] |
|
|
for path in self.blur_paths: |
|
|
output_deblurred_dir = os.path.join(self.output_dir, "deblurred") |
|
|
full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4") |
|
|
if not os.path.exists(full_output_path): |
|
|
filtered_blur_paths.append(path) |
|
|
self.blur_paths = filtered_blur_paths |
|
|
|
|
|
self.sharp_paths = get_sharp_paths(self.blur_paths) |
|
|
if self.split == 'val': |
|
|
|
|
|
self.sharp_paths = self.sharp_paths[:5] |
|
|
self.length = len(self.sharp_paths) |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
sharp_path = self.sharp_paths[idx] |
|
|
|
|
|
|
|
|
|
|
|
blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( |
|
|
frame_paths=sharp_path, window_max=13, in_start=3, in_end=10, out_start=0,out_end=13, center=6, mode="2x", fps=240 |
|
|
) |
|
|
|
|
|
|
|
|
video = self.load_frames(np.array(seq_frames)) |
|
|
blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) |
|
|
last_two_parts_of_path = os.path.join(*sharp_path[6].split(os.sep)[-2:]) |
|
|
|
|
|
data = { |
|
|
'file_name': last_two_parts_of_path, |
|
|
'blur_img': blur_input, |
|
|
'video': video, |
|
|
"caption": "", |
|
|
'input_interval': inp_int, |
|
|
'output_interval': out_int, |
|
|
"num_frames": num_frames, |
|
|
"mode": "2x", |
|
|
} |
|
|
return data |
|
|
|
|
|
class BAISTDataset(BaseClass): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
|
|
|
test_folders = { |
|
|
"gWA_sBM_c01_d26_mWA0_ch06_cropped_32X": None, |
|
|
"gBR_sBM_c01_d05_mBR0_ch01_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch04_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch05_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch08_cropped_32X": None, |
|
|
"gWA_sBM_c01_d26_mWA0_ch02_cropped_32X": None, |
|
|
"gJS_sBM_c01_d02_mJS0_ch08_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch07_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch06_cropped_32X": None, |
|
|
"gBR_sBM_c01_d05_mBR0_ch03_cropped_32X": None, |
|
|
"gBR_sBM_c01_d05_mBR0_ch05_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch02_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch03_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch09_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch10_cropped_32X": None, |
|
|
"gWA_sBM_c01_d26_mWA0_ch10_cropped_32X": None, |
|
|
"gBR_sBM_c01_d05_mBR0_ch06_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch08_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch06_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch10_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch09_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch02_cropped_32X": None, |
|
|
"gBR_sBM_c01_d05_mBR0_ch04_cropped_32X": None, |
|
|
"gPO_sBM_c01_d10_mPO0_ch09_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch01_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch07_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch03_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch04_cropped_32X": None, |
|
|
"gBR_sBM_c01_d05_mBR0_ch02_cropped_32X": None, |
|
|
"gHO_sBM_c01_d20_mHO0_ch01_cropped_32X": None, |
|
|
"gMH_sBM_c01_d22_mMH0_ch05_cropped_32X": None, |
|
|
"gPO_sBM_c01_d10_mPO0_ch10_cropped_32X": None, |
|
|
} |
|
|
|
|
|
def collect_blur_images(root_dir, allowed_folders, skip_start=40, skip_end=40): |
|
|
blur_image_paths = [] |
|
|
|
|
|
for dirpath, dirnames, filenames in os.walk(root_dir): |
|
|
if os.path.basename(dirpath) == "blur": |
|
|
parent_folder = os.path.basename(os.path.dirname(dirpath)) |
|
|
if (self.split in ["test", "val"] and parent_folder in test_folders) or (self.split in "train" and parent_folder not in test_folders): |
|
|
|
|
|
valid_files = [ |
|
|
f for f in filenames |
|
|
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')) and os.path.splitext(f)[0].isdigit() |
|
|
] |
|
|
valid_files.sort(key=lambda x: int(os.path.splitext(x)[0])) |
|
|
|
|
|
|
|
|
middle_files = valid_files[skip_start:len(valid_files) - skip_end] |
|
|
|
|
|
for f in middle_files: |
|
|
from pathlib import Path |
|
|
full_path = Path(os.path.join(dirpath, f)) |
|
|
output_deblurred_dir = os.path.join(self.output_dir, "deblurred") |
|
|
full_output_path = Path(output_deblurred_dir, *full_path.parts[-3:]).with_suffix(".mp4") |
|
|
if not os.path.exists(full_output_path) or self.split in ["train", "val"]: |
|
|
blur_image_paths.append(os.path.join(dirpath, f)) |
|
|
|
|
|
return blur_image_paths |
|
|
|
|
|
|
|
|
|
|
|
self.image_paths = collect_blur_images(self.data_dir, test_folders) |
|
|
|
|
|
self.image_paths = [path for path in self.image_paths if os.path.exists(path.replace("blur", "blur_anno").replace(".png", ".pkl"))] |
|
|
|
|
|
filtered_image_paths = [] |
|
|
for blur_path in self.image_paths: |
|
|
base_dir = blur_path.replace('/blur/', '/sharp/').replace('.png', '') |
|
|
sharp_paths = [f"{base_dir}_{i:03d}.png" for i in range(7)] |
|
|
if all(os.path.exists(p) for p in sharp_paths): |
|
|
filtered_image_paths.append(blur_path) |
|
|
|
|
|
self.image_paths = filtered_image_paths |
|
|
|
|
|
if self.split == 'val': |
|
|
|
|
|
self.image_paths = self.image_paths[:4] |
|
|
self.length = len(self.image_paths) |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image_path = self.image_paths[idx] |
|
|
blur_img_original = load_as_srgb(image_path) |
|
|
|
|
|
bbx_path = image_path.replace("blur", "blur_anno").replace(".png", ".pkl") |
|
|
|
|
|
|
|
|
bbx = np.load(bbx_path, allow_pickle=True)['bbox'][0:4] |
|
|
|
|
|
|
|
|
W,H = blur_img_original.size |
|
|
blur_img = blur_img_original.resize((self.image_size[1], self.image_size[0]), resample=Image.BILINEAR) |
|
|
|
|
|
|
|
|
blur_np = np.array([blur_img]) |
|
|
|
|
|
base_dir = os.path.dirname(os.path.dirname(image_path)) |
|
|
filename = os.path.splitext(os.path.basename(image_path))[0] |
|
|
sharp_dir = os.path.join(base_dir, "sharp") |
|
|
|
|
|
frame_paths = [ |
|
|
os.path.join(sharp_dir, f"{filename}_{i:03d}.png") |
|
|
for i in range(7) |
|
|
] |
|
|
|
|
|
_, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( |
|
|
frame_paths=frame_paths, window_max=7, in_start=0, in_end=7, out_start=0,out_end=7, center=3, mode="1x", fps=240 |
|
|
) |
|
|
|
|
|
pixel_values = self.load_frames(np.stack(seq_frames, axis=0)) |
|
|
blur_pixel_values = self.load_frames(blur_np) |
|
|
|
|
|
relative_file_name = os.path.relpath(image_path, self.data_dir) |
|
|
|
|
|
out_bbx = bbx.copy() |
|
|
|
|
|
scale_x = blur_pixel_values.shape[3]/W |
|
|
scale_y = blur_pixel_values.shape[2]/H |
|
|
|
|
|
out_bbx[0] = int(out_bbx[0] * scale_x) |
|
|
out_bbx[1] = int(out_bbx[1] * scale_y) |
|
|
out_bbx[2] = int(out_bbx[2] * scale_x) |
|
|
out_bbx[3] = int(out_bbx[3] * scale_y) |
|
|
|
|
|
out_bbx = torch.tensor(out_bbx, dtype=torch.uint32) |
|
|
|
|
|
|
|
|
blur_img_npy = np.array(blur_img) |
|
|
out_bbx_npy = out_bbx.numpy().astype(np.uint32) |
|
|
blur_img_npy = blur_img_npy[out_bbx_npy[1]:out_bbx_npy[3], out_bbx_npy[0]:out_bbx_npy[2], :] |
|
|
|
|
|
data = { |
|
|
'file_name': relative_file_name, |
|
|
'blur_img': blur_pixel_values, |
|
|
'video': pixel_values, |
|
|
'bbx': out_bbx, |
|
|
'caption': "", |
|
|
'input_interval': inp_int, |
|
|
'output_interval': out_int, |
|
|
"num_frames": num_frames, |
|
|
'mode': "1x", |
|
|
} |
|
|
return data |
|
|
|