Spaces:
Runtime error
Runtime error
| #! /usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2023 Imperial College London (Pingchuan Ma) | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| import torch | |
| import torchaudio | |
| import torchvision | |
| class FunctionalModule(torch.nn.Module): | |
| def __init__(self, functional): | |
| super().__init__() | |
| self.functional = functional | |
| def forward(self, input): | |
| return self.functional(input) | |
| class VideoTransform: | |
| def __init__(self, speed_rate): | |
| self.video_pipeline = torch.nn.Sequential( | |
| FunctionalModule(lambda x: x.unsqueeze(-1)), | |
| FunctionalModule(lambda x: x if speed_rate == 1 else torch.index_select(x, dim=0, index=torch.linspace(0, x.shape[0]-1, int(x.shape[0] / speed_rate), dtype=torch.int64))), | |
| FunctionalModule(lambda x: x.permute(3, 0, 1, 2)), | |
| FunctionalModule(lambda x: x / 255.), | |
| torchvision.transforms.CenterCrop(88), | |
| torchvision.transforms.Normalize(0.421, 0.165), | |
| ) | |
| def __call__(self, sample): | |
| return self.video_pipeline(sample) | |
| class AudioTransform: | |
| def __init__(self): | |
| self.audio_pipeline = torch.nn.Sequential( | |
| FunctionalModule(lambda x: torch.nn.functional.layer_norm(x, x.shape, eps=0)), | |
| FunctionalModule(lambda x: x.transpose(0, 1)), | |
| ) | |
| def __call__(self, sample): | |
| return self.audio_pipeline(sample) | |