sinhala-tts / scripts /convert_indicf5_checkpoint.py
outlawmold's picture
Add MPS training stability fixes and experiment logs
19655a1
#!/usr/bin/env python3
"""Convert IndicF5 HF safetensors into a training-ready F5-TTS EMA checkpoint.
This script adapts ai4bharat/IndicF5 weights for Sinhala fine-tuning with
custom vocab size by:
1) stripping torch.compile `_orig_mod` key prefixes,
2) dropping embedded vocoder parameters,
3) dropping mismatched text embedding weights, and
4) materializing a complete EMA state dict for strict trainer loading.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from ema_pytorch import EMA
from f5_tts.infer.utils_infer import get_tokenizer
from f5_tts.model import CFM, DiT
from safetensors.torch import load_file
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Convert IndicF5 checkpoint for Sinhala fine-tuning")
parser.add_argument(
"--input",
default="pretrained_models/model.safetensors",
help="Path to downloaded IndicF5 model.safetensors",
)
parser.add_argument(
"--output",
default="pretrained_models/indicf5_for_sinhala.pt",
help="Output path for converted EMA checkpoint",
)
parser.add_argument(
"--vocab",
default="data/sinhala_vocab/vocab.txt",
help="Tokenizer vocab path used by training",
)
return parser
def main() -> int:
args = build_parser().parse_args()
in_path = Path(args.input)
out_path = Path(args.output)
vocab_path = Path(args.vocab)
if not in_path.exists():
raise FileNotFoundError(f"Input checkpoint not found: {in_path}")
if not vocab_path.exists():
raise FileNotFoundError(f"Vocab not found: {vocab_path}")
print(f"[1/5] Loading safetensors from {in_path}")
src_state = load_file(str(in_path), device="cpu")
print("[2/5] Rewriting keys and filtering incompatible tensors")
converted = {}
dropped_vocoder = 0
dropped_text_embed = 0
for key, value in src_state.items():
if key.startswith("vocoder."):
dropped_vocoder += 1
continue
new_key = key.replace("ema_model._orig_mod.", "ema_model.")
if new_key == "ema_model.transformer.text_embed.text_embed.weight":
dropped_text_embed += 1
continue
converted[new_key] = value
print(f" kept tensors: {len(converted)}")
print(f" dropped vocoder tensors: {dropped_vocoder}")
print(f" dropped text_embed tensors: {dropped_text_embed}")
print("[3/5] Building Sinhala-sized F5 model + EMA container")
_, vocab_size = get_tokenizer(str(vocab_path), "custom")
model = CFM(
transformer=DiT(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
text_num_embeds=vocab_size,
mel_dim=100,
),
)
ema = EMA(model, include_online_model=False)
print("[4/5] Loading converted weights into EMA with strict=False")
ema.load_state_dict(converted, strict=False)
print(f"[5/5] Saving training-ready checkpoint to {out_path}")
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({"ema_model_state_dict": ema.state_dict()}, out_path)
print("[OK] Conversion complete")
print(f"Output: {out_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())