File size: 2,825 Bytes
4cdc4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607fd46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# app.py
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import os
from model.inference import TTSInference

# Page Config
st.set_page_config(page_title="My Custom TTS Engine", layout="wide")

st.title("🎙️ Custom Architecture TTS Playground")
st.markdown("This project demonstrates a custom PyTorch implementation of a Transformer-based TTS.")

# Sidebar for Model Controls
with st.sidebar:
    st.header("Model Settings")
    checkpoint = st.selectbox("Select Checkpoint", [
        "checkpoints/checkpoint_epoch_50c.pth", 
        "checkpoints/checkpoint_epoch_3c.pth", 
        "checkpoints/checkpoint_epoch_8.pth"
    ])
    # Force CPU for Hugging Face free tier to prevent CUDA errors
    device = st.radio("Device", ["cpu"]) 
    st.info("Load a specific training checkpoint to compare progress.")

# --- CRITICAL FIX FOR CLOUD: Cache the model ---
@st.cache_resource
def load_engine(ckpt_path, dev):
    if not os.path.exists(ckpt_path):
        return None # Return None if file isn't uploaded yet
    return TTSInference(checkpoint_path=ckpt_path, device=dev)

# Initialize the Inference Engine
tts_engine = load_engine(checkpoint, device)

# Main Input Area
text_input = st.text_area("Enter Text to Speak:", "Deep learning is fascinating.", height=100)

col1, col2 = st.columns([1, 2])

with col1:
    if st.button("Generate Audio", type="primary"):
        if tts_engine is None:
            st.error(f"⚠️ Error: Could not find '{checkpoint}'. Did you upload it to the 'checkpoints' folder on Hugging Face?")
        else:
            with st.spinner("Running Inference..."):
                # Call your backend
                audio_data, sample_rate, mel_spec = tts_engine.predict(text_input)
                
                # Play Audio
                st.success("Generation Complete!")
                st.audio(audio_data, sample_rate=sample_rate)

                # --- VISUALIZATION ---
                st.subheader("Mel Spectrogram Analysis")
                fig, ax = plt.subplots(figsize=(10, 3))
                im = ax.imshow(mel_spec, aspect='auto', origin='lower', cmap='inferno')
                plt.colorbar(im, ax=ax)
                plt.title("Generated Mel Spectrogram")
                plt.xlabel("Time Frames")
                plt.ylabel("Mel Channels")
                st.pyplot(fig)

with col2:
    st.subheader("Architecture Details")
    st.code("""
    class TextToMel(nn.Module):
        def __init__(self):
            super().__init__()
            self.encoder = TransformerEncoder(...)
            self.decoder = TransformerDecoder(...)
            
        def forward(self, text):
            # 1. Embed text
            # 2. Add Positional Encodings
            # 3. Predict Mel Frames
            return mel_output
    """, language="python")