Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import glob | |
| import math | |
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image | |
| VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes | |
| class LoadImagesAndVideos: | |
| """ | |
| A data loader for handling both images and videos, providing batches of frames or images for processing. | |
| Supports various image formats, including HEIC, and handles text files with paths to images/videos. | |
| """ | |
| def __init__(self, path, batch_size=1, vid_stride=1): | |
| self.batch_size = batch_size | |
| self.vid_stride = vid_stride | |
| self.files = self._load_files(path) | |
| self.video_flag = [self._is_video(f) for f in self.files] | |
| self.nf = len(self.files) | |
| self.ni = sum(not is_video for is_video in self.video_flag) | |
| self.mode = "image" | |
| self.cap = None | |
| if any(self.video_flag): | |
| self._start_video(self.files[self.video_flag.index(True)]) | |
| if not self.files: | |
| raise FileNotFoundError(f"No images or videos found in {path}.") | |
| def _load_files(self, path): | |
| """Load files from a given path, which may be a directory, list, or text file.""" | |
| if isinstance(path, str) and Path(path).suffix == ".txt": | |
| path = Path(path).read_text().splitlines() | |
| files = [] | |
| for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: | |
| p = str(Path(p).absolute()) | |
| if "*" in p: | |
| files.extend(glob.glob(p, recursive=True)) | |
| elif os.path.isdir(p): | |
| files.extend(glob.glob(os.path.join(p, "*.*"))) | |
| elif os.path.isfile(p): | |
| files.append(p) | |
| else: | |
| raise FileNotFoundError(f"{p} does not exist") | |
| return files | |
| def _is_video(self, file_path): | |
| """Check if a file is a video based on its extension.""" | |
| return file_path.split('.')[-1].lower() in VID_FORMATS | |
| def __iter__(self): | |
| self.count = 0 | |
| return self | |
| def __next__(self): | |
| paths, imgs, infos = [], [], [] | |
| while len(imgs) < self.batch_size: | |
| if self.count >= self.nf: | |
| if imgs: | |
| return paths, imgs, infos | |
| else: | |
| raise StopIteration | |
| path = self.files[self.count] | |
| if self.video_flag[self.count]: | |
| self._process_video(paths, imgs, infos, path) | |
| else: | |
| self._process_image(paths, imgs, infos, path) | |
| self.count += 1 | |
| return paths, imgs, infos | |
| def _process_image(self, paths, imgs, infos, path): | |
| """Process an image file and append it to the batch.""" | |
| img = self._read_image(path) | |
| if img is not None: | |
| paths.append(path) | |
| imgs.append(img) | |
| infos.append(f"image {self.count + 1}/{self.nf} {path}") | |
| def _process_video(self, paths, imgs, infos, path): | |
| """Process a video file, reading frames as per the stride.""" | |
| self.mode = "video" | |
| if not self.cap or not self.cap.isOpened(): | |
| self._start_video(path) | |
| success = False | |
| for _ in range(self.vid_stride): | |
| success = self.cap.grab() | |
| if not success: | |
| break | |
| if success: | |
| _, frame = self.cap.retrieve() | |
| paths.append(path) | |
| imgs.append(frame) | |
| infos.append(f"video {self.count + 1}/{self.nf} frame {self.frame}/{self.frames} {path}") | |
| self.frame += 1 | |
| if self.frame >= self.frames: | |
| self.cap.release() | |
| def _read_image(self, path): | |
| """Read an image from a file, handling HEIC format if necessary.""" | |
| if path.lower().endswith("heic"): | |
| from pillow_heif import register_heif_opener | |
| register_heif_opener() | |
| with Image.open(path) as img: | |
| return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| else: | |
| return cv2.imread(path) | |
| def _start_video(self, path): | |
| """Initialize video capture for a new video file.""" | |
| self.cap = cv2.VideoCapture(path) | |
| if not self.cap.isOpened(): | |
| raise FileNotFoundError(f"Failed to open video {path}") | |
| self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) | |
| self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) | |
| self.frame = 0 | |
| def __len__(self): | |
| return math.ceil(self.nf / self.batch_size) | |