bdstar commited on
Commit
93a1d44
·
verified ·
1 Parent(s): 432712b

init commit for app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess, json, os, io, tempfile
3
+ from faster_whisper import WhisperModel
4
+ from ollama import Client as OllamaClient
5
+
6
+ # ---- CONFIG ----
7
+ LLM_MODEL = "llama3.2:3b" # or "mistral:7b", "qwen2.5:3b"
8
+ WHISPER_SIZE = "small" # "base", "small", "medium"
9
+ USE_SILERO = True # set False to use Coqui XTTS v2
10
+
11
+ # ---- STT (faster-whisper) ----
12
+ # Run on GPU if available: compute_type="float16", device="cuda"
13
+ stt_model = WhisperModel(WHISPER_SIZE, device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu",
14
+ compute_type="float16" if os.environ.get("CUDA_VISIBLE_DEVICES") else "int8")
15
+
16
+ def speech_to_text(audio_path: str) -> str:
17
+ segments, info = stt_model.transcribe(audio_path, beam_size=1, vad_filter=True)
18
+ text = "".join(seg.text for seg in segments).strip()
19
+ return text
20
+
21
+ # ---- LLM (Ollama) ----
22
+ ollama = OllamaClient(host="http://127.0.0.1:11434")
23
+
24
+ SYSTEM_PROMPT = """You are a friendly conversational English coach and voice assistant.
25
+ - First, understand the user's utterance.
26
+ - If there are mistakes (grammar/word choice/tense), provide a brief corrected sentence first, prefixed with "Correction:".
27
+ - In 1 short line, explain the key fix, prefixed with "Why:".
28
+ - Then continue the conversation naturally in one or two sentences.
29
+ - Be concise, supportive, and avoid long lectures.
30
+ Format:
31
+ Correction: <corrected sentence or "None">
32
+ Why: <very brief reason, or "N/A">
33
+ Reply: <your friendly response to keep the conversation going>"""
34
+
35
+ def chat_with_llm(history_messages, user_text):
36
+ # history_messages is a list of {"role": "...", "content": "..."}
37
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
38
+
39
+ for m in (history_messages or []):
40
+ role = m.get("role")
41
+ content = m.get("content")
42
+ if role in ("user", "assistant") and content:
43
+ messages.append({"role": role, "content": content})
44
+
45
+ messages.append({"role": "user", "content": user_text})
46
+ resp = ollama.chat(model=LLM_MODEL, messages=messages)
47
+ return resp["message"]["content"]
48
+
49
+
50
+ # ---- TTS ----
51
+ def tts_silero(text: str) -> str:
52
+ """
53
+ Return path to a WAV file synthesized by Silero (CPU-friendly).
54
+ Works across recent torch.hub return signatures.
55
+ """
56
+ import torch, tempfile
57
+ import soundfile as sf
58
+
59
+ # Newer torch.hub supports "trust_repo"; set to True or 'check'
60
+ obj = torch.hub.load(
61
+ repo_or_dir="snakers4/silero-models",
62
+ model="silero_tts",
63
+ language="en",
64
+ speaker="v3_en",
65
+ trust_repo=True # or 'check' to be prompted the first time
66
+ )
67
+
68
+ # Handle both cases: either a single model, or a (model, something) tuple
69
+ model = obj[0] if isinstance(obj, (list, tuple)) else obj
70
+
71
+ sample_rate = 48000
72
+ speaker = "en_0" # valid default voice in v3_en pack
73
+ audio = model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
74
+
75
+ out_wav = tempfile.mktemp(suffix=".wav")
76
+ sf.write(out_wav, audio, sample_rate)
77
+ return out_wav
78
+
79
+
80
+ def tts_coqui_xtts(text: str) -> str:
81
+ """
82
+ Returns path to a WAV file synthesized by Coqui XTTS v2 (higher quality; GPU-friendly).
83
+ """
84
+ from TTS.api import TTS
85
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
86
+ out_wav = tempfile.mktemp(suffix=".wav")
87
+ tts.tts_to_file(text=text, file_path=out_wav, speaker="female-en-5", language="en")
88
+ return out_wav
89
+
90
+ def text_to_speech(text: str) -> str:
91
+ if USE_SILERO:
92
+ return tts_silero(text)
93
+ else:
94
+ return tts_coqui_xtts(text)
95
+
96
+ # ---- Gradio pipeline ----
97
+ def pipeline(audio, history):
98
+ # audio is (sample_rate, np.array) OR a filepath (depends on Gradio version)
99
+ # Normalize to a temp wav file
100
+ if audio is None:
101
+ return history, None, "Please speak something."
102
+
103
+ if isinstance(audio, tuple):
104
+ # (sr, data) -> write wav
105
+ import soundfile as sf, numpy as np, tempfile
106
+ sr, data = audio
107
+ tmp_in = tempfile.mktemp(suffix=".wav")
108
+ sf.write(tmp_in, data.astype("float32"), sr)
109
+ audio_path = tmp_in
110
+ else:
111
+ audio_path = audio # path already
112
+
113
+ user_text = speech_to_text(audio_path)
114
+ if not user_text:
115
+ return history, None, "Didn't catch that—could you repeat?"
116
+
117
+ reply = chat_with_llm(history, user_text)
118
+
119
+ # Extract the "Reply:" line for TTS; speak only the conversational reply
120
+ speak_text = reply
121
+ for tag in ["Reply:", "Correction:", "Why:"]:
122
+ # Try to find "Reply:" block
123
+ if "Reply:" in reply:
124
+ speak_text = reply.split("Reply:", 1)[1].strip()
125
+ break
126
+
127
+ wav_path = text_to_speech(speak_text)
128
+ updated = (history or []) + [
129
+ {"role": "user", "content": user_text},
130
+ {"role": "assistant", "content": reply},
131
+ ]
132
+ return updated, wav_path, ""
133
+
134
+ with gr.Blocks(title="Voice Coach") as demo:
135
+ gr.Markdown("## 🎙️ Interactive Voice Chat (with on-the-fly sentence correction)")
136
+ with gr.Row():
137
+ audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Speak")
138
+ audio_out = gr.Audio(label="Assistant (TTS)", autoplay=True)
139
+ chatbox = gr.Chatbot(type="messages", height=300)
140
+ status = gr.Markdown()
141
+ btn = gr.Button("Send")
142
+
143
+ # Use continuous recording or press "Send" after recording
144
+ audio_in.change(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
145
+ btn.click(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
146
+
147
+ if __name__ == "__main__":
148
+ demo.launch(share=True)