Spaces:
Runtime error
Runtime error
| import torch | |
| import soundfile as sf | |
| import os | |
| from .model import HexaTransformer | |
| from .text_encoder import TextEncoder | |
| from .config import HexaConfig | |
| def run_tiny_test(): | |
| """ | |
| Test the architecture with a tiny config to fit in memory. | |
| """ | |
| print("Initializing Tiny Hexa Model for Code Verification...") | |
| # Override Config for Tiny Scale | |
| config = HexaConfig( | |
| dim=512, | |
| depth=6, | |
| heads=8, | |
| dim_head=64, | |
| num_languages=15 | |
| ) | |
| device = "cpu" | |
| model = HexaTransformer(config) | |
| model.to(device) | |
| model.eval() | |
| params = sum(p.numel() for p in model.parameters()) | |
| print(f"Tiny Model Size: {params / 1e6:.2f} Million parameters") | |
| # Process Text | |
| text = "Hello world, testing tiny hexa." | |
| encoder = TextEncoder() | |
| text_ids = encoder.preprocess(text, lang_code='en').to(device) | |
| print(f"Encoded text shape: {text_ids.shape}") | |
| # Inputs | |
| speaker = torch.tensor([0]).to(device) | |
| language = torch.tensor([0]).to(device) | |
| emotion = torch.tensor([0]).to(device) | |
| # Forward Pass | |
| with torch.no_grad(): | |
| output = model(text_ids, speaker, language, emotion) | |
| print(f"Forward pass successful. Output shape: {output.shape}") | |
| # Save dummy audio | |
| # Output is (B, Frames, Mel_Channels) | |
| # We fake audio from it | |
| dummy_wav = torch.randn(output.shape[1] * 256).numpy() | |
| sf.write("tiny_output.wav", dummy_wav, config.sample_rate) | |
| print("Saved tiny_output.wav") | |
| if __name__ == "__main__": | |
| run_tiny_test() | |