|
|
|
|
|
|
|
|
|
|
| import torch
|
| import json
|
| import os
|
| import sys
|
| import re
|
|
|
| from safetensors.torch import save_file
|
|
|
|
|
| model_path = './model.pt'
|
|
|
|
|
| if len(sys.argv) > 1:
|
| model_path = sys.argv[1]
|
|
|
|
|
| path_dst = os.path.dirname(model_path)
|
|
|
| print(f"Loading model from {model_path}")
|
|
|
| model = torch.load(model_path, map_location='cpu')
|
|
|
|
|
|
|
|
|
| for key in model.keys():
|
| print(key)
|
| if key == 'hyper_parameters':
|
|
|
|
|
| print(json.dumps(model[key], indent=4))
|
|
|
|
|
|
|
|
|
| if isinstance(model, torch.nn.Module):
|
| state_dict = model.state_dict()
|
| else:
|
| state_dict = model
|
|
|
|
|
| print("State dictionary keys:")
|
| for key in state_dict.keys():
|
| print(key)
|
|
|
|
|
| def flatten_state_dict(state_dict, parent_key='', sep='.'):
|
| items = []
|
| items_new = []
|
|
|
| for k, v in state_dict.items():
|
| new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
| if isinstance(v, torch.Tensor):
|
| items.append((new_key, v))
|
| elif isinstance(v, dict):
|
| items.extend(flatten_state_dict(v, new_key, sep=sep).items())
|
| return dict(items)
|
|
|
| size_total_mb = 0
|
|
|
| for key, value in list(items):
|
|
|
| if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
|
| not key.startswith('state_dict.backbone.') and \
|
| not key.startswith('state_dict.head.out'):
|
| print('Skipping key: ', key)
|
| continue
|
|
|
| new_key = key
|
|
|
| new_key = new_key.replace('state_dict.', '')
|
| new_key = new_key.replace('pos_net', 'posnet')
|
|
|
|
|
| if new_key.startswith("backbone.posnet."):
|
| match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key)
|
| if match:
|
| new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}"
|
|
|
|
|
| if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
|
| new_key = "backbone.embedding.weight"
|
|
|
|
|
|
|
| if new_key.endswith("norm.scale.weight"):
|
| new_key = new_key.replace("norm.scale.weight", "norm.weight")
|
| value = value[0]
|
|
|
| if new_key.endswith("norm.shift.weight"):
|
| new_key = new_key.replace("norm.shift.weight", "norm.bias")
|
| value = value[0]
|
|
|
| if new_key.endswith("gamma"):
|
| new_key = new_key.replace("gamma", "gamma.weight")
|
|
|
|
|
| if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")):
|
| value = value.unsqueeze(1)
|
|
|
| if new_key.endswith("dwconv.bias"):
|
| value = value.unsqueeze(1)
|
|
|
| size_mb = value.element_size() * value.nelement() / (1024 * 1024)
|
| print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
|
|
|
| size_total_mb += size_mb
|
|
|
|
|
|
|
|
|
| items_new.append((new_key, value))
|
|
|
| print(f"Total size: {size_total_mb:8.2f} MB")
|
|
|
| return dict(items_new)
|
|
|
| flattened_state_dict = flatten_state_dict(state_dict)
|
|
|
|
|
|
|
| output_path = path_dst + '/model.safetensors'
|
| save_file(flattened_state_dict, output_path)
|
|
|
| print(f"Model has been successfully converted and saved to {output_path}")
|
|
|
|
|
| total_size = os.path.getsize(output_path)
|
|
|
|
|
| weight_map = {
|
| "model.safetensors": ["*"]
|
| }
|
|
|
|
|
| metadata = {
|
| "total_size": total_size,
|
| "weight_map": weight_map
|
| }
|
|
|
|
|
| index_path = path_dst + '/index.json'
|
| with open(index_path, 'w') as f:
|
| json.dump(metadata, f, indent=4)
|
|
|
| print(f"Metadata has been saved to {index_path}")
|
|
|
| config = {
|
| "architectures": [
|
| "WavTokenizerDec"
|
| ],
|
| "hidden_size": 1282,
|
| "n_embd_features": 512,
|
| "n_ff": 2304,
|
| "vocab_size": 4096,
|
| "n_head": 1,
|
| "layer_norm_epsilon": 1e-6,
|
| "group_norm_epsilon": 1e-6,
|
| "group_norm_groups": 32,
|
| "max_position_embeddings": 8192,
|
| "n_layer": 12,
|
| "posnet": {
|
| "n_embd": 768,
|
| "n_layer": 6
|
| },
|
| "convnext": {
|
| "n_embd": 768,
|
| "n_layer": 12
|
| },
|
| }
|
|
|
| with open(path_dst + '/config.json', 'w') as f:
|
| json.dump(config, f, indent=4)
|
|
|
| print(f"Config has been saved to {path_dst + 'config.json'}")
|
|
|