Mohit0708's picture
Rename streamlit_app.py to app.py
921e6ec verified
# app.py
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import os
from model.inference import TTSInference
# Page Config
st.set_page_config(page_title="My Custom TTS Engine", layout="wide")
st.title("🎙️ Custom Architecture TTS Playground")
st.markdown("This project demonstrates a custom PyTorch implementation of a Transformer-based TTS.")
# Sidebar for Model Controls
with st.sidebar:
st.header("Model Settings")
checkpoint = st.selectbox("Select Checkpoint", [
"checkpoints/checkpoint_epoch_50c.pth",
"checkpoints/checkpoint_epoch_3c.pth",
"checkpoints/checkpoint_epoch_8.pth"
])
# Force CPU for Hugging Face free tier to prevent CUDA errors
device = st.radio("Device", ["cpu"])
st.info("Load a specific training checkpoint to compare progress.")
# --- CRITICAL FIX FOR CLOUD: Cache the model ---
@st.cache_resource
def load_engine(ckpt_path, dev):
if not os.path.exists(ckpt_path):
return None # Return None if file isn't uploaded yet
return TTSInference(checkpoint_path=ckpt_path, device=dev)
# Initialize the Inference Engine
tts_engine = load_engine(checkpoint, device)
# Main Input Area
text_input = st.text_area("Enter Text to Speak:", "Deep learning is fascinating.", height=100)
col1, col2 = st.columns([1, 2])
with col1:
if st.button("Generate Audio", type="primary"):
if tts_engine is None:
st.error(f"⚠️ Error: Could not find '{checkpoint}'. Did you upload it to the 'checkpoints' folder on Hugging Face?")
else:
with st.spinner("Running Inference..."):
# Call your backend
audio_data, sample_rate, mel_spec = tts_engine.predict(text_input)
# Play Audio
st.success("Generation Complete!")
st.audio(audio_data, sample_rate=sample_rate)
# --- VISUALIZATION ---
st.subheader("Mel Spectrogram Analysis")
fig, ax = plt.subplots(figsize=(10, 3))
im = ax.imshow(mel_spec, aspect='auto', origin='lower', cmap='inferno')
plt.colorbar(im, ax=ax)
plt.title("Generated Mel Spectrogram")
plt.xlabel("Time Frames")
plt.ylabel("Mel Channels")
st.pyplot(fig)
with col2:
st.subheader("Architecture Details")
st.code("""
class TextToMel(nn.Module):
def __init__(self):
super().__init__()
self.encoder = TransformerEncoder(...)
self.decoder = TransformerDecoder(...)
def forward(self, text):
# 1. Embed text
# 2. Add Positional Encodings
# 3. Predict Mel Frames
return mel_output
""", language="python")