| |
| |
| """ |
| Convert a FireRedLID .pth.tar checkpoint into a vLLM / HF-compatible model |
| directory containing: |
| |
| config.json |
| generation_config.json |
| preprocessor_config.json |
| tokenizer.json |
| tokenizer_config.json |
| special_tokens_map.json |
| model.safetensors |
| dict.txt (copied verbatim) |
| cmvn.ark (copied verbatim) |
| |
| Usage |
| ----- |
| python3 vllm-abo/tools/convert_fireredlid_checkpoint.py \ |
| --src pretrained_models/FireRedLID \ |
| --dst converted_models/FireRedLID-vllm |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import shutil |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from safetensors.torch import save_file |
|
|
|
|
| |
| |
| |
|
|
| def read_kaldi_cmvn(kaldi_cmvn_file: str): |
| """Read a Kaldi CMVN stats file and return (dim, means, inv_std).""" |
| try: |
| import kaldiio |
| except ImportError: |
| raise ImportError( |
| "kaldiio is required to read cmvn.ark – " |
| "pip install kaldiio" |
| ) |
| stats = kaldiio.load_mat(kaldi_cmvn_file) |
| assert stats.shape[0] == 2, f"Unexpected CMVN shape: {stats.shape}" |
| dim = stats.shape[-1] - 1 |
| count = stats[0, dim] |
| assert count >= 1 |
| floor = 1e-20 |
| means: list[float] = [] |
| inv_std: list[float] = [] |
| for d in range(dim): |
| mean = stats[0, d] / count |
| means.append(float(mean)) |
| variance = (stats[1, d] / count) - mean * mean |
| if variance < floor: |
| variance = floor |
| inv_std.append(1.0 / math.sqrt(variance)) |
| return dim, means, inv_std |
|
|
|
|
| |
| |
| |
|
|
| def build_tokenizer_assets(dict_path: str): |
| """ |
| Read dict.txt and produce the three tokenizer JSON files: |
| tokenizer.json, tokenizer_config.json, special_tokens_map.json |
| """ |
| vocab: dict[str, int] = {} |
| with open(dict_path, encoding="utf-8") as f: |
| for line in f: |
| parts = line.strip().split() |
| if len(parts) >= 2: |
| token, idx = parts[0], int(parts[1]) |
| elif len(parts) == 1: |
| token, idx = parts[0], len(vocab) |
| else: |
| continue |
| vocab[token] = idx |
|
|
| |
| tokenizer_json = { |
| "version": "1.0", |
| "truncation": None, |
| "padding": None, |
| "added_tokens": [ |
| { |
| "id": idx, |
| "content": tok, |
| "single_word": False, |
| "lstrip": False, |
| "rstrip": False, |
| "normalized": False, |
| "special": tok.startswith("<") and tok.endswith(">"), |
| } |
| for tok, idx in sorted(vocab.items(), key=lambda x: x[1]) |
| ], |
| "normalizer": None, |
| "pre_tokenizer": { |
| "type": "WhitespaceSplit", |
| }, |
| "post_processor": None, |
| "decoder": None, |
| "model": { |
| "type": "WordLevel", |
| "vocab": vocab, |
| "unk_token": "<unk>", |
| }, |
| } |
|
|
| |
| tokenizer_config = { |
| "tokenizer_class": "PreTrainedTokenizerFast", |
| "model_type": "fireredlid", |
| "bos_token": "<sos>", |
| "eos_token": "<eos>", |
| "unk_token": "<unk>", |
| "pad_token": "<pad>", |
| "added_tokens_decoder": { |
| str(idx): { |
| "content": tok, |
| "single_word": False, |
| "lstrip": False, |
| "rstrip": False, |
| "normalized": False, |
| "special": tok.startswith("<") and tok.endswith(">"), |
| } |
| for tok, idx in sorted(vocab.items(), key=lambda x: x[1]) |
| }, |
| } |
|
|
| |
| special_tokens_map = { |
| "bos_token": "<sos>", |
| "eos_token": "<eos>", |
| "unk_token": "<unk>", |
| "pad_token": "<pad>", |
| } |
|
|
| return tokenizer_json, tokenizer_config, special_tokens_map, vocab |
|
|
|
|
| |
| |
| |
|
|
| def build_config(args_obj, vocab_size: int) -> dict: |
| """Build config.json from the checkpoint's args namespace.""" |
| config = { |
| "architectures": ["FireRedLIDForConditionalGeneration"], |
| "model_type": "fireredlid", |
| "is_encoder_decoder": True, |
| "vocab_size": vocab_size, |
| "lid_odim": getattr(args_obj, "lid_odim", vocab_size), |
| "idim": getattr(args_obj, "idim", 80), |
| "d_model": getattr(args_obj, "d_model", 1280), |
| "n_head": getattr(args_obj, "n_head", 20), |
| "n_layers_enc": getattr(args_obj, "n_layers_enc", 16), |
| "n_layers_lid_dec": getattr(args_obj, "n_layers_lid_dec", 6), |
| "kernel_size": getattr(args_obj, "kernel_size", 33), |
| "residual_dropout": getattr(args_obj, "residual_dropout", 0.05), |
| "dropout_rate": getattr(args_obj, "dropout_rate", 0.05), |
| "pe_maxlen": getattr(args_obj, "pe_maxlen", 5000), |
| "pad_token_id": getattr(args_obj, "pad_id", 2), |
| "bos_token_id": getattr(args_obj, "sos_id", 3), |
| "eos_token_id": getattr(args_obj, "eos_id", 4), |
| "decoder_start_token_id": getattr(args_obj, "sos_id", 3), |
| "tie_word_embeddings": True, |
| } |
| return config |
|
|
|
|
| |
| |
| |
|
|
| def build_generation_config(config: dict) -> dict: |
| return { |
| "_from_model_config": True, |
| "bos_token_id": config["bos_token_id"], |
| "eos_token_id": config["eos_token_id"], |
| "pad_token_id": config["pad_token_id"], |
| "decoder_start_token_id": config["decoder_start_token_id"], |
| "max_new_tokens": 2, |
| "temperature": 1.25, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def build_preprocessor_config( |
| cmvn_dim: int, |
| means: list[float], |
| inv_std: list[float], |
| ) -> dict: |
| return { |
| "feature_extractor_type": "FireRedLIDFeatureExtractor", |
| "processor_class": "FireRedLIDProcessor", |
| "feature_size": cmvn_dim, |
| "sampling_rate": 16000, |
| "num_mel_bins": cmvn_dim, |
| "frame_length": 25, |
| "frame_shift": 10, |
| "dither": 0.0, |
| "left_context": 3, |
| "right_context": 3, |
| "chunk_length": 30, |
| "dim": cmvn_dim, |
| "means": means, |
| "inverse_std_variences": inv_std, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Convert FireRedLID .pth.tar to vLLM-compatible format" |
| ) |
| parser.add_argument("--src", required=True, help="Source directory (pretrained_models/FireRedLID)") |
| parser.add_argument("--dst", required=True, help="Destination directory") |
| args = parser.parse_args() |
|
|
| src = Path(args.src) |
| dst = Path(args.dst) |
| dst.mkdir(parents=True, exist_ok=True) |
|
|
| |
| ckpt_path = src / "model.pth.tar" |
| print(f"Loading checkpoint: {ckpt_path}") |
| package = torch.load(str(ckpt_path), map_location="cpu", weights_only=False) |
| ckpt_args = package["args"] |
| state_dict: dict[str, torch.Tensor] = package["model_state_dict"] |
|
|
| num_keys_before = len(state_dict) |
| print(f" Total keys in checkpoint: {num_keys_before}") |
|
|
| |
| filtered: dict[str, torch.Tensor] = {} |
| skipped_keys: list[str] = [] |
| encoder_count = 0 |
| decoder_count = 0 |
|
|
| for k, v in state_dict.items(): |
| if k.startswith("encoder.") or k.startswith("lid_decoder."): |
| filtered[k] = v |
| if k.startswith("encoder."): |
| encoder_count += 1 |
| else: |
| decoder_count += 1 |
| else: |
| skipped_keys.append(k) |
|
|
| print(f" Encoder keys: {encoder_count}") |
| print(f" Decoder keys: {decoder_count}") |
| print(f" Skipped keys: {len(skipped_keys)}") |
| if skipped_keys: |
| for k in skipped_keys[:10]: |
| print(f" skipped: {k}") |
| if len(skipped_keys) > 10: |
| print(f" ... and {len(skipped_keys) - 10} more") |
|
|
| |
| safetensors_path = dst / "model.safetensors" |
| print(f"Saving {safetensors_path} ({len(filtered)} keys) ...") |
| save_file(filtered, str(safetensors_path)) |
|
|
| |
| dict_path = src / "dict.txt" |
| tokenizer_json, tokenizer_config, special_tokens_map, vocab = \ |
| build_tokenizer_assets(str(dict_path)) |
| vocab_size = len(vocab) |
| print(f" Vocab size: {vocab_size}") |
|
|
| with open(dst / "tokenizer.json", "w", encoding="utf-8") as f: |
| json.dump(tokenizer_json, f, indent=2, ensure_ascii=False) |
| with open(dst / "tokenizer_config.json", "w", encoding="utf-8") as f: |
| json.dump(tokenizer_config, f, indent=2, ensure_ascii=False) |
| with open(dst / "special_tokens_map.json", "w", encoding="utf-8") as f: |
| json.dump(special_tokens_map, f, indent=2, ensure_ascii=False) |
|
|
| |
| config = build_config(ckpt_args, vocab_size) |
| with open(dst / "config.json", "w", encoding="utf-8") as f: |
| json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
| |
| gen_config = build_generation_config(config) |
| with open(dst / "generation_config.json", "w", encoding="utf-8") as f: |
| json.dump(gen_config, f, indent=2, ensure_ascii=False) |
|
|
| |
| cmvn_path = src / "cmvn.ark" |
| cmvn_dim, means, inv_std = read_kaldi_cmvn(str(cmvn_path)) |
| preprocessor_config = build_preprocessor_config(cmvn_dim, means, inv_std) |
| with open(dst / "preprocessor_config.json", "w", encoding="utf-8") as f: |
| json.dump(preprocessor_config, f, indent=2, ensure_ascii=False) |
|
|
| |
| shutil.copy2(str(dict_path), str(dst / "dict.txt")) |
| shutil.copy2(str(cmvn_path), str(dst / "cmvn.ark")) |
|
|
| |
| print("\n=== Conversion Summary ===") |
| print(f" num_state_keys_before : {num_keys_before}") |
| print(f" num_state_keys_after : {len(filtered)}") |
| print(f" num_skipped_keys : {len(skipped_keys)}") |
| print(f" vocab_size : {vocab_size}") |
| print(f" encoder_key_count : {encoder_count}") |
| print(f" decoder_key_count : {decoder_count}") |
| print(f" output_dir : {dst}") |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|