|
|
import gradio as gr |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
try: |
|
|
from app.main import TTSService, STTService |
|
|
except ImportError: |
|
|
from .main import TTSService, STTService |
|
|
|
|
|
|
|
|
try: |
|
|
from config import TTS_MODELS, LANGUAGES, MODEL_MAPPING |
|
|
except ImportError: |
|
|
from ..config import TTS_MODELS, LANGUAGES, MODEL_MAPPING |
|
|
|
|
|
class SpikStudioUI: |
|
|
def __init__(self, tts_service=None, stt_service=None): |
|
|
self.tts_service = tts_service or TTSService() |
|
|
self.stt_service = stt_service or STTService() |
|
|
self.theme = gr.themes.Ocean() |
|
|
self.css = """ |
|
|
#tts-tab {padding: 2rem;} |
|
|
#tts-input, #tts-voice, #tts-generate, #tts-audio {margin-bottom: 1rem;} |
|
|
""" |
|
|
|
|
|
def create_ui(self): |
|
|
with gr.Blocks(theme=self.theme, css=self.css) as demo: |
|
|
self._create_header() |
|
|
self._create_tts_tab() |
|
|
self._create_stt_tab() |
|
|
return demo |
|
|
|
|
|
def _create_header(self): |
|
|
gr.HTML(""" |
|
|
<div style='text-align: center; margin-top: 40px;'> |
|
|
<span style='font-size: 3em; font-family: serif; font-weight: 400;'>Spik Studio</span><br> |
|
|
<span style='font-size: 2em; font-family: serif; color: #1784d6; font-weight: 600;'>Text-to-Speech & Speech-to-Text</span><br><br> |
|
|
<span style='font-size: 1.2em; font-family: sans-serif;'> |
|
|
Transform text to natural-sounding audio, or get instant transcripts from your recordings.<br> |
|
|
</span> |
|
|
</div> |
|
|
<div style='max-width:420px;margin:0 auto;'></div> |
|
|
""") |
|
|
|
|
|
def _create_tts_tab(self): |
|
|
with gr.Tab("Text to Speech"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
text = gr.Textbox(label="Script", placeholder="Enter your text here...", lines=4) |
|
|
with gr.Row(): |
|
|
language = gr.Dropdown( |
|
|
[lang[0] for lang in LANGUAGES], |
|
|
value=LANGUAGES[0][0], |
|
|
label="Language" |
|
|
) |
|
|
model = gr.Dropdown( |
|
|
TTS_MODELS, |
|
|
value=TTS_MODELS[0], |
|
|
label="Model", |
|
|
interactive=False |
|
|
) |
|
|
btn = gr.Button("Generate Speech") |
|
|
with gr.Column(): |
|
|
audio = gr.Audio(label="Generated Audio") |
|
|
meta = gr.Markdown() |
|
|
|
|
|
btn.click( |
|
|
self._run_tts, |
|
|
inputs=[text, language, model], |
|
|
outputs=[audio, meta] |
|
|
) |
|
|
|
|
|
def _create_stt_tab(self): |
|
|
with gr.Tab("Speech to Text"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
audio_url_input = gr.Textbox(label="Enter Audio URL", placeholder="https://example.com/audio.wav") |
|
|
or_label = gr.HTML("<span style='font-size: 1.2em; font-family: sans-serif;'>OR</span>") |
|
|
audio_input = gr.Audio(type="filepath", label="Upload Audio") |
|
|
with gr.Row(): |
|
|
stt_language = gr.Dropdown( |
|
|
choices=[lang[0] for lang in LANGUAGES], |
|
|
value=LANGUAGES[0][0], |
|
|
label="Language" |
|
|
) |
|
|
stt_model = gr.Dropdown( |
|
|
list(MODEL_MAPPING.keys()), |
|
|
value=list(MODEL_MAPPING.keys())[0], |
|
|
label="Model", |
|
|
interactive=False |
|
|
) |
|
|
stt_btn = gr.Button("Transcribe") |
|
|
with gr.Column(): |
|
|
transcript = gr.Textbox(label="Transcript", placeholder="Transcription will appear here...", lines=24) |
|
|
|
|
|
stt_btn.click( |
|
|
self._run_stt, |
|
|
inputs=[audio_url_input, audio_input, stt_language, stt_model], |
|
|
outputs=transcript |
|
|
) |
|
|
|
|
|
def _run_tts(self, text, language, model): |
|
|
return self.tts_service.generate_speech(text, model) |
|
|
|
|
|
def _run_stt(self, audio_url, audio_input, stt_language, stt_model): |
|
|
|
|
|
lang_code = next((code for name, code in LANGUAGES if name == stt_language), "en") |
|
|
|
|
|
api_model = MODEL_MAPPING.get(stt_model, "whisper") |
|
|
|
|
|
return self.stt_service.transcribe( |
|
|
audio_url=audio_url, |
|
|
audio_path=audio_input, |
|
|
language=lang_code, |
|
|
model=api_model |
|
|
) |
|
|
|
|
|
def launch_ui(server_name="0.0.0.0", server_port=7860, share=False): |
|
|
"""Launch the Gradio UI.""" |
|
|
app = SpikStudioUI() |
|
|
demo = app.create_ui() |
|
|
demo.launch(server_name=server_name, server_port=server_port, share=share) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
launch_ui(debug=True) |
|
|
|