Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from decord import VideoReader, cpu | |
| from PIL import Image | |
| from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
| from llava.conversation import conv_templates | |
| from llava.mm_utils import (KeywordsStoppingCriteria, get_model_name_from_path, | |
| process_images, tokenizer_image_token) | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| title_markdown = (""" | |
| <div style="display: flex; justify-content: flex-start; align-items: center; text-align: center;"> | |
| <div style="margin-right: 20px; display: flex; align-items: center;"> | |
| <a href="https://github.com/ShareGPT4Omni/ShareGPT4Video" style="text-decoration: none; display: flex; align-items: center;"> | |
| <img src="https://raw.githubusercontent.com/ShareGPT4V/ShareGPT4V-Resources/master/images/share4video_tight.png" alt="ShareGPT4Video🚀" style="max-width: 120px; height: auto;"> | |
| </a> | |
| </div> | |
| <div> | |
| <h1>ShareGPT4Video: Improving Video Understanding and Generation with Better Captions</h1> | |
| <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5> | |
| <h5 style="margin: 0;"> <a href="https://sharegpt4video.github.io/">[Project Page]</a> <a href="https://github.com/ShareGPT4Omni/ShareGPT4Video">[Code]</a> <a href="https://arxiv.org/abs/2406.04325v1">[Paper]</a> | |
| </div> | |
| </div> | |
| """) | |
| block_css = """ | |
| #buttons button { | |
| min-width: min(120px,100%); | |
| } | |
| """ | |
| learn_more_markdown = (""" | |
| ### License | |
| The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. | |
| """) | |
| def create_frame_grid(img_array, interval_width=50): | |
| n, h, w, c = img_array.shape | |
| grid_size = int(np.ceil(np.sqrt(n))) | |
| horizontal_band = np.ones((h, interval_width, c), | |
| dtype=img_array.dtype) * 255 | |
| vertical_band = np.ones((interval_width, w + (grid_size - 1) | |
| * (w + interval_width), c), dtype=img_array.dtype) * 255 | |
| rows = [] | |
| for i in range(grid_size): | |
| row_frames = [] | |
| for j in range(grid_size): | |
| idx = i * grid_size + j | |
| if idx < n: | |
| frame = img_array[idx] | |
| else: | |
| frame = np.ones_like(img_array[0]) * 255 | |
| if j > 0: | |
| row_frames.append(horizontal_band) | |
| row_frames.append(frame) | |
| combined_row = np.concatenate(row_frames, axis=1) | |
| if i > 0: | |
| rows.append(vertical_band) | |
| rows.append(combined_row) | |
| final_grid = np.concatenate(rows, axis=0) | |
| return final_grid | |
| def resize_image_grid(image, max_length=1920): | |
| width, height = image.size | |
| if max(width, height) > max_length: | |
| if width > height: | |
| scale = max_length / width | |
| else: | |
| scale = max_length / height | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| img_resized = image.resize((new_width, new_height), Image.BILINEAR) | |
| else: | |
| img_resized = image | |
| return img_resized | |
| def get_index(num_frames, num_segments): | |
| seg_size = float(num_frames - 1) / num_segments | |
| start = int(seg_size / 2) | |
| offsets = np.array([ | |
| start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
| ]) | |
| return offsets | |
| def load_video(video_path, num_segments=8, return_msg=False, num_frames=4): | |
| vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
| num_frames = len(vr) | |
| frame_indices = get_index(num_frames, num_segments) | |
| img_array = vr.get_batch(frame_indices).asnumpy() | |
| img_grid = create_frame_grid(img_array, 50) | |
| img_grid = Image.fromarray(img_grid).convert("RGB") | |
| img_grid = resize_image_grid(img_grid) | |
| if return_msg: | |
| fps = float(vr.get_avg_fps()) | |
| sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) | |
| # " " should be added in the start and end | |
| msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." | |
| return img_grid, msg | |
| else: | |
| return img_grid | |
| def video_answer(prompt, model, processor, tokenizer, img_grid, do_sample=True, | |
| max_new_tokens=200, num_beams=1, top_p=0.9, | |
| temperature=1.0, print_res=False, **kwargs): | |
| if not isinstance(img_grid, (list, tuple)): | |
| img_grid = [img_grid] | |
| image_size = img_grid[0].size | |
| image_tensor = process_images(img_grid, processor, model.config)[0] | |
| input_ids = tokenizer_image_token( | |
| prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') | |
| input_ids = input_ids.unsqueeze(0).to( | |
| device=model.device, non_blocking=True) | |
| pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token is not None else tokenizer.eos_token_id | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=image_tensor.to( | |
| dtype=torch.float16, device=model.device, non_blocking=True), | |
| image_sizes=[image_size], | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=pad_token_id, | |
| use_cache=True, | |
| **kwargs) | |
| outputs = tokenizer.batch_decode( | |
| output_ids, skip_special_tokens=True)[0].strip() | |
| if print_res: # debug usage | |
| print('### PROMPTING LM WITH: ', prompt) | |
| print('### LM OUTPUT TEXT: ', outputs) | |
| return outputs | |
| class Chat: | |
| def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', num_frames=16): | |
| disable_torch_init() | |
| model_name = get_model_name_from_path(model_path) | |
| self.tokenizer, self.model, self.processor, context_len = load_pretrained_model( | |
| model_path, model_base, model_name, | |
| load_8bit, load_4bit, | |
| device=device) | |
| self.model.eval() | |
| self.conv_mode = conv_mode | |
| self.device = self.model.device | |
| self.num_frames = num_frames | |
| self.pre_query_prompt = "The provided image arranges keyframes from a video in a grid view, keyframes are separated with white bands. Answer concisely with overall content and context of the video, highlighting any significant events, characters, or objects that appear throughout the frames." | |
| def get_prompt(self, qs, state): | |
| state.append_message(state.roles[0], qs) | |
| state.append_message(state.roles[1], None) | |
| return state | |
| def generate(self, vid_path: list, prompt: str, first_run: bool, state): | |
| if self.num_frames != 0: | |
| vid, msg = load_video( | |
| vid_path, num_segments=self.num_frames, return_msg=True) | |
| else: | |
| vid, msg = None, 'num_frames is 0, not inputing image' | |
| img_grid = vid | |
| if self.pre_query_prompt is not None: | |
| prompt = DEFAULT_IMAGE_TOKEN + '\n' + self.pre_query_prompt + prompt | |
| else: | |
| prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt | |
| state = self.get_prompt(prompt, state) | |
| prompt = state.get_prompt() | |
| llm_response = video_answer(prompt, model=self.model, processor=self.processor, tokenizer=self.tokenizer, | |
| do_sample=True, temperature=0.1, img_grid=img_grid, max_new_tokens=1024, print_res=True) | |
| return llm_response, state | |