hexa-tts-5b / src /test_tiny.py
Hexa09's picture
Upload folder using huggingface_hub
e729286 verified
import torch
import soundfile as sf
import os
from .model import HexaTransformer
from .text_encoder import TextEncoder
from .config import HexaConfig
def run_tiny_test():
"""
Test the architecture with a tiny config to fit in memory.
"""
print("Initializing Tiny Hexa Model for Code Verification...")
# Override Config for Tiny Scale
config = HexaConfig(
dim=512,
depth=6,
heads=8,
dim_head=64,
num_languages=15
)
device = "cpu"
model = HexaTransformer(config)
model.to(device)
model.eval()
params = sum(p.numel() for p in model.parameters())
print(f"Tiny Model Size: {params / 1e6:.2f} Million parameters")
# Process Text
text = "Hello world, testing tiny hexa."
encoder = TextEncoder()
text_ids = encoder.preprocess(text, lang_code='en').to(device)
print(f"Encoded text shape: {text_ids.shape}")
# Inputs
speaker = torch.tensor([0]).to(device)
language = torch.tensor([0]).to(device)
emotion = torch.tensor([0]).to(device)
# Forward Pass
with torch.no_grad():
output = model(text_ids, speaker, language, emotion)
print(f"Forward pass successful. Output shape: {output.shape}")
# Save dummy audio
# Output is (B, Frames, Mel_Channels)
# We fake audio from it
dummy_wav = torch.randn(output.shape[1] * 256).numpy()
sf.write("tiny_output.wav", dummy_wav, config.sample_rate)
print("Saved tiny_output.wav")
if __name__ == "__main__":
run_tiny_test()