tiny-bird-diffusion / pipeline.py
sukriramli's picture
Upload pipeline.py with huggingface_hub
bb29705 verified
Raw
History Blame Contribute Delete
1.76 kB
import os, sys, torch, joblib, importlib
import pandas as pd
import torchaudio.transforms as T
class BioacousticEngine:
def __init__(self, repo_dir="tiny-bird-diffusion"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.repo_dir = repo_dir
sys.path.append(os.path.abspath(repo_dir))
from cvt import cvt13
mel_module = importlib.import_module("mel_spectrogram")
self.preprocessor = mel_module.MelSpectrogramProcessor(device=self.device)
self.model = cvt13()
self.model.load_state_dict(torch.load(f"{repo_dir}/protoclr.pth", map_location="cpu"))
self.model = self.model.to(self.device).eval()
brain_data = joblib.load(f"{repo_dir}/trained_cluster_brain.joblib")
self.reducer = brain_data['umap']
self.df = pd.read_csv(f"{repo_dir}/acoustic_atlas_metadata.csv")
def process_waveform(self, waveform, sample_rate):
if sample_rate != 16000: waveform = T.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
total_samples = waveform.shape[-1]
target_samples = 3 * 16000
if total_samples > target_samples:
step, max_energy, best_start = 4000, -1, 0
for start in range(0, total_samples - target_samples + 1, step):
energy = waveform[:, start:start + target_samples].abs().mean().item()
if energy > max_energy: max_energy, best_start = energy, start
waveform = waveform[:, best_start:best_start + target_samples]
if waveform.abs().max() > 0.02: waveform = waveform / waveform.abs().max()
return waveform.to(self.device)