Spaces:
Runtime error
Runtime error
| #! /usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2023 Imperial College London (Pingchuan Ma) | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| import torch | |
| import whisper | |
| import torchaudio | |
| import torchvision | |
| from .transforms import AudioTransform, VideoTransform | |
| class AVSRDataLoader: | |
| def __init__(self, modality, speed_rate=1, transform=True, detector="retinaface", convert_gray=True): | |
| self.modality = modality | |
| self.transform = transform | |
| if self.modality in ["audio", "audiovisual"]: | |
| self.audio_transform = AudioTransform() | |
| if self.modality in ["video", "audiovisual"]: | |
| if detector == "mediapipe": | |
| from pipelines.detectors.mediapipe.video_process import VideoProcess | |
| self.video_process = VideoProcess(convert_gray=convert_gray) | |
| if detector == "retinaface": | |
| from pipelines.detectors.retinaface.video_process import VideoProcess | |
| self.video_process = VideoProcess(convert_gray=convert_gray) | |
| self.video_transform = VideoTransform(speed_rate=speed_rate) | |
| def load_data(self, data_filename, landmarks=None, transform=True): | |
| if self.modality == "audio": | |
| audio = self.load_audio(data_filename) | |
| return self.audio_transform(audio) if self.transform else audio | |
| if self.modality == "video": | |
| video = self.load_video(data_filename) | |
| video = self.video_process(video, landmarks) | |
| video = torch.tensor(video) | |
| return self.video_transform(video) if self.transform else video | |
| if self.modality == "audiovisual": | |
| rate_ratio = 640 | |
| audio = self.load_audio(data_filename) | |
| video = self.load_video(data_filename) | |
| video = self.video_process(video, landmarks) | |
| video = torch.tensor(video) | |
| min_t = min(len(video), audio.size(1) // rate_ratio) | |
| audio = audio[:, :min_t*rate_ratio] | |
| video = video[:min_t] | |
| if self.transform: | |
| audio = self.audio_transform(audio) | |
| video = self.video_transform(video) | |
| return video, audio | |
| def load_audio(self, data_filename): | |
| # rtype: [1, T] | |
| waveform = torch.tensor(whisper.load_audio(data_filename)).unsqueeze(0) | |
| return waveform | |
| def load_video(self, data_filename): | |
| return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy() | |