Mohit0708 commited on
Commit
4cdc4e7
·
verified ·
1 Parent(s): 607fd46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -64
app.py CHANGED
@@ -1,65 +1,78 @@
1
- # app.py
2
- import streamlit as st
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- from model.inference import TTSInference
6
-
7
- # Page Config
8
- st.set_page_config(page_title="My Custom TTS Engine", layout="wide")
9
-
10
- st.title("🎙️ Custom Architecture TTS Playground")
11
- st.markdown("This project demonstrates a custom PyTorch implementation of a Transformer-based TTS.")
12
-
13
- # Sidebar for Model Controls
14
- with st.sidebar:
15
- st.header("Model Settings")
16
- checkpoint = st.selectbox("Select Checkpoint", ["checkpoints/checkpoint_epoch_50c.pth", "checkpoints/checkpoint_epoch_3c.pth", "checkpoints/checkpoint_epoch_8.pth"])
17
- device = st.radio("Device", ["cpu", "cuda"])
18
- st.info("Load a specific training checkpoint to compare progress.")
19
-
20
- # Initialize the Inference Engine
21
- # (In a real app, use @st.cache_resource to load this once)
22
- tts_engine = TTSInference(checkpoint_path=checkpoint, device=device)
23
-
24
- # Main Input Area
25
- text_input = st.text_area("Enter Text to Speak:", "Deep learning is fascinating.", height=100)
26
-
27
- col1, col2 = st.columns([1, 2])
28
-
29
- with col1:
30
- if st.button("Generate Audio", type="primary"):
31
- with st.spinner("Running Inference..."):
32
- # Call your backend
33
- audio_data, sample_rate, mel_spec = tts_engine.predict(text_input)
34
-
35
- # Play Audio
36
- st.success("Generation Complete!")
37
- st.audio(audio_data, sample_rate=sample_rate)
38
-
39
- # --- VISUALIZATION (Crucial for Path 2) ---
40
- # Showing the spectrogram proves you understand the data, not just the result.
41
- st.subheader("Mel Spectrogram Analysis")
42
- fig, ax = plt.subplots(figsize=(10, 3))
43
- im = ax.imshow(mel_spec, aspect='auto', origin='lower', cmap='inferno')
44
- plt.colorbar(im, ax=ax)
45
- plt.title("Generated Mel Spectrogram")
46
- plt.xlabel("Time Frames")
47
- plt.ylabel("Mel Channels")
48
- st.pyplot(fig)
49
-
50
- with col2:
51
- st.subheader("Architecture Details")
52
- st.code("""
53
-
54
- class TextToMel(nn.Module):
55
- def __init__(self):
56
- super().__init__()
57
- self.encoder = TransformerEncoder(...)
58
- self.decoder = TransformerDecoder(...)
59
-
60
- def forward(self, text):
61
- # 1. Embed text
62
- # 2. Add Positional Encodings
63
- # 3. Predict Mel Frames
64
- return mel_output
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """, language="python")
 
1
+ # app.py
2
+ import streamlit as st
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import os
6
+ from model.inference import TTSInference
7
+
8
+ # Page Config
9
+ st.set_page_config(page_title="My Custom TTS Engine", layout="wide")
10
+
11
+ st.title("🎙️ Custom Architecture TTS Playground")
12
+ st.markdown("This project demonstrates a custom PyTorch implementation of a Transformer-based TTS.")
13
+
14
+ # Sidebar for Model Controls
15
+ with st.sidebar:
16
+ st.header("Model Settings")
17
+ checkpoint = st.selectbox("Select Checkpoint", [
18
+ "checkpoints/checkpoint_epoch_50c.pth",
19
+ "checkpoints/checkpoint_epoch_3c.pth",
20
+ "checkpoints/checkpoint_epoch_8.pth"
21
+ ])
22
+ # Force CPU for Hugging Face free tier to prevent CUDA errors
23
+ device = st.radio("Device", ["cpu"])
24
+ st.info("Load a specific training checkpoint to compare progress.")
25
+
26
+ # --- CRITICAL FIX FOR CLOUD: Cache the model ---
27
+ @st.cache_resource
28
+ def load_engine(ckpt_path, dev):
29
+ if not os.path.exists(ckpt_path):
30
+ return None # Return None if file isn't uploaded yet
31
+ return TTSInference(checkpoint_path=ckpt_path, device=dev)
32
+
33
+ # Initialize the Inference Engine
34
+ tts_engine = load_engine(checkpoint, device)
35
+
36
+ # Main Input Area
37
+ text_input = st.text_area("Enter Text to Speak:", "Deep learning is fascinating.", height=100)
38
+
39
+ col1, col2 = st.columns([1, 2])
40
+
41
+ with col1:
42
+ if st.button("Generate Audio", type="primary"):
43
+ if tts_engine is None:
44
+ st.error(f"⚠️ Error: Could not find '{checkpoint}'. Did you upload it to the 'checkpoints' folder on Hugging Face?")
45
+ else:
46
+ with st.spinner("Running Inference..."):
47
+ # Call your backend
48
+ audio_data, sample_rate, mel_spec = tts_engine.predict(text_input)
49
+
50
+ # Play Audio
51
+ st.success("Generation Complete!")
52
+ st.audio(audio_data, sample_rate=sample_rate)
53
+
54
+ # --- VISUALIZATION ---
55
+ st.subheader("Mel Spectrogram Analysis")
56
+ fig, ax = plt.subplots(figsize=(10, 3))
57
+ im = ax.imshow(mel_spec, aspect='auto', origin='lower', cmap='inferno')
58
+ plt.colorbar(im, ax=ax)
59
+ plt.title("Generated Mel Spectrogram")
60
+ plt.xlabel("Time Frames")
61
+ plt.ylabel("Mel Channels")
62
+ st.pyplot(fig)
63
+
64
+ with col2:
65
+ st.subheader("Architecture Details")
66
+ st.code("""
67
+ class TextToMel(nn.Module):
68
+ def __init__(self):
69
+ super().__init__()
70
+ self.encoder = TransformerEncoder(...)
71
+ self.decoder = TransformerDecoder(...)
72
+
73
+ def forward(self, text):
74
+ # 1. Embed text
75
+ # 2. Add Positional Encodings
76
+ # 3. Predict Mel Frames
77
+ return mel_output
78
  """, language="python")