GigaAM / Examples /ssl_inference.py
niobures's picture
GigaAM
74e8c79 verified
import argparse
import hydra
import soundfile
import torch
from omegaconf import OmegaConf
class SpecScaler(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.log(x.clamp_(1e-9, 1e9))
def _parse_args():
parser = argparse.ArgumentParser(
description="Run inference using GigaAM checkpoint"
)
parser.add_argument("--encoder_config", help="Path to GigaAM config file (.yaml)")
parser.add_argument(
"--model_weights", help="Path to GigaAM checkpoint file (.ckpt)"
)
parser.add_argument("--audio_path", help="Path to audio signal")
parser.add_argument("--device", help="Device: cpu / cuda")
return parser.parse_args()
def main(encoder_config: str, model_weights: str, device: str, audio_path: str):
conf = OmegaConf.load(encoder_config)
encoder = hydra.utils.instantiate(conf.encoder)
ckpt = torch.load(model_weights, map_location="cpu")
encoder.load_state_dict(ckpt, strict=True)
encoder.to(device)
feature_extractor = hydra.utils.instantiate(conf.feature_extractor)
audio_signal, _ = soundfile.read(audio_path, dtype="float32")
features = feature_extractor(torch.tensor(audio_signal).float())
features = features.to(device)
encoded, _ = encoder.forward(
audio_signal=features.unsqueeze(0),
length=torch.tensor([features.shape[-1]]).to(device),
)
print(f"encoded signal shape: {encoded.shape}")
if __name__ == "__main__":
args = _parse_args()
main(
encoder_config=args.encoder_config,
model_weights=args.model_weights,
device=args.device,
audio_path=args.audio_path,
)