FireRedLID-vllm / convert_fireredlid_checkpoint.py
PatchyTisa's picture
Add files using upload-large-folder tool
1d67dac verified
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Convert a FireRedLID .pth.tar checkpoint into a vLLM / HF-compatible model
directory containing:
config.json
generation_config.json
preprocessor_config.json
tokenizer.json
tokenizer_config.json
special_tokens_map.json
model.safetensors
dict.txt (copied verbatim)
cmvn.ark (copied verbatim)
Usage
-----
python3 vllm-abo/tools/convert_fireredlid_checkpoint.py \
--src pretrained_models/FireRedLID \
--dst converted_models/FireRedLID-vllm
"""
from __future__ import annotations
import argparse
import json
import math
import os
import shutil
from pathlib import Path
import numpy as np
import torch
from safetensors.torch import save_file
# ---------------------------------------------------------------------------
# CMVN parsing (replicates the logic from fireredlid/data/feat.py)
# ---------------------------------------------------------------------------
def read_kaldi_cmvn(kaldi_cmvn_file: str):
"""Read a Kaldi CMVN stats file and return (dim, means, inv_std)."""
try:
import kaldiio
except ImportError:
raise ImportError(
"kaldiio is required to read cmvn.ark – "
"pip install kaldiio"
)
stats = kaldiio.load_mat(kaldi_cmvn_file)
assert stats.shape[0] == 2, f"Unexpected CMVN shape: {stats.shape}"
dim = stats.shape[-1] - 1
count = stats[0, dim]
assert count >= 1
floor = 1e-20
means: list[float] = []
inv_std: list[float] = []
for d in range(dim):
mean = stats[0, d] / count
means.append(float(mean))
variance = (stats[1, d] / count) - mean * mean
if variance < floor:
variance = floor
inv_std.append(1.0 / math.sqrt(variance))
return dim, means, inv_std
# ---------------------------------------------------------------------------
# dict.txt → tokenizer assets
# ---------------------------------------------------------------------------
def build_tokenizer_assets(dict_path: str):
"""
Read dict.txt and produce the three tokenizer JSON files:
tokenizer.json, tokenizer_config.json, special_tokens_map.json
"""
vocab: dict[str, int] = {}
with open(dict_path, encoding="utf-8") as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
token, idx = parts[0], int(parts[1])
elif len(parts) == 1:
token, idx = parts[0], len(vocab)
else:
continue
vocab[token] = idx
# --- tokenizer.json (WordLevel) ---
tokenizer_json = {
"version": "1.0",
"truncation": None,
"padding": None,
"added_tokens": [
{
"id": idx,
"content": tok,
"single_word": False,
"lstrip": False,
"rstrip": False,
"normalized": False,
"special": tok.startswith("<") and tok.endswith(">"),
}
for tok, idx in sorted(vocab.items(), key=lambda x: x[1])
],
"normalizer": None,
"pre_tokenizer": {
"type": "WhitespaceSplit",
},
"post_processor": None,
"decoder": None,
"model": {
"type": "WordLevel",
"vocab": vocab,
"unk_token": "<unk>",
},
}
# --- tokenizer_config.json ---
tokenizer_config = {
"tokenizer_class": "PreTrainedTokenizerFast",
"model_type": "fireredlid",
"bos_token": "<sos>",
"eos_token": "<eos>",
"unk_token": "<unk>",
"pad_token": "<pad>",
"added_tokens_decoder": {
str(idx): {
"content": tok,
"single_word": False,
"lstrip": False,
"rstrip": False,
"normalized": False,
"special": tok.startswith("<") and tok.endswith(">"),
}
for tok, idx in sorted(vocab.items(), key=lambda x: x[1])
},
}
# --- special_tokens_map.json ---
special_tokens_map = {
"bos_token": "<sos>",
"eos_token": "<eos>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
return tokenizer_json, tokenizer_config, special_tokens_map, vocab
# ---------------------------------------------------------------------------
# config.json
# ---------------------------------------------------------------------------
def build_config(args_obj, vocab_size: int) -> dict:
"""Build config.json from the checkpoint's args namespace."""
config = {
"architectures": ["FireRedLIDForConditionalGeneration"],
"model_type": "fireredlid",
"is_encoder_decoder": True,
"vocab_size": vocab_size,
"lid_odim": getattr(args_obj, "lid_odim", vocab_size),
"idim": getattr(args_obj, "idim", 80),
"d_model": getattr(args_obj, "d_model", 1280),
"n_head": getattr(args_obj, "n_head", 20),
"n_layers_enc": getattr(args_obj, "n_layers_enc", 16),
"n_layers_lid_dec": getattr(args_obj, "n_layers_lid_dec", 6),
"kernel_size": getattr(args_obj, "kernel_size", 33),
"residual_dropout": getattr(args_obj, "residual_dropout", 0.05),
"dropout_rate": getattr(args_obj, "dropout_rate", 0.05),
"pe_maxlen": getattr(args_obj, "pe_maxlen", 5000),
"pad_token_id": getattr(args_obj, "pad_id", 2),
"bos_token_id": getattr(args_obj, "sos_id", 3),
"eos_token_id": getattr(args_obj, "eos_id", 4),
"decoder_start_token_id": getattr(args_obj, "sos_id", 3),
"tie_word_embeddings": True,
}
return config
# ---------------------------------------------------------------------------
# generation_config.json
# ---------------------------------------------------------------------------
def build_generation_config(config: dict) -> dict:
return {
"_from_model_config": True,
"bos_token_id": config["bos_token_id"],
"eos_token_id": config["eos_token_id"],
"pad_token_id": config["pad_token_id"],
"decoder_start_token_id": config["decoder_start_token_id"],
"max_new_tokens": 2,
"temperature": 1.25,
}
# ---------------------------------------------------------------------------
# preprocessor_config.json
# ---------------------------------------------------------------------------
def build_preprocessor_config(
cmvn_dim: int,
means: list[float],
inv_std: list[float],
) -> dict:
return {
"feature_extractor_type": "FireRedLIDFeatureExtractor",
"processor_class": "FireRedLIDProcessor",
"feature_size": cmvn_dim,
"sampling_rate": 16000,
"num_mel_bins": cmvn_dim,
"frame_length": 25,
"frame_shift": 10,
"dither": 0.0,
"left_context": 3,
"right_context": 3,
"chunk_length": 30,
"dim": cmvn_dim,
"means": means,
"inverse_std_variences": inv_std,
}
# ---------------------------------------------------------------------------
# main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Convert FireRedLID .pth.tar to vLLM-compatible format"
)
parser.add_argument("--src", required=True, help="Source directory (pretrained_models/FireRedLID)")
parser.add_argument("--dst", required=True, help="Destination directory")
args = parser.parse_args()
src = Path(args.src)
dst = Path(args.dst)
dst.mkdir(parents=True, exist_ok=True)
# ---- 1. Load checkpoint ----
ckpt_path = src / "model.pth.tar"
print(f"Loading checkpoint: {ckpt_path}")
package = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
ckpt_args = package["args"]
state_dict: dict[str, torch.Tensor] = package["model_state_dict"]
num_keys_before = len(state_dict)
print(f" Total keys in checkpoint: {num_keys_before}")
# ---- 2. Filter: keep only encoder.* and lid_decoder.* keys ----
filtered: dict[str, torch.Tensor] = {}
skipped_keys: list[str] = []
encoder_count = 0
decoder_count = 0
for k, v in state_dict.items():
if k.startswith("encoder.") or k.startswith("lid_decoder."):
filtered[k] = v
if k.startswith("encoder."):
encoder_count += 1
else:
decoder_count += 1
else:
skipped_keys.append(k)
print(f" Encoder keys: {encoder_count}")
print(f" Decoder keys: {decoder_count}")
print(f" Skipped keys: {len(skipped_keys)}")
if skipped_keys:
for k in skipped_keys[:10]:
print(f" skipped: {k}")
if len(skipped_keys) > 10:
print(f" ... and {len(skipped_keys) - 10} more")
# ---- 3. Save model.safetensors ----
safetensors_path = dst / "model.safetensors"
print(f"Saving {safetensors_path} ({len(filtered)} keys) ...")
save_file(filtered, str(safetensors_path))
# ---- 4. Read dict.txt & build tokenizer assets ----
dict_path = src / "dict.txt"
tokenizer_json, tokenizer_config, special_tokens_map, vocab = \
build_tokenizer_assets(str(dict_path))
vocab_size = len(vocab)
print(f" Vocab size: {vocab_size}")
with open(dst / "tokenizer.json", "w", encoding="utf-8") as f:
json.dump(tokenizer_json, f, indent=2, ensure_ascii=False)
with open(dst / "tokenizer_config.json", "w", encoding="utf-8") as f:
json.dump(tokenizer_config, f, indent=2, ensure_ascii=False)
with open(dst / "special_tokens_map.json", "w", encoding="utf-8") as f:
json.dump(special_tokens_map, f, indent=2, ensure_ascii=False)
# ---- 5. Build config.json ----
config = build_config(ckpt_args, vocab_size)
with open(dst / "config.json", "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
# ---- 6. Build generation_config.json ----
gen_config = build_generation_config(config)
with open(dst / "generation_config.json", "w", encoding="utf-8") as f:
json.dump(gen_config, f, indent=2, ensure_ascii=False)
# ---- 7. Read CMVN & build preprocessor_config.json ----
cmvn_path = src / "cmvn.ark"
cmvn_dim, means, inv_std = read_kaldi_cmvn(str(cmvn_path))
preprocessor_config = build_preprocessor_config(cmvn_dim, means, inv_std)
with open(dst / "preprocessor_config.json", "w", encoding="utf-8") as f:
json.dump(preprocessor_config, f, indent=2, ensure_ascii=False)
# ---- 8. Copy reference files ----
shutil.copy2(str(dict_path), str(dst / "dict.txt"))
shutil.copy2(str(cmvn_path), str(dst / "cmvn.ark"))
# ---- 9. Summary ----
print("\n=== Conversion Summary ===")
print(f" num_state_keys_before : {num_keys_before}")
print(f" num_state_keys_after : {len(filtered)}")
print(f" num_skipped_keys : {len(skipped_keys)}")
print(f" vocab_size : {vocab_size}")
print(f" encoder_key_count : {encoder_count}")
print(f" decoder_key_count : {decoder_count}")
print(f" output_dir : {dst}")
print("Done.")
if __name__ == "__main__":
main()