| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import json |
| | import os |
| | import sys |
| | import re |
| |
|
| | from safetensors.torch import save_file |
| |
|
| | |
| | model_path = './model.pt' |
| |
|
| | |
| | if len(sys.argv) > 1: |
| | model_path = sys.argv[1] |
| |
|
| | |
| | path_dst = os.path.dirname(model_path) |
| |
|
| | print(f"Loading model from {model_path}") |
| |
|
| | model = torch.load(model_path, map_location='cpu') |
| |
|
| | |
| |
|
| | |
| | for key in model.keys(): |
| | print(key) |
| | if key == 'hyper_parameters': |
| | |
| | |
| | print(json.dumps(model[key], indent=4)) |
| | |
| | |
| |
|
| | |
| | if isinstance(model, torch.nn.Module): |
| | state_dict = model.state_dict() |
| | else: |
| | state_dict = model |
| |
|
| | |
| | print("State dictionary keys:") |
| | for key in state_dict.keys(): |
| | print(key) |
| |
|
| | |
| | def flatten_state_dict(state_dict, parent_key='', sep='.'): |
| | items = [] |
| | items_new = [] |
| |
|
| | for k, v in state_dict.items(): |
| | new_key = f"{parent_key}{sep}{k}" if parent_key else k |
| | if isinstance(v, torch.Tensor): |
| | items.append((new_key, v)) |
| | elif isinstance(v, dict): |
| | items.extend(flatten_state_dict(v, new_key, sep=sep).items()) |
| | return dict(items) |
| |
|
| | size_total_mb = 0 |
| |
|
| | for key, value in list(items): |
| | |
| | if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \ |
| | not key.startswith('state_dict.backbone.') and \ |
| | not key.startswith('state_dict.head.out'): |
| | print('Skipping key: ', key) |
| | continue |
| |
|
| | new_key = key |
| |
|
| | new_key = new_key.replace('state_dict.', '') |
| | new_key = new_key.replace('pos_net', 'posnet') |
| |
|
| | |
| | if new_key.startswith("backbone.posnet."): |
| | match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key) |
| | if match: |
| | new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}" |
| |
|
| | |
| | if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed": |
| | new_key = "backbone.embedding.weight" |
| |
|
| | |
| | |
| | if new_key.endswith("norm.scale.weight"): |
| | new_key = new_key.replace("norm.scale.weight", "norm.weight") |
| | value = value[0] |
| |
|
| | if new_key.endswith("norm.shift.weight"): |
| | new_key = new_key.replace("norm.shift.weight", "norm.bias") |
| | value = value[0] |
| |
|
| | if new_key.endswith("gamma"): |
| | new_key = new_key.replace("gamma", "gamma.weight") |
| |
|
| | |
| | if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")): |
| | value = value.unsqueeze(1) |
| |
|
| | if new_key.endswith("dwconv.bias"): |
| | value = value.unsqueeze(1) |
| |
|
| | size_mb = value.element_size() * value.nelement() / (1024 * 1024) |
| | print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}") |
| |
|
| | size_total_mb += size_mb |
| |
|
| | |
| | |
| |
|
| | items_new.append((new_key, value)) |
| |
|
| | print(f"Total size: {size_total_mb:8.2f} MB") |
| |
|
| | return dict(items_new) |
| |
|
| | flattened_state_dict = flatten_state_dict(state_dict) |
| |
|
| |
|
| | |
| | output_path = path_dst + '/model.safetensors' |
| | save_file(flattened_state_dict, output_path) |
| |
|
| | print(f"Model has been successfully converted and saved to {output_path}") |
| |
|
| | |
| | total_size = os.path.getsize(output_path) |
| |
|
| | |
| | weight_map = { |
| | "model.safetensors": ["*"] |
| | } |
| |
|
| | |
| | metadata = { |
| | "total_size": total_size, |
| | "weight_map": weight_map |
| | } |
| |
|
| | |
| | index_path = path_dst + '/index.json' |
| | with open(index_path, 'w') as f: |
| | json.dump(metadata, f, indent=4) |
| |
|
| | print(f"Metadata has been saved to {index_path}") |
| |
|
| | config = { |
| | "architectures": [ |
| | "WavTokenizerDec" |
| | ], |
| | "hidden_size": 1282, |
| | "n_embd_features": 512, |
| | "n_ff": 2304, |
| | "vocab_size": 4096, |
| | "n_head": 1, |
| | "layer_norm_epsilon": 1e-6, |
| | "group_norm_epsilon": 1e-6, |
| | "group_norm_groups": 32, |
| | "max_position_embeddings": 8192, |
| | "n_layer": 12, |
| | "posnet": { |
| | "n_embd": 768, |
| | "n_layer": 6 |
| | }, |
| | "convnext": { |
| | "n_embd": 768, |
| | "n_layer": 12 |
| | }, |
| | } |
| |
|
| | with open(path_dst + '/config.json', 'w') as f: |
| | json.dump(config, f, indent=4) |
| |
|
| | print(f"Config has been saved to {path_dst + 'config.json'}") |
| |
|