# app.py import streamlit as st import numpy as np import matplotlib.pyplot as plt import os from model.inference import TTSInference # Page Config st.set_page_config(page_title="My Custom TTS Engine", layout="wide") st.title("🎙️ Custom Architecture TTS Playground") st.markdown("This project demonstrates a custom PyTorch implementation of a Transformer-based TTS.") # Sidebar for Model Controls with st.sidebar: st.header("Model Settings") checkpoint = st.selectbox("Select Checkpoint", [ "checkpoints/checkpoint_epoch_50c.pth", "checkpoints/checkpoint_epoch_3c.pth", "checkpoints/checkpoint_epoch_8.pth" ]) # Force CPU for Hugging Face free tier to prevent CUDA errors device = st.radio("Device", ["cpu"]) st.info("Load a specific training checkpoint to compare progress.") # --- CRITICAL FIX FOR CLOUD: Cache the model --- @st.cache_resource def load_engine(ckpt_path, dev): if not os.path.exists(ckpt_path): return None # Return None if file isn't uploaded yet return TTSInference(checkpoint_path=ckpt_path, device=dev) # Initialize the Inference Engine tts_engine = load_engine(checkpoint, device) # Main Input Area text_input = st.text_area("Enter Text to Speak:", "Deep learning is fascinating.", height=100) col1, col2 = st.columns([1, 2]) with col1: if st.button("Generate Audio", type="primary"): if tts_engine is None: st.error(f"⚠️ Error: Could not find '{checkpoint}'. Did you upload it to the 'checkpoints' folder on Hugging Face?") else: with st.spinner("Running Inference..."): # Call your backend audio_data, sample_rate, mel_spec = tts_engine.predict(text_input) # Play Audio st.success("Generation Complete!") st.audio(audio_data, sample_rate=sample_rate) # --- VISUALIZATION --- st.subheader("Mel Spectrogram Analysis") fig, ax = plt.subplots(figsize=(10, 3)) im = ax.imshow(mel_spec, aspect='auto', origin='lower', cmap='inferno') plt.colorbar(im, ax=ax) plt.title("Generated Mel Spectrogram") plt.xlabel("Time Frames") plt.ylabel("Mel Channels") st.pyplot(fig) with col2: st.subheader("Architecture Details") st.code(""" class TextToMel(nn.Module): def __init__(self): super().__init__() self.encoder = TransformerEncoder(...) self.decoder = TransformerDecoder(...) def forward(self, text): # 1. Embed text # 2. Add Positional Encodings # 3. Predict Mel Frames return mel_output """, language="python")