Spaces:
Running
on
L4
Running
on
L4
| # -*- coding: utf-8 -*- | |
| # Time :2025/3/29 10:28 | |
| # Author :Hui Huang | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from .tokenizer_utils import load_config | |
| import os | |
| from safetensors.torch import load_file | |
| class SparkBaseModel(nn.Module): | |
| def from_pretrained(cls, model_path: str): | |
| config = load_config(os.path.join(model_path, "config.yaml"))['audio_tokenizer'] | |
| model = cls(config) | |
| state_dict = load_file(os.path.join(model_path, "model.safetensors")) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| model.remove_weight_norm() | |
| return model | |
| def remove_weight_norm(self): | |
| """Removes weight normalization from all layers.""" | |
| def _remove_weight_norm(m): | |
| try: | |
| torch.nn.utils.remove_weight_norm(m) | |
| except ValueError: | |
| pass # The module didn't have weight norm | |
| self.apply(_remove_weight_norm) | |
| class SnacBaseModel(nn.Module): | |
| def from_config(cls, config_path): | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| model = cls(**config) | |
| return model | |
| def from_pretrained(cls, model_path: str): | |
| model = cls.from_config(os.path.join(model_path, "config.json")) | |
| state_dict = torch.load( | |
| os.path.join(model_path, "pytorch_model.bin"), | |
| map_location="cpu", weights_only=True) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| return model | |
| class MegaBaseModel(nn.Module): | |
| CKPT_NAME = "model" | |
| def from_pretrained(cls, model_path: str): | |
| config_file = None | |
| ckpt_path = None | |
| for file in os.listdir(model_path): | |
| if file.endswith(".ckpt"): | |
| ckpt_path = os.path.join(model_path, file) | |
| if file.endswith(".yaml"): | |
| config_file = os.path.join(model_path, file) | |
| if ckpt_path is None: | |
| raise FileNotFoundError(f"No checkpoint found at {model_path}") | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) | |
| state_dict_all = { | |
| k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items() | |
| } | |
| state_dict = state_dict_all[cls.CKPT_NAME] | |
| state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()} | |
| if config_file is not None: | |
| with open(config_file) as f: | |
| config = yaml.safe_load(f) | |
| model = cls(config) | |
| else: | |
| model = cls() | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| return model |