Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| from PIL import Image | |
| import numpy as np | |
| def load_image(image, image_processor): | |
| if isinstance(image, str): # is a image path | |
| raw_image = Image.open(image).convert('RGB') | |
| image = image_processor(raw_image).unsqueeze(0) | |
| elif isinstance(image, Image.Image): | |
| raw_image = image | |
| image = image_processor(raw_image).unsqueeze(0) | |
| elif isinstance(image, torch.Tensor): | |
| if len(image.shape) == 3: | |
| image = image.unsqueeze(0) | |
| return image | |
| def load_audio(audio, audio_processor): | |
| if isinstance(audio, str): # is a audio path | |
| raw_audio = torchaudio.load(audio) | |
| audio = audio_processor(raw_audio) | |
| elif isinstance(audio, tuple): | |
| sample_rate, raw_waveform = audio | |
| waveform = raw_waveform / np.iinfo(raw_waveform.dtype).max | |
| if waveform.ndim == 1: | |
| waveform = torch.from_numpy(waveform[None, :]) | |
| elif waveform.ndim == 2: | |
| waveform = torch.from_numpy(waveform).mean(1).unsqueeze(0) | |
| else: | |
| raise NotImplementedError # "No such data!" | |
| audio = audio_processor((waveform, sample_rate)) | |
| else: | |
| raise NotImplementedError | |
| return audio.unsqueeze(0) | |