| | import av |
| | import os |
| | import pims |
| | import numpy as np |
| | from torch.utils.data import Dataset |
| | from torchvision.transforms.functional import to_pil_image |
| | from PIL import Image |
| |
|
| |
|
| | class VideoReader(Dataset): |
| | def __init__(self, path, transform=None): |
| | self.video = pims.PyAVVideoReader(path) |
| | self.rate = self.video.frame_rate |
| | self.transform = transform |
| | |
| | @property |
| | def frame_rate(self): |
| | return self.rate |
| | |
| | def __len__(self): |
| | return len(self.video) |
| | |
| | def __getitem__(self, idx): |
| | frame = self.video[idx] |
| | frame = Image.fromarray(np.asarray(frame)) |
| | if self.transform is not None: |
| | frame = self.transform(frame) |
| | return frame |
| |
|
| |
|
| | class VideoWriter: |
| | def __init__(self, path, frame_rate, bit_rate=1000000): |
| | self.container = av.open(path, mode='w') |
| | self.stream = self.container.add_stream('h264', rate=f'{frame_rate:.4f}') |
| | self.stream.pix_fmt = 'yuv420p' |
| | self.stream.bit_rate = bit_rate |
| | |
| | def write(self, frames): |
| | |
| | self.stream.width = frames.size(3) |
| | self.stream.height = frames.size(2) |
| | if frames.size(1) == 1: |
| | frames = frames.repeat(1, 3, 1, 1) |
| | frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy() |
| | for t in range(frames.shape[0]): |
| | frame = frames[t] |
| | frame = av.VideoFrame.from_ndarray(frame, format='rgb24') |
| | self.container.mux(self.stream.encode(frame)) |
| | |
| | def close(self): |
| | self.container.mux(self.stream.encode()) |
| | self.container.close() |
| |
|
| |
|
| | class ImageSequenceReader(Dataset): |
| | def __init__(self, path, transform=None): |
| | self.path = path |
| | self.files = sorted(os.listdir(path)) |
| | self.transform = transform |
| | |
| | def __len__(self): |
| | return len(self.files) |
| | |
| | def __getitem__(self, idx): |
| | with Image.open(os.path.join(self.path, self.files[idx])) as img: |
| | img.load() |
| | if self.transform is not None: |
| | return self.transform(img) |
| | return img |
| |
|
| |
|
| | class ImageSequenceWriter: |
| | def __init__(self, path, extension='jpg'): |
| | self.path = path |
| | self.extension = extension |
| | self.counter = 0 |
| | os.makedirs(path, exist_ok=True) |
| | |
| | def write(self, frames): |
| | |
| | for t in range(frames.shape[0]): |
| | to_pil_image(frames[t]).save(os.path.join( |
| | self.path, str(self.counter).zfill(4) + '.' + self.extension)) |
| | self.counter += 1 |
| | |
| | def close(self): |
| | pass |
| | |
| |
|