Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| import glob | |
| import os | |
| import tempfile | |
| from collections import defaultdict | |
| from io import BytesIO | |
| from typing import Any, Dict, List, Optional, Union | |
| import cv2 | |
| import numpy as np | |
| import PIL | |
| import PIL.Image | |
| import requests | |
| from transformers import PretrainedConfig | |
| from pydub import AudioSegment | |
| from llava.constants import MEDIA_TOKENS | |
| from llava.media import Image, Video, Speech, Sound | |
| from llava.utils import make_list | |
| from llava.utils.logging import logger | |
| import torch | |
| import whisper | |
| import soundfile as sf | |
| from librosa import resample as librosa_resample | |
| from transformers import AutoFeatureExtractor | |
| import math | |
| from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler | |
| import kaldiio | |
| # wav_processor = AutoFeatureExtractor.from_pretrained('pretrained_models/AF-Whisper') | |
| wav_processor = AutoFeatureExtractor.from_pretrained('Qwen/Qwen2-Audio-7B') | |
| __all__ = ["extract_media"] | |
| def int16_to_float32(x): | |
| return (x / 32767.0).astype(np.float32) | |
| def float32_to_int16(x): | |
| x = np.clip(x, a_min=-1., a_max=1.) | |
| return (x * 32767.).astype(np.int16) | |
| def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: | |
| if isinstance(image, Image): | |
| if image.path.startswith("http://") or image.path.startswith("https://"): | |
| image = PIL.Image.open(requests.get(image.path, stream=True).raw) | |
| else: | |
| image = PIL.Image.open(image.path) | |
| return image | |
| def _load_video_bytesio(video_bytesio: BytesIO, *, num_frames: int) -> List[PIL.Image.Image]: | |
| with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: | |
| temp_video.write(video_bytesio.read()) | |
| temp_video_name = temp_video.name | |
| return _load_video(temp_video_name, num_frames=num_frames) | |
| def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]: | |
| # Load video frames from a directory | |
| if os.path.isdir(video_path): | |
| frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) | |
| indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) | |
| return [PIL.Image.open(frame_paths[index]) for index in indices] | |
| # Load video frames from a video file | |
| vidcap = cv2.VideoCapture(video_path) | |
| # Find the last frame as frame count might not be accurate | |
| frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| while frame_count > 0: | |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) | |
| if vidcap.grab(): | |
| break | |
| frame_count -= 1 | |
| else: | |
| raise ValueError(f"Video '{video_path}' has no frames.") | |
| # Extract frames uniformly | |
| indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) | |
| frames = {} | |
| for index in indices: | |
| if index in frames: | |
| continue | |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) | |
| success, frame = vidcap.read() | |
| if not success: | |
| logger.warning(f"Failed to read frame {index} from video '{video_path}'. Skipped.") | |
| continue | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames[index] = PIL.Image.fromarray(frame) | |
| return [frames[index] for index in indices if index in frames] | |
| def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: | |
| num_frames = config.num_video_frames | |
| if getattr(config, "fps") != 0: | |
| logger.warning("Extracting frames from video with specified FPS is not supported yet. Ignored.") | |
| if isinstance(video.path, BytesIO): | |
| frames = _load_video_bytesio(video.path, num_frames=num_frames) | |
| else: | |
| frames = _load_video(video.path, num_frames=num_frames) | |
| return frames | |
| def _load_speech(speech_path: str): | |
| # Load video frames from a directory | |
| if speech_path is None: | |
| return None | |
| speech_outputs = [] | |
| speech = whisper.load_audio(speech_path) | |
| speech = whisper.pad_or_trim(speech) | |
| mel = whisper.log_mel_spectrogram(speech) | |
| speech_outputs.append(mel.unsqueeze(0)) | |
| speech_frames = torch.stack(speech_outputs, dim=0) | |
| return speech_frames.numpy().tolist() | |
| def _extract_speech(speech: Speech, config: PretrainedConfig): | |
| frames = _load_speech(speech.path) | |
| return frames | |
| def _get_num_windows(T, sr): | |
| window_length = int(30.0 * sr) | |
| window_overlap = int(0.0 * sr) | |
| max_num_window = 20 | |
| num_windows = 1 | |
| if T <= window_length: | |
| num_windows = 1 | |
| full_length = window_length | |
| elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap): | |
| num_windows = max_num_window | |
| full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap) | |
| else: | |
| num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap))) | |
| full_length = num_windows * window_length - (num_windows - 1) * window_overlap | |
| return num_windows, full_length | |
| def _load_audio(file_path, target_sr=16000, duration=30.0, start=0.0): | |
| if file_path.endswith('.mp3'): | |
| audio = AudioSegment.from_file(file_path) | |
| if len(audio) > (start + duration) * 1000: | |
| audio = audio[start * 1000:(start + duration) * 1000] | |
| if audio.frame_rate != target_sr: | |
| audio = audio.set_frame_rate(target_sr) | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| data = np.array(audio.get_array_of_samples()) | |
| if audio.sample_width == 2: | |
| data = data.astype(np.float32) / np.iinfo(np.int16).max | |
| elif audio.sample_width == 4: | |
| data = data.astype(np.float32) / np.iinfo(np.int32).max | |
| else: | |
| raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) | |
| else: | |
| with sf.SoundFile(file_path) as audio: | |
| original_sr = audio.samplerate | |
| channels = audio.channels | |
| max_frames = int((start + duration) * original_sr) | |
| audio.seek(int(start * original_sr)) | |
| frames_to_read = min(max_frames, len(audio)) | |
| data = audio.read(frames_to_read) | |
| if data.max() > 1 or data.min() < -1: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| if original_sr != target_sr: | |
| if channels == 1: | |
| data = librosa_resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) | |
| else: | |
| data = librosa_resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] | |
| else: | |
| if channels != 1: | |
| data = data.T[0] | |
| if data.min() >= 0: | |
| data = 2 * data / abs(data.max()) - 1.0 | |
| else: | |
| data = data / max(abs(data.max()), abs(data.min())) | |
| assert len(data.shape) == 1, data.shape | |
| return data | |
| def _load_sound_mask(sound_file, sample_rate=16000, window_length=30.0, window_overlap=0.0, max_num_window=20, audio_start = 0.0): | |
| if sound_file is None: | |
| return None | |
| window_length = int(window_length * sample_rate) | |
| window_overlap = int(window_overlap * sample_rate) | |
| max_num_window = int(max_num_window) | |
| duration = max_num_window * (window_length - window_overlap) + window_overlap | |
| sound_outputs = [] | |
| audio_feature_masks = [] | |
| audio_embed_masks = [] | |
| try: | |
| audio_data = _load_audio(sound_file, sample_rate, duration, audio_start) # already cuts to max duration | |
| T = len(audio_data) | |
| audio_data = audio_data.reshape(1, -1) | |
| num_windows, full_length = _get_num_windows(T, sample_rate) | |
| audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() | |
| for i in range(num_windows): | |
| audio_embed_mask = torch.zeros(750) | |
| start = i * (window_length - window_overlap) | |
| audio_data_tensor_this = audio_data_tensor[:, start:start+window_length] | |
| orig_length = audio_data_tensor_this.shape[1] | |
| audio_data_tensor_this = wav_processor(audio_data_tensor_this.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt") #.squeeze(0) text="dummy", audios=audio_data_tensor_this, return_tensors="pt") # | |
| sound_outputs.append(audio_data_tensor_this["input_features"]) | |
| # calculate the mask for the input melspec to Whisper | |
| melspec_frames_this_window = int(math.ceil(orig_length / 160)) | |
| feature_attention_mask = torch.zeros(3000, dtype=torch.int32) | |
| feature_attention_mask[:melspec_frames_this_window] = 1 | |
| audio_feature_masks.append(feature_attention_mask.unsqueeze(0)) | |
| # calculate the mask for the output embedding for use in AF3 | |
| conv_lengths = (melspec_frames_this_window - 1) // 2 + 1 | |
| output_embedding_lengths = (conv_lengths - 2) // 2 + 1 | |
| audio_embed_mask[:output_embedding_lengths] = 1 | |
| audio_embed_masks.append(audio_embed_mask) | |
| except: | |
| print("Error loading sound file: ", sound_file) | |
| sound_outputs.append(torch.zeros(1,128,3000)) | |
| audio_feature_masks.append(torch.zeros(1, 3000, dtype=torch.int32)) | |
| audio_embed_masks.append(torch.zeros(750)) | |
| sound_outputs = torch.stack(sound_outputs, dim=0) | |
| audio_feature_masks = torch.stack(audio_feature_masks, dim=0) | |
| audio_embed_masks = torch.stack(audio_embed_masks, dim=0) | |
| return sound_outputs.numpy().tolist(), audio_feature_masks ,audio_embed_masks | |
| def _extract_sound_mask(sound: Sound, config: PretrainedConfig): | |
| frames, audio_feature_masks, audio_embed_masks = _load_sound_mask(sound.path) | |
| return frames, audio_feature_masks, audio_embed_masks | |
| def extract_media( | |
| messages: List[Dict[str, Any]], | |
| config: Optional[PretrainedConfig] = None, | |
| draft: bool = False, | |
| ) -> Dict[str, List[Any]]: | |
| media = defaultdict(list) | |
| media_meta = defaultdict(list) | |
| for message in messages: | |
| text = "" | |
| print(message) | |
| for part in make_list(message["value"]): | |
| if isinstance(part, str): | |
| for token in MEDIA_TOKENS.values(): | |
| if token in part: | |
| logger.warning(f"Media token '{token}' found in text: '{part}'. Removed.") | |
| part = part.replace(token, "").strip() | |
| text += part | |
| if isinstance(part, (Image, PIL.Image.Image)): | |
| if draft: | |
| media["image"].append(part) | |
| else: | |
| media["image"].append(_extract_image(part)) | |
| text += MEDIA_TOKENS["image"] | |
| if isinstance(part, Video): | |
| if draft: | |
| media["video"].append(part) | |
| else: | |
| media["video"].append(_extract_video(part, config)) | |
| text += MEDIA_TOKENS["video"] | |
| if isinstance(part, Speech): | |
| if draft: | |
| media["speech"].append(part) | |
| else: | |
| media["speech"].append(_extract_speech(part, config)) | |
| text += MEDIA_TOKENS["speech"] | |
| if isinstance(part, Sound): | |
| if draft: | |
| media["sound"].append(part) | |
| else: | |
| sound, audio_feature_masks,audio_embed_masks = _extract_sound_mask(part, config) | |
| media["sound"].append(sound) | |
| media_meta["sound_feature_masks"].append(audio_feature_masks) | |
| media_meta["sound_embed_masks"].append(audio_embed_masks) | |
| text += MEDIA_TOKENS["sound"] * len(sound) | |
| message["value"] = text | |
| return media, media_meta | |