bdstar commited on
Commit
c812b33
·
verified ·
1 Parent(s): b286a6f

update app file

Browse files
Files changed (1) hide show
  1. app.py +167 -167
app.py CHANGED
@@ -1,167 +1,167 @@
1
- import gradio as gr
2
- import subprocess, json, os, io, tempfile
3
- from faster_whisper import WhisperModel
4
- from ollama import Client as OllamaClient
5
-
6
- # ---- CONFIG ----
7
- LLM_MODEL = "llama3.2:3b" # or "mistral:7b", "qwen2.5:3b"
8
- WHISPER_SIZE = "small" # "base", "small", "medium"
9
- USE_SILERO = True # set False to use Coqui XTTS v2
10
-
11
- import os
12
- USE_REMOTE_OLLAMA = bool(os.getenv("OLLAMA_HOST"))
13
-
14
- if not USE_REMOTE_OLLAMA:
15
- # Transformers fallback for Spaces (CPU-friendly small instruct model)
16
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
17
- HF_CHAT_MODEL = os.getenv("HF_CHAT_MODEL", "google/gemma-2-2b-it") # small instruct model that runs on CPU
18
- _tok = AutoTokenizer.from_pretrained(HF_CHAT_MODEL)
19
- _mdl = AutoModelForCausalLM.from_pretrained(HF_CHAT_MODEL, torch_dtype="auto", device_map="auto")
20
- gen = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=256)
21
-
22
-
23
- # ---- STT (faster-whisper) ----
24
- # Run on GPU if available: compute_type="float16", device="cuda"
25
- stt_model = WhisperModel(WHISPER_SIZE, device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu",
26
- compute_type="float16" if os.environ.get("CUDA_VISIBLE_DEVICES") else "int8")
27
-
28
- def speech_to_text(audio_path: str) -> str:
29
- segments, info = stt_model.transcribe(audio_path, beam_size=1, vad_filter=True)
30
- text = "".join(seg.text for seg in segments).strip()
31
- return text
32
-
33
- # ---- LLM (Ollama) ----
34
- ollama = OllamaClient(host="http://127.0.0.1:11434")
35
-
36
- SYSTEM_PROMPT = """You are a friendly conversational English coach and voice assistant.
37
- - First, understand the user's utterance.
38
- - If there are mistakes (grammar/word choice/tense), provide a brief corrected sentence first, prefixed with "Correction:".
39
- - In 1 short line, explain the key fix, prefixed with "Why:".
40
- - Then continue the conversation naturally in one or two sentences.
41
- - Be concise, supportive, and avoid long lectures.
42
- Format:
43
- Correction: <corrected sentence or "None">
44
- Why: <very brief reason, or "N/A">
45
- Reply: <your friendly response to keep the conversation going>"""
46
-
47
- def chat_with_llm(history_messages, user_text):
48
- if USE_REMOTE_OLLAMA:
49
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
50
- for m in (history_messages or []):
51
- if m.get("role") in ("user", "assistant") and m.get("content"):
52
- messages.append({"role": m["role"], "content": m["content"]})
53
- messages.append({"role": "user", "content": user_text})
54
- resp = ollama.chat(model=LLM_MODEL, messages=messages)
55
- return resp["message"]["content"]
56
- else:
57
- # Simple prompt stitching for the fallback pipeline
58
- history_text = "\n".join(
59
- [f"User: {m['content']}" if m["role"]=="user" else f"Assistant: {m['content']}"
60
- for m in (history_messages or [])]
61
- )
62
- prompt = f"{SYSTEM_PROMPT}\n{history_text}\nUser: {user_text}\nAssistant:"
63
- out = gen(prompt)[0]["generated_text"]
64
- # Return only the new assistant chunk after the prompt
65
- return out.split("Assistant:", 1)[-1].strip()
66
-
67
-
68
-
69
- # ---- TTS ----
70
- def tts_silero(text: str) -> str:
71
- """
72
- Return path to a WAV file synthesized by Silero (CPU-friendly).
73
- Works across recent torch.hub return signatures.
74
- """
75
- import torch, tempfile
76
- import soundfile as sf
77
-
78
- # Newer torch.hub supports "trust_repo"; set to True or 'check'
79
- obj = torch.hub.load(
80
- repo_or_dir="snakers4/silero-models",
81
- model="silero_tts",
82
- language="en",
83
- speaker="v3_en",
84
- trust_repo=True # or 'check' to be prompted the first time
85
- )
86
-
87
- # Handle both cases: either a single model, or a (model, something) tuple
88
- model = obj[0] if isinstance(obj, (list, tuple)) else obj
89
-
90
- sample_rate = 48000
91
- speaker = "en_0" # valid default voice in v3_en pack
92
- audio = model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
93
-
94
- out_wav = tempfile.mktemp(suffix=".wav")
95
- sf.write(out_wav, audio, sample_rate)
96
- return out_wav
97
-
98
-
99
- def tts_coqui_xtts(text: str) -> str:
100
- """
101
- Returns path to a WAV file synthesized by Coqui XTTS v2 (higher quality; GPU-friendly).
102
- """
103
- from TTS.api import TTS
104
- tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
105
- out_wav = tempfile.mktemp(suffix=".wav")
106
- tts.tts_to_file(text=text, file_path=out_wav, speaker="female-en-5", language="en")
107
- return out_wav
108
-
109
- def text_to_speech(text: str) -> str:
110
- if USE_SILERO:
111
- return tts_silero(text)
112
- else:
113
- return tts_coqui_xtts(text)
114
-
115
- # ---- Gradio pipeline ----
116
- def pipeline(audio, history):
117
- # audio is (sample_rate, np.array) OR a filepath (depends on Gradio version)
118
- # Normalize to a temp wav file
119
- if audio is None:
120
- return history, None, "Please speak something."
121
-
122
- if isinstance(audio, tuple):
123
- # (sr, data) -> write wav
124
- import soundfile as sf, numpy as np, tempfile
125
- sr, data = audio
126
- tmp_in = tempfile.mktemp(suffix=".wav")
127
- sf.write(tmp_in, data.astype("float32"), sr)
128
- audio_path = tmp_in
129
- else:
130
- audio_path = audio # path already
131
-
132
- user_text = speech_to_text(audio_path)
133
- if not user_text:
134
- return history, None, "Didn't catch that—could you repeat?"
135
-
136
- reply = chat_with_llm(history, user_text)
137
-
138
- # Extract the "Reply:" line for TTS; speak only the conversational reply
139
- speak_text = reply
140
- for tag in ["Reply:", "Correction:", "Why:"]:
141
- # Try to find "Reply:" block
142
- if "Reply:" in reply:
143
- speak_text = reply.split("Reply:", 1)[1].strip()
144
- break
145
-
146
- wav_path = text_to_speech(speak_text)
147
- updated = (history or []) + [
148
- {"role": "user", "content": user_text},
149
- {"role": "assistant", "content": reply},
150
- ]
151
- return updated, wav_path, ""
152
-
153
- with gr.Blocks(title="Voice Coach") as demo:
154
- gr.Markdown("## 🎙️ Interactive Voice Chat (with on-the-fly sentence correction)")
155
- with gr.Row():
156
- audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Speak")
157
- audio_out = gr.Audio(label="Assistant (TTS)", autoplay=True)
158
- chatbox = gr.Chatbot(type="messages", height=300)
159
- status = gr.Markdown()
160
- btn = gr.Button("Send")
161
-
162
- # Use continuous recording or press "Send" after recording
163
- audio_in.change(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
164
- btn.click(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
165
-
166
- if __name__ == "__main__":
167
- demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import subprocess, json, os, io, tempfile
3
+ from faster_whisper import WhisperModel
4
+ from ollama import Client as OllamaClient
5
+
6
+ # ---- CONFIG ----
7
+ LLM_MODEL = "llama3.2:3b" # or "mistral:7b", "qwen2.5:3b"
8
+ WHISPER_SIZE = "small" # "base", "small", "medium"
9
+ USE_SILERO = True # set False to use Coqui XTTS v2
10
+
11
+ import os
12
+ USE_REMOTE_OLLAMA = bool(os.getenv("OLLAMA_HOST"))
13
+
14
+ if not USE_REMOTE_OLLAMA:
15
+ # Transformers fallback for Spaces (CPU-friendly small instruct model)
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
17
+ HF_CHAT_MODEL = os.getenv("HF_CHAT_MODEL", "google/gemma-2-2b-it") # small instruct model that runs on CPU
18
+ _tok = AutoTokenizer.from_pretrained(HF_CHAT_MODEL)
19
+ _mdl = AutoModelForCausalLM.from_pretrained(HF_CHAT_MODEL, torch_dtype="auto", device_map="auto")
20
+ gen = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=256)
21
+
22
+
23
+ # ---- STT (faster-whisper) ----
24
+ # Run on GPU if available: compute_type="float16", device="cuda"
25
+ stt_model = WhisperModel(WHISPER_SIZE, device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu",
26
+ compute_type="float16" if os.environ.get("CUDA_VISIBLE_DEVICES") else "int8")
27
+
28
+ def speech_to_text(audio_path: str) -> str:
29
+ segments, info = stt_model.transcribe(audio_path, beam_size=1, vad_filter=True)
30
+ text = "".join(seg.text for seg in segments).strip()
31
+ return text
32
+
33
+ # ---- LLM (Ollama) ----
34
+ # ollama = OllamaClient(host="http://127.0.0.1:11434")
35
+
36
+ SYSTEM_PROMPT = """You are a friendly conversational English coach and voice assistant.
37
+ - First, understand the user's utterance.
38
+ - If there are mistakes (grammar/word choice/tense), provide a brief corrected sentence first, prefixed with "Correction:".
39
+ - In 1 short line, explain the key fix, prefixed with "Why:".
40
+ - Then continue the conversation naturally in one or two sentences.
41
+ - Be concise, supportive, and avoid long lectures.
42
+ Format:
43
+ Correction: <corrected sentence or "None">
44
+ Why: <very brief reason, or "N/A">
45
+ Reply: <your friendly response to keep the conversation going>"""
46
+
47
+ def chat_with_llm(history_messages, user_text):
48
+ if USE_REMOTE_OLLAMA:
49
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
50
+ for m in (history_messages or []):
51
+ if m.get("role") in ("user", "assistant") and m.get("content"):
52
+ messages.append({"role": m["role"], "content": m["content"]})
53
+ messages.append({"role": "user", "content": user_text})
54
+ resp = ollama.chat(model=LLM_MODEL, messages=messages)
55
+ return resp["message"]["content"]
56
+ else:
57
+ # Simple prompt stitching for the fallback pipeline
58
+ history_text = "\n".join(
59
+ [f"User: {m['content']}" if m["role"]=="user" else f"Assistant: {m['content']}"
60
+ for m in (history_messages or [])]
61
+ )
62
+ prompt = f"{SYSTEM_PROMPT}\n{history_text}\nUser: {user_text}\nAssistant:"
63
+ out = gen(prompt)[0]["generated_text"]
64
+ # Return only the new assistant chunk after the prompt
65
+ return out.split("Assistant:", 1)[-1].strip()
66
+
67
+
68
+
69
+ # ---- TTS ----
70
+ def tts_silero(text: str) -> str:
71
+ """
72
+ Return path to a WAV file synthesized by Silero (CPU-friendly).
73
+ Works across recent torch.hub return signatures.
74
+ """
75
+ import torch, tempfile
76
+ import soundfile as sf
77
+
78
+ # Newer torch.hub supports "trust_repo"; set to True or 'check'
79
+ obj = torch.hub.load(
80
+ repo_or_dir="snakers4/silero-models",
81
+ model="silero_tts",
82
+ language="en",
83
+ speaker="v3_en",
84
+ trust_repo=True # or 'check' to be prompted the first time
85
+ )
86
+
87
+ # Handle both cases: either a single model, or a (model, something) tuple
88
+ model = obj[0] if isinstance(obj, (list, tuple)) else obj
89
+
90
+ sample_rate = 48000
91
+ speaker = "en_0" # valid default voice in v3_en pack
92
+ audio = model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
93
+
94
+ out_wav = tempfile.mktemp(suffix=".wav")
95
+ sf.write(out_wav, audio, sample_rate)
96
+ return out_wav
97
+
98
+
99
+ def tts_coqui_xtts(text: str) -> str:
100
+ """
101
+ Returns path to a WAV file synthesized by Coqui XTTS v2 (higher quality; GPU-friendly).
102
+ """
103
+ from TTS.api import TTS
104
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
105
+ out_wav = tempfile.mktemp(suffix=".wav")
106
+ tts.tts_to_file(text=text, file_path=out_wav, speaker="female-en-5", language="en")
107
+ return out_wav
108
+
109
+ def text_to_speech(text: str) -> str:
110
+ if USE_SILERO:
111
+ return tts_silero(text)
112
+ else:
113
+ return tts_coqui_xtts(text)
114
+
115
+ # ---- Gradio pipeline ----
116
+ def pipeline(audio, history):
117
+ # audio is (sample_rate, np.array) OR a filepath (depends on Gradio version)
118
+ # Normalize to a temp wav file
119
+ if audio is None:
120
+ return history, None, "Please speak something."
121
+
122
+ if isinstance(audio, tuple):
123
+ # (sr, data) -> write wav
124
+ import soundfile as sf, numpy as np, tempfile
125
+ sr, data = audio
126
+ tmp_in = tempfile.mktemp(suffix=".wav")
127
+ sf.write(tmp_in, data.astype("float32"), sr)
128
+ audio_path = tmp_in
129
+ else:
130
+ audio_path = audio # path already
131
+
132
+ user_text = speech_to_text(audio_path)
133
+ if not user_text:
134
+ return history, None, "Didn't catch that—could you repeat?"
135
+
136
+ reply = chat_with_llm(history, user_text)
137
+
138
+ # Extract the "Reply:" line for TTS; speak only the conversational reply
139
+ speak_text = reply
140
+ for tag in ["Reply:", "Correction:", "Why:"]:
141
+ # Try to find "Reply:" block
142
+ if "Reply:" in reply:
143
+ speak_text = reply.split("Reply:", 1)[1].strip()
144
+ break
145
+
146
+ wav_path = text_to_speech(speak_text)
147
+ updated = (history or []) + [
148
+ {"role": "user", "content": user_text},
149
+ {"role": "assistant", "content": reply},
150
+ ]
151
+ return updated, wav_path, ""
152
+
153
+ with gr.Blocks(title="Voice Coach") as demo:
154
+ gr.Markdown("## 🎙️ Interactive Voice Chat (with on-the-fly sentence correction)")
155
+ with gr.Row():
156
+ audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Speak")
157
+ audio_out = gr.Audio(label="Assistant (TTS)", autoplay=True)
158
+ chatbox = gr.Chatbot(type="messages", height=300)
159
+ status = gr.Markdown()
160
+ btn = gr.Button("Send")
161
+
162
+ # Use continuous recording or press "Send" after recording
163
+ audio_in.change(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
164
+ btn.click(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
165
+
166
+ if __name__ == "__main__":
167
+ demo.launch(share=True)