|
|
import torch
|
|
|
import soundfile as sf
|
|
|
import os
|
|
|
from .model import HexaTransformer
|
|
|
from .text_encoder import TextEncoder
|
|
|
from .config import HexaConfig
|
|
|
|
|
|
def run_tiny_test():
|
|
|
"""
|
|
|
Test the architecture with a tiny config to fit in memory.
|
|
|
"""
|
|
|
print("Initializing Tiny Hexa Model for Code Verification...")
|
|
|
|
|
|
|
|
|
config = HexaConfig(
|
|
|
dim=512,
|
|
|
depth=6,
|
|
|
heads=8,
|
|
|
dim_head=64,
|
|
|
num_languages=15
|
|
|
)
|
|
|
|
|
|
device = "cpu"
|
|
|
model = HexaTransformer(config)
|
|
|
model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
params = sum(p.numel() for p in model.parameters())
|
|
|
print(f"Tiny Model Size: {params / 1e6:.2f} Million parameters")
|
|
|
|
|
|
|
|
|
text = "Hello world, testing tiny hexa."
|
|
|
encoder = TextEncoder()
|
|
|
text_ids = encoder.preprocess(text, lang_code='en').to(device)
|
|
|
print(f"Encoded text shape: {text_ids.shape}")
|
|
|
|
|
|
|
|
|
speaker = torch.tensor([0]).to(device)
|
|
|
language = torch.tensor([0]).to(device)
|
|
|
emotion = torch.tensor([0]).to(device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output = model(text_ids, speaker, language, emotion)
|
|
|
|
|
|
print(f"Forward pass successful. Output shape: {output.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dummy_wav = torch.randn(output.shape[1] * 256).numpy()
|
|
|
sf.write("tiny_output.wav", dummy_wav, config.sample_rate)
|
|
|
print("Saved tiny_output.wav")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
run_tiny_test()
|
|
|
|