File size: 2,787 Bytes
528efee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# -*- coding: utf-8 -*-
# Time      :2025/3/29 10:28
# Author    :Hui Huang
import json

import torch
import torch.nn as nn
import yaml

from .tokenizer_utils import load_config
import os
from safetensors.torch import load_file


class SparkBaseModel(nn.Module):
    @classmethod
    def from_pretrained(cls, model_path: str):
        config = load_config(os.path.join(model_path, "config.yaml"))['audio_tokenizer']
        model = cls(config)
        state_dict = load_file(os.path.join(model_path, "model.safetensors"))
        model.load_state_dict(state_dict, strict=False)
        model.eval()
        model.remove_weight_norm()
        return model

    def remove_weight_norm(self):
        """Removes weight normalization from all layers."""

        def _remove_weight_norm(m):
            try:
                torch.nn.utils.remove_weight_norm(m)
            except ValueError:
                pass  # The module didn't have weight norm

        self.apply(_remove_weight_norm)


class SnacBaseModel(nn.Module):
    @classmethod
    def from_config(cls, config_path):
        with open(config_path, "r") as f:
            config = json.load(f)
        model = cls(**config)
        return model

    @classmethod
    def from_pretrained(cls, model_path: str):
        model = cls.from_config(os.path.join(model_path, "config.json"))
        state_dict = torch.load(
            os.path.join(model_path, "pytorch_model.bin"),
            map_location="cpu", weights_only=True)
        model.load_state_dict(state_dict, strict=False)
        model.eval()
        return model


class MegaBaseModel(nn.Module):
    CKPT_NAME = "model"

    @classmethod
    def from_pretrained(cls, model_path: str):
        config_file = None
        ckpt_path = None
        for file in os.listdir(model_path):
            if file.endswith(".ckpt"):
                ckpt_path = os.path.join(model_path, file)
            if file.endswith(".yaml"):
                config_file = os.path.join(model_path, file)
        if ckpt_path is None:
            raise FileNotFoundError(f"No checkpoint found at {model_path}")

        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
        state_dict_all = {
            k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()
        }
        state_dict = state_dict_all[cls.CKPT_NAME]
        state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}

        if config_file is not None:
            with open(config_file) as f:
                config = yaml.safe_load(f)
            model = cls(config)
        else:
            model = cls()
        model.load_state_dict(state_dict, strict=False)
        model.eval()
        return model