#!/usr/bin/env python3 """Convert IndicF5 HF safetensors into a training-ready F5-TTS EMA checkpoint. This script adapts ai4bharat/IndicF5 weights for Sinhala fine-tuning with custom vocab size by: 1) stripping torch.compile `_orig_mod` key prefixes, 2) dropping embedded vocoder parameters, 3) dropping mismatched text embedding weights, and 4) materializing a complete EMA state dict for strict trainer loading. """ from __future__ import annotations import argparse from pathlib import Path import torch from ema_pytorch import EMA from f5_tts.infer.utils_infer import get_tokenizer from f5_tts.model import CFM, DiT from safetensors.torch import load_file def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Convert IndicF5 checkpoint for Sinhala fine-tuning") parser.add_argument( "--input", default="pretrained_models/model.safetensors", help="Path to downloaded IndicF5 model.safetensors", ) parser.add_argument( "--output", default="pretrained_models/indicf5_for_sinhala.pt", help="Output path for converted EMA checkpoint", ) parser.add_argument( "--vocab", default="data/sinhala_vocab/vocab.txt", help="Tokenizer vocab path used by training", ) return parser def main() -> int: args = build_parser().parse_args() in_path = Path(args.input) out_path = Path(args.output) vocab_path = Path(args.vocab) if not in_path.exists(): raise FileNotFoundError(f"Input checkpoint not found: {in_path}") if not vocab_path.exists(): raise FileNotFoundError(f"Vocab not found: {vocab_path}") print(f"[1/5] Loading safetensors from {in_path}") src_state = load_file(str(in_path), device="cpu") print("[2/5] Rewriting keys and filtering incompatible tensors") converted = {} dropped_vocoder = 0 dropped_text_embed = 0 for key, value in src_state.items(): if key.startswith("vocoder."): dropped_vocoder += 1 continue new_key = key.replace("ema_model._orig_mod.", "ema_model.") if new_key == "ema_model.transformer.text_embed.text_embed.weight": dropped_text_embed += 1 continue converted[new_key] = value print(f" kept tensors: {len(converted)}") print(f" dropped vocoder tensors: {dropped_vocoder}") print(f" dropped text_embed tensors: {dropped_text_embed}") print("[3/5] Building Sinhala-sized F5 model + EMA container") _, vocab_size = get_tokenizer(str(vocab_path), "custom") model = CFM( transformer=DiT( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100, ), ) ema = EMA(model, include_online_model=False) print("[4/5] Loading converted weights into EMA with strict=False") ema.load_state_dict(converted, strict=False) print(f"[5/5] Saving training-ready checkpoint to {out_path}") out_path.parent.mkdir(parents=True, exist_ok=True) torch.save({"ema_model_state_dict": ema.state_dict()}, out_path) print("[OK] Conversion complete") print(f"Output: {out_path}") return 0 if __name__ == "__main__": raise SystemExit(main())