Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import hf_hub_download | |
| from torch.nn.functional import pad, normalize, softmax | |
| from manipulate_model.model import Model | |
| def get_config_and_model(model_root="manipulate_model/demo-model/audio"): | |
| config_path = os.path.join(model_root, "config.yaml") | |
| config = OmegaConf.load(config_path) | |
| if isinstance(config.model.encoder, str): | |
| config.model.encoder = OmegaConf.load(config.model.encoder) | |
| if isinstance(config.model.decoder, str): | |
| config.model.decoder = OmegaConf.load(config.model.decoder) | |
| model = Model(config=config) | |
| model_file = hf_hub_download("arnabdas8901/manipulation_detection_transformer", filename= "weights.pt") | |
| weights = torch.load(model_file, map_location=torch.device("cpu")) | |
| model.load_state_dict(weights["model_state_dict"]) | |
| print("### Model loaded from :", model_file) | |
| return config, model | |
| def load_audio(file_path, config): | |
| # Load audio | |
| # Parameters | |
| # ---------- | |
| # file_path : str | |
| # Path to audio file | |
| # Returns | |
| # ------- | |
| # torch.Tensor | |
| audio = None | |
| if file_path.endswith(".wav") or file_path.endswith(".flac") or file_path.endswith(".mp3"): | |
| audio, sample_rate = torchaudio.load(file_path) | |
| if sample_rate != config.data.sr: | |
| print("requires resampling") | |
| audio = torchaudio.functional.resample(audio, sample_rate, config.data.sr) | |
| elif file_path.endswith(".mp4"): | |
| #_, audio, _ = read_video(file_path) | |
| pass | |
| return preprocess_audio(audio, config) | |
| def preprocess_audio(audio, config, step_size=1): | |
| # Preprocess audio | |
| # Parameters | |
| # ---------- | |
| # audio : torch.Tensor | |
| # Audio signal | |
| # config : OmegaConf | |
| # Configuration object | |
| # Returns | |
| # ------- | |
| # torch.Tensor : Normalized audio signal | |
| window_size = config.data.window_size | |
| sr = config.data.sr | |
| fps = config.data.fps | |
| if audio.shape[0] > 1: | |
| print("Warning: multi channel audio") | |
| audio = audio[0].unsqueeze(0) | |
| audio_len = audio.shape[1] | |
| step_size = step_size * (sr // fps) | |
| window_size = window_size * (sr // fps) | |
| audio = pad(audio, (window_size, window_size), "constant", 0) | |
| sliced_audio = [] | |
| for i in range(0, audio_len + window_size, step_size): | |
| audio_slice = audio[:, i : i + window_size] | |
| if audio_slice.shape[1] < window_size: | |
| audio_slice = pad( | |
| audio_slice, (0, window_size - audio_slice.shape[1]), "constant", 0 | |
| ) | |
| audio_slice = normalize(audio_slice, dim=1) | |
| sliced_audio.append(audio_slice) | |
| sliced_audio = torch.stack(sliced_audio).squeeze() | |
| return sliced_audio | |
| def infere(model, x, config, device="cpu", bs=8): | |
| print(x) | |
| model.eval() | |
| x = load_audio(x, config) | |
| # Inference (x is a stack of windows) | |
| frame_predictions = [] | |
| with torch.no_grad(): | |
| n_iter = x.shape[0] | |
| for i in range(0, n_iter, bs): | |
| input_batch = x[i: i + bs] | |
| input_batch = input_batch.to(device) | |
| output = softmax(model(input_batch), dim=1) | |
| frame_predictions.append(output.cpu().numpy()) | |
| frame_predictions = np.concatenate(frame_predictions, axis=0)[:,0] | |
| return frame_predictions | |
| def convert_frame_predictions_to_timestamps(frame_predictions, fps, window_size): | |
| # Convert frame predictions to timestamps | |
| # Parameters | |
| # ---------- | |
| # frame_predictions : np.ndarray | |
| # Frame predictions | |
| # fps : int | |
| # Frames per second | |
| # Returns | |
| # ------- | |
| # np.ndarray : Timestamps | |
| frame_predictions = ( | |
| frame_predictions[ | |
| int(window_size / 2) : -int(window_size / 2), 0 | |
| ] # removes the padding, does not consider step size as of now | |
| .round() | |
| .astype(int) | |
| ) | |
| timestamps = [] | |
| for i, frame_prediction in enumerate(frame_predictions): | |
| if frame_prediction == 1: | |
| timestamps.append(i / fps) | |
| return timestamps | |