File size: 1,626 Bytes
e729286 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
import soundfile as sf
import os
from .model import HexaTransformer
from .text_encoder import TextEncoder
from .config import HexaConfig
def run_tiny_test():
"""
Test the architecture with a tiny config to fit in memory.
"""
print("Initializing Tiny Hexa Model for Code Verification...")
# Override Config for Tiny Scale
config = HexaConfig(
dim=512,
depth=6,
heads=8,
dim_head=64,
num_languages=15
)
device = "cpu"
model = HexaTransformer(config)
model.to(device)
model.eval()
params = sum(p.numel() for p in model.parameters())
print(f"Tiny Model Size: {params / 1e6:.2f} Million parameters")
# Process Text
text = "Hello world, testing tiny hexa."
encoder = TextEncoder()
text_ids = encoder.preprocess(text, lang_code='en').to(device)
print(f"Encoded text shape: {text_ids.shape}")
# Inputs
speaker = torch.tensor([0]).to(device)
language = torch.tensor([0]).to(device)
emotion = torch.tensor([0]).to(device)
# Forward Pass
with torch.no_grad():
output = model(text_ids, speaker, language, emotion)
print(f"Forward pass successful. Output shape: {output.shape}")
# Save dummy audio
# Output is (B, Frames, Mel_Channels)
# We fake audio from it
dummy_wav = torch.randn(output.shape[1] * 256).numpy()
sf.write("tiny_output.wav", dummy_wav, config.sample_rate)
print("Saved tiny_output.wav")
if __name__ == "__main__":
run_tiny_test()
|