Spaces:
Runtime error
Runtime error
| from typing import Dict, List | |
| from PIL import Image | |
| import random | |
| from .utils import sample_video, read_image, adjust_bbox, filter_ocr_polygon | |
| class VisionParser: | |
| def __init__( | |
| self, | |
| n_frames=8, | |
| max_n_frames=256, | |
| is_training=True, | |
| video_sampling_strategy={}, | |
| ): | |
| self.n_frames = n_frames | |
| self.max_n_frames = max_n_frames | |
| self.is_training = is_training | |
| self.video_sampling_strategy = video_sampling_strategy | |
| # fmt: off | |
| self.data_temp = { | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "Describe the image and the video."}, | |
| # 支持的 image 格式: | |
| {"type": "image", "image": {"image_file": "/path/to/image"}}, | |
| {"type": "image", "image": {"video_file": "/path/to/video", "frame_indices": 0}}, | |
| # 支持的 video 格式: | |
| {"type": "video", "video": {"video_file": "/path/to/video"}}, | |
| {"type": "video", "video": {"video_file": "/path/to/video", "frame_indices": [0, 1, 2]}}, | |
| {"type": "video", "video": {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100}}, | |
| {"type": "video", "video": {"video_file": "/path/to/video", "time_indices": [0, 1, 2]}}, | |
| {"type": "video", "video": {"video_file": "/path/to/video", "start_time": 0, "end_time": 100}}, | |
| {"type": "video", "video": {"image_file": ["/path/to/image"]}, "frame_indices": [0, 1, 2]}, | |
| ] | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [ | |
| {"type": "text","text": "xxx"} | |
| ] | |
| } | |
| ], | |
| "dataset": "LSMDC", | |
| "task": "video/caption" | |
| } | |
| # fmt: on | |
| def check_format(self, data_dict: Dict, image_processing_config: Dict): | |
| if image_processing_config.get('do_crop', False) and image_processing_config.get('has_coordinates', False): | |
| raise ValueError(f'do_crop and has_coordinates cannot be True at the same time!') | |
| """ | |
| 1. 将 messages 中的 image/video 替换成相应的 PIL.Image/List[PIL.Image] | |
| 2. text 的特殊处理:调整 box;过滤面积太小的OCR | |
| """ | |
| def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict: | |
| self.check_format(data_dict, image_processing_config) | |
| self.set_n_frames(data_dict) | |
| first_image = None # ugly! 需要调整box/过滤面积太小的OCR的数据只有图片任务 | |
| for msg in data_dict['messages']: | |
| if isinstance(msg['content'], dict): | |
| msg['content'] = [msg['content']] | |
| for content in msg['content']: | |
| if content['type'] == 'image': | |
| content['image'] = self.load_image_item(content['image']) | |
| if first_image is None: | |
| first_image = content['image'] | |
| elif content['type'] == 'video': | |
| video = self.load_video_item(content['video']) | |
| content['video'] = video.pop('frames') | |
| if video: | |
| data_dict['extra_info']['frame_disturb_info'] = video.pop('video_info', {}) | |
| elif content['type'] == 'text': | |
| pass | |
| else: | |
| raise ValueError(f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']") | |
| for msg in data_dict['messages']: | |
| for content in msg['content']: | |
| if content['type'] == 'text': | |
| self.postprocess_text(content, data_dict, image_processing_config, first_image) | |
| return data_dict['messages'] | |
| # set n_frames for each vision item. | |
| def set_n_frames(self, data_dict): | |
| if isinstance(self.n_frames, int): | |
| n_frames = self.n_frames | |
| else: | |
| n_frames = random.choice(self.n_frames) | |
| assert n_frames <= self.max_n_frames | |
| curr_n_frames = 0 | |
| has_dynamic = False | |
| for msg in data_dict['messages']: | |
| if isinstance(msg['content'], dict): | |
| msg['content'] = [msg['content']] | |
| for content in msg['content']: | |
| if content['type'] == 'image': | |
| curr_n_frames += 1 | |
| elif content['type'] == 'video': | |
| if 'frame_indices' in content['video']: | |
| curr_n_frames += len(content['video']['frame_indices']) | |
| content['video']['n_frames'] = len(content['video']['frame_indices']) | |
| elif 'time_indices' in content['video']: | |
| curr_n_frames += len(content['video']['time_indices']) | |
| content['video']['n_frames'] = len(content['video']['time_indices']) | |
| elif 'min_n_frames' in content['video']: | |
| content['video']['min_n_frames'] = int(content['video']['min_n_frames']) | |
| curr_n_frames += content['video']['min_n_frames'] | |
| content['video']['n_frames'] = content['video']['min_n_frames'] | |
| has_dynamic = True | |
| elif 'fps' in content['video']: | |
| content['video']['n_frames'] = self.max_n_frames | |
| curr_n_frames += self.max_n_frames | |
| has_dynamic = True | |
| else: | |
| content['video']['n_frames'] = 0 | |
| has_dynamic = True | |
| while curr_n_frames < n_frames and has_dynamic: | |
| for msg in data_dict['messages']: | |
| for content in msg['content']: | |
| if content['type'] == 'video': | |
| if 'frame_indices' in content['video']: | |
| pass | |
| elif 'time_indices' in content['video']: | |
| pass | |
| else: | |
| if curr_n_frames < n_frames: | |
| content['video']['n_frames'] += 1 | |
| curr_n_frames += 1 | |
| while curr_n_frames > self.max_n_frames and has_dynamic: | |
| for msg in data_dict['messages']: | |
| for content in msg['content']: | |
| if content['type'] == 'video': | |
| if 'frame_indices' in content['video']: | |
| pass | |
| elif 'time_indices' in content['video']: | |
| pass | |
| else: | |
| if curr_n_frames > self.max_n_frames: | |
| content['video']['n_frames'] -= 1 | |
| curr_n_frames -= 1 | |
| for msg in data_dict['messages']: | |
| for content in msg['content']: | |
| if content['type'] == 'video': | |
| if 'frame_indices' in content['video']: | |
| pass | |
| elif 'time_indices' in content['video']: | |
| pass | |
| else: | |
| n = self.video_sampling_strategy.get('force_frames_n_divisible', 1) | |
| if n > 1 and content['video']['n_frames'] % n != 0: | |
| content['video']['n_frames'] += n - content['video']['n_frames'] % n | |
| def load_image_item(self, image_item) -> Image.Image: | |
| """ | |
| image_item: | |
| {"image_file": {"lq": "/path/to/image"}} | |
| {"video_file": {"lq": "/path/to/video"}, "frame_indices": 0} | |
| """ | |
| # check format | |
| if ("image_file" not in image_item) and ("video_file" not in image_item): | |
| raise KeyError(f"Key 'image_file' or 'video_file' not found in image_item") | |
| if 'image_file' in image_item: | |
| if not isinstance(image_item['image_file'], str): | |
| raise ValueError(f"{image_item['image_file']} is not a str!") | |
| if 'video_file' in image_item: | |
| if not isinstance(image_item['frame_indices'], int): | |
| raise ValueError(f"{image_item['frame_indices']} is not a int!") | |
| if 'image_file' in image_item: | |
| image = read_image(image_item['image_file']) | |
| else: | |
| frame_indices = [image_item['frame_indices']] | |
| image = sample_video(image_item['video_file'], frame_indices = frame_indices)[0] | |
| return image | |
| def load_video_item(self, video_item) -> List[Image.Image]: | |
| """ | |
| video_item: | |
| {"video_file": {"lq": "/path/to/video"}, "n_frames": 8} | |
| {"video_file": {"lq": "/path/to/video"}, "frame_indices": [0, 1, 2], "n_frames": 3} | |
| {"video_file": {"lq": "/path/to/video"}, "start_frame": 0, "end_frame": 100, "n_frames": 8} | |
| {"video_file": {"lq": "/path/to/video"}, "time_indices": [0, 1, 2], "n_frames": 3} | |
| {"video_file": {"lq": "/path/to/video"}, "start_time": 0, "end_time": 100, "n_frames": 8} | |
| {"image_file": {"lq": ["/path/to/image"]}, "frame_indices": [0, 1, 2], "n_frames": 3} | |
| """ | |
| # check format | |
| if ("image_file" not in video_item) and ("video_file" not in video_item): | |
| raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item") | |
| video_path = video_item.get('video_file', video_item.get('image_file')) | |
| n_frames = video_item.get('n_frames', None) | |
| frame_indices = video_item.get('frame_indices', None) | |
| start_frame = video_item.get('start_frame', None) | |
| end_frame = video_item.get('end_frame', None) | |
| time_indices = video_item.get('time_indices', None) | |
| start_time = video_item.get('start_time', None) | |
| end_time = video_item.get('end_time', None) | |
| mask_boxes = video_item.get('mask_boxes', None) | |
| fps = video_item.get('fps', None) | |
| frames, frame_indices = sample_video( | |
| video_path=video_path, | |
| frame_indices=frame_indices, | |
| start_frame=start_frame, | |
| end_frame=end_frame, | |
| n_frames=n_frames, | |
| time_indices=time_indices, | |
| start_time=start_time, | |
| end_time=end_time, | |
| sampling_fps=fps, | |
| mask_boxes=mask_boxes, | |
| is_training=self.is_training, | |
| video_sampling_strategy=self.video_sampling_strategy, | |
| return_frame_ids=True, | |
| ) | |
| if self.video_sampling_strategy.get('use_multi_images_for_video', False): | |
| new_frames = [] | |
| for f in frames: | |
| new_frames.extend([f, f]) | |
| frames = new_frames | |
| if isinstance(frame_indices, dict): | |
| return { | |
| 'frames': frames, | |
| 'video_info': frame_indices | |
| } | |
| return {'frames': frames} | |
| def postprocess_text(self, content, data_dict, image_processing_config, first_image): | |
| if image_processing_config.get('has_coordinates') and image_processing_config.get('do_padding'): | |
| content['text'] = adjust_bbox(content['text'], frame=first_image) | |
| if data_dict.get('task') == 'image/OCR' and image_processing_config.get('has_coordinates'): | |
| content['text'] = filter_ocr_polygon(content['text']) | |