Modern_TalkNET / app.py
harmonicsnail's picture
Add inference model + Gradio app
9496a85
# app.py
import gradio as gr
import os
from model_inference import NetTALKWrapper
# Optional: set env var NETTALK_STATE_DICT to different filename if needed
STATE_DICT = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
# instantiate the model once
try:
model = NetTALKWrapper(state_dict_path=STATE_DICT)
except Exception as e:
# Gradio will show this on startup logs — helpful for debugging
raise RuntimeError(f"Failed to load model: {e}")
def predict_phonemes(word: str):
if not word or not word.strip():
return "Please enter a word", None
phonemes = model.predict_string(word)
# return phoneme string; no audio here (you can add TTS later)
return phonemes, None
css = """
.gradio-container { max-width: 900px; margin: auto; }
body { background: linear-gradient(135deg,#071024,#081226); color: #e6eef8; }
"""
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
gr.Markdown("# 🧠 NetTALK phoneme predictor")
gr.Markdown("Enter a word and get ARPAbet phonemes predicted by the trained model.")
with gr.Row():
word = gr.Textbox(label="Enter word", placeholder="example: 'computer'", lines=1)
btn = gr.Button("Predict")
out_ph = gr.Textbox(label="Predicted ARPAbet Phonemes")
# placeholder for future audio output
out_audio = gr.Audio(label="Synthesized audio (optional)", visible=False)
btn.click(predict_phonemes, inputs=word, outputs=[out_ph, out_audio])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)