hexa-tts-5b / src /inference.py
Hexa09's picture
Upload folder using huggingface_hub
e729286 verified
import torch
import soundfile as sf
import os
from .model import build_model
from .text_encoder import TextEncoder
from .config import HexaConfig
def generate_audio(text, output_path, lang='en', speaker_id=0, emotion_id=0):
"""
Generates audio from text using the Hexa 5B model.
"""
print(f"Initializing Hexa 5B TTS System...")
# 1. Load Configuration
config = HexaConfig()
# 2. Load Model (Architecture only, random weights for demo)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model = build_model()
model.to(device)
model.eval()
# 3. Process Text
encoder = TextEncoder()
print(f"Processing text: '{text}' ({lang})")
text_ids = encoder.preprocess(text, lang_code=lang).to(device)
# 4. Prepare inputs
# Ensure IDs are within range
speaker_tensor = torch.tensor([speaker_id]).to(device).clamp(0, config.num_speakers-1)
language_tensor = torch.tensor([0]).to(device) # Placeholder mapping
emotion_tensor = torch.tensor([emotion_id]).to(device).clamp(0, config.num_emotions-1)
# 5. Generate (Forward Pass)
with torch.no_grad():
# In a real autoregressive model, this would be a loop.
# Here we just run one forward pass to verify architecture.
mel_output = model(text_ids, speaker_tensor, language_tensor, emotion_tensor)
print(f"Model forward pass successful. Output shape: {mel_output.shape}")
print("Note: Since this is an untrained model, the output is random noise.")
# 6. Dummy Vocoder (Simulated)
# In production, use HifiGAN here to convert Mel -> Audio
sr = config.sample_rate
dummy_audio = torch.randn(mel_output.shape[1] * 256) # Approx length
# Save
sf.write(output_path, dummy_audio.cpu().numpy(), sr)
print(f"Saved generated (random) audio to: {output_path}")
if __name__ == "__main__":
# Test Run
generate_audio(
"Hello, this is Hexa TTS.",
"test_output.wav",
lang='en',
emotion_id=5 # e.g. 'Happy'
)