Spaces:
Runtime error
Runtime error
| from typing import Dict, List | |
| import random | |
| from PIL import Image, ImageDraw, ImageFont | |
| from .utils import sample_video | |
| class VideoPermutationParser: | |
| def __init__( | |
| self, | |
| n_frames=8, | |
| is_training=True, | |
| frame_nums = list(range(8, 25)), | |
| video_sampling_strategy={}, | |
| ): | |
| self.n_frames = n_frames | |
| self.is_training = is_training | |
| self.frame_nums = frame_nums | |
| self.video_sampling_strategy = video_sampling_strategy | |
| # fmt: off | |
| self.data_temp = { | |
| "text": [{ | |
| "prompt": "<video>", | |
| "response": "" | |
| }], | |
| "video": [{ | |
| "video_file": { | |
| "yg": "/mnt/bn/videonasyg/videos/webvid_10M_download/011851_011900/1047443473.mp4", | |
| "lq": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4" | |
| }, | |
| "frame_indices": [0, 85, 171, 256, 342, 427, 513, 598] | |
| }], | |
| } | |
| # fmt: on | |
| def check_format(self, data_dict: Dict): | |
| pass | |
| # for k in self.data_temp.keys(): | |
| # assert k in data_dict | |
| def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict: | |
| self.check_format(data_dict) | |
| frames = self.load_video_item(data_dict['video'][0]) | |
| # frames = self.add_text_to_frames(frames) # for debug | |
| idxs = list(range(1, len(frames) + 1)) | |
| random.shuffle(idxs) | |
| prefix_len = int(3/8*len(idxs)) | |
| shuffled_frames = [frames[i-1] for i in idxs] | |
| prompt = f'Output the correct chronological order of scrambled video frames. The order of the first {prefix_len} ones are:\n' | |
| prompt += '\n'.join([str(i) for i in idxs[: prefix_len]]) + '\nOutput the order of the following frames:' | |
| response = '\n'.join([str(i) for i in idxs[prefix_len: ]]) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "video", "video": shuffled_frames}, | |
| {"type": "text", "text": prompt}, | |
| ] | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [ | |
| {"type": "text", "text": response} | |
| ] | |
| } | |
| ] | |
| return messages | |
| def load_video_item(self, video_item) -> List[Image.Image]: | |
| """ | |
| video_item: | |
| {"video_file": "/path/to/video", "n_frames": 8} | |
| {"video_file": "/path/to/video", "frame_indices": [0, 1, 2], "n_frames": 3} | |
| {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100, "n_frames": 8} | |
| {"video_file": "/path/to/video", "time_indices": [0, 1, 2], "n_frames": 3} | |
| {"video_file": "/path/to/video", "start_time": 0, "end_time": 100, "n_frames": 8} | |
| {"image_file": ["/path/to/image"], "frame_indices": [0, 1, 2], "n_frames": 3} | |
| """ | |
| # check format | |
| if ("image_file" not in video_item) and ("video_file" not in video_item): | |
| raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item") | |
| video_path = video_item.get('video_file', video_item.get('image_file')) | |
| n_frames = video_item.get('n_frames', None) | |
| frame_indices = video_item.get('frame_indices', None) | |
| start_frame = video_item.get('start_frame', None) | |
| end_frame = video_item.get('end_frame', None) | |
| time_indices = video_item.get('time_indices', None) | |
| start_time = video_item.get('start_time', None) | |
| end_time = video_item.get('end_time', None) | |
| mask_boxes = video_item.get('mask_boxes', None) | |
| n_frames = random.choice(self.frame_nums) | |
| n = self.video_sampling_strategy.get('force_frames_n_divisible', 1) | |
| if n > 1 and n_frames % n != 0: | |
| n_frames += n - n_frames % n | |
| frames, frame_indices = sample_video( | |
| video_path=video_path, | |
| frame_indices=frame_indices, | |
| start_frame=start_frame, | |
| end_frame=end_frame, | |
| n_frames=n_frames, | |
| time_indices=time_indices, | |
| start_time=start_time, | |
| end_time=end_time, | |
| mask_boxes=mask_boxes, | |
| is_training=self.is_training, | |
| video_sampling_strategy=self.video_sampling_strategy, | |
| return_frame_ids=True, | |
| ) | |
| return frames | |
| def add_text_to_frames(self, frames: List[Image.Image]): | |
| new_frames = [] | |
| for i, image in enumerate(frames): | |
| draw = ImageDraw.Draw(image) | |
| font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 100) | |
| text_position = (50, 50) | |
| text_content = f'{i+1}' | |
| text_color = (255, 0, 0) | |
| draw.text(text_position, text_content, font=font, fill=text_color) | |
| new_frames.append(image) | |
| return new_frames | |