Mohit0708 commited on
Commit
607fd46
·
verified ·
1 Parent(s): be29b5b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from model.inference import TTSInference
6
+
7
+ # Page Config
8
+ st.set_page_config(page_title="My Custom TTS Engine", layout="wide")
9
+
10
+ st.title("🎙️ Custom Architecture TTS Playground")
11
+ st.markdown("This project demonstrates a custom PyTorch implementation of a Transformer-based TTS.")
12
+
13
+ # Sidebar for Model Controls
14
+ with st.sidebar:
15
+ st.header("Model Settings")
16
+ checkpoint = st.selectbox("Select Checkpoint", ["checkpoints/checkpoint_epoch_50c.pth", "checkpoints/checkpoint_epoch_3c.pth", "checkpoints/checkpoint_epoch_8.pth"])
17
+ device = st.radio("Device", ["cpu", "cuda"])
18
+ st.info("Load a specific training checkpoint to compare progress.")
19
+
20
+ # Initialize the Inference Engine
21
+ # (In a real app, use @st.cache_resource to load this once)
22
+ tts_engine = TTSInference(checkpoint_path=checkpoint, device=device)
23
+
24
+ # Main Input Area
25
+ text_input = st.text_area("Enter Text to Speak:", "Deep learning is fascinating.", height=100)
26
+
27
+ col1, col2 = st.columns([1, 2])
28
+
29
+ with col1:
30
+ if st.button("Generate Audio", type="primary"):
31
+ with st.spinner("Running Inference..."):
32
+ # Call your backend
33
+ audio_data, sample_rate, mel_spec = tts_engine.predict(text_input)
34
+
35
+ # Play Audio
36
+ st.success("Generation Complete!")
37
+ st.audio(audio_data, sample_rate=sample_rate)
38
+
39
+ # --- VISUALIZATION (Crucial for Path 2) ---
40
+ # Showing the spectrogram proves you understand the data, not just the result.
41
+ st.subheader("Mel Spectrogram Analysis")
42
+ fig, ax = plt.subplots(figsize=(10, 3))
43
+ im = ax.imshow(mel_spec, aspect='auto', origin='lower', cmap='inferno')
44
+ plt.colorbar(im, ax=ax)
45
+ plt.title("Generated Mel Spectrogram")
46
+ plt.xlabel("Time Frames")
47
+ plt.ylabel("Mel Channels")
48
+ st.pyplot(fig)
49
+
50
+ with col2:
51
+ st.subheader("Architecture Details")
52
+ st.code("""
53
+
54
+ class TextToMel(nn.Module):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.encoder = TransformerEncoder(...)
58
+ self.decoder = TransformerDecoder(...)
59
+
60
+ def forward(self, text):
61
+ # 1. Embed text
62
+ # 2. Add Positional Encodings
63
+ # 3. Predict Mel Frames
64
+ return mel_output
65
+ """, language="python")