| |
| """Convert IndicF5 HF safetensors into a training-ready F5-TTS EMA checkpoint. |
| |
| This script adapts ai4bharat/IndicF5 weights for Sinhala fine-tuning with |
| custom vocab size by: |
| 1) stripping torch.compile `_orig_mod` key prefixes, |
| 2) dropping embedded vocoder parameters, |
| 3) dropping mismatched text embedding weights, and |
| 4) materializing a complete EMA state dict for strict trainer loading. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from ema_pytorch import EMA |
| from f5_tts.infer.utils_infer import get_tokenizer |
| from f5_tts.model import CFM, DiT |
| from safetensors.torch import load_file |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Convert IndicF5 checkpoint for Sinhala fine-tuning") |
| parser.add_argument( |
| "--input", |
| default="pretrained_models/model.safetensors", |
| help="Path to downloaded IndicF5 model.safetensors", |
| ) |
| parser.add_argument( |
| "--output", |
| default="pretrained_models/indicf5_for_sinhala.pt", |
| help="Output path for converted EMA checkpoint", |
| ) |
| parser.add_argument( |
| "--vocab", |
| default="data/sinhala_vocab/vocab.txt", |
| help="Tokenizer vocab path used by training", |
| ) |
| return parser |
|
|
|
|
| def main() -> int: |
| args = build_parser().parse_args() |
|
|
| in_path = Path(args.input) |
| out_path = Path(args.output) |
| vocab_path = Path(args.vocab) |
|
|
| if not in_path.exists(): |
| raise FileNotFoundError(f"Input checkpoint not found: {in_path}") |
| if not vocab_path.exists(): |
| raise FileNotFoundError(f"Vocab not found: {vocab_path}") |
|
|
| print(f"[1/5] Loading safetensors from {in_path}") |
| src_state = load_file(str(in_path), device="cpu") |
|
|
| print("[2/5] Rewriting keys and filtering incompatible tensors") |
| converted = {} |
| dropped_vocoder = 0 |
| dropped_text_embed = 0 |
|
|
| for key, value in src_state.items(): |
| if key.startswith("vocoder."): |
| dropped_vocoder += 1 |
| continue |
|
|
| new_key = key.replace("ema_model._orig_mod.", "ema_model.") |
|
|
| if new_key == "ema_model.transformer.text_embed.text_embed.weight": |
| dropped_text_embed += 1 |
| continue |
|
|
| converted[new_key] = value |
|
|
| print(f" kept tensors: {len(converted)}") |
| print(f" dropped vocoder tensors: {dropped_vocoder}") |
| print(f" dropped text_embed tensors: {dropped_text_embed}") |
|
|
| print("[3/5] Building Sinhala-sized F5 model + EMA container") |
| _, vocab_size = get_tokenizer(str(vocab_path), "custom") |
|
|
| model = CFM( |
| transformer=DiT( |
| dim=1024, |
| depth=22, |
| heads=16, |
| ff_mult=2, |
| text_dim=512, |
| conv_layers=4, |
| text_num_embeds=vocab_size, |
| mel_dim=100, |
| ), |
| ) |
| ema = EMA(model, include_online_model=False) |
|
|
| print("[4/5] Loading converted weights into EMA with strict=False") |
| ema.load_state_dict(converted, strict=False) |
|
|
| print(f"[5/5] Saving training-ready checkpoint to {out_path}") |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save({"ema_model_state_dict": ema.state_dict()}, out_path) |
|
|
| print("[OK] Conversion complete") |
| print(f"Output: {out_path}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|