bdstar commited on
Commit
a0cbe1c
·
verified ·
1 Parent(s): c990ff3

init commit for all files

Browse files
Files changed (4) hide show
  1. app.py +167 -0
  2. apt.txt +1 -0
  3. requirements.txt +94 -0
  4. runtime.txt +1 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess, json, os, io, tempfile
3
+ from faster_whisper import WhisperModel
4
+ from ollama import Client as OllamaClient
5
+
6
+ # ---- CONFIG ----
7
+ LLM_MODEL = "llama3.2:3b" # or "mistral:7b", "qwen2.5:3b"
8
+ WHISPER_SIZE = "small" # "base", "small", "medium"
9
+ USE_SILERO = True # set False to use Coqui XTTS v2
10
+
11
+ import os
12
+ USE_REMOTE_OLLAMA = bool(os.getenv("OLLAMA_HOST"))
13
+
14
+ if not USE_REMOTE_OLLAMA:
15
+ # Transformers fallback for Spaces (CPU-friendly small instruct model)
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
17
+ HF_CHAT_MODEL = os.getenv("HF_CHAT_MODEL", "google/gemma-2-2b-it") # small instruct model that runs on CPU
18
+ _tok = AutoTokenizer.from_pretrained(HF_CHAT_MODEL)
19
+ _mdl = AutoModelForCausalLM.from_pretrained(HF_CHAT_MODEL, torch_dtype="auto", device_map="auto")
20
+ gen = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=256)
21
+
22
+
23
+ # ---- STT (faster-whisper) ----
24
+ # Run on GPU if available: compute_type="float16", device="cuda"
25
+ stt_model = WhisperModel(WHISPER_SIZE, device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu",
26
+ compute_type="float16" if os.environ.get("CUDA_VISIBLE_DEVICES") else "int8")
27
+
28
+ def speech_to_text(audio_path: str) -> str:
29
+ segments, info = stt_model.transcribe(audio_path, beam_size=1, vad_filter=True)
30
+ text = "".join(seg.text for seg in segments).strip()
31
+ return text
32
+
33
+ # ---- LLM (Ollama) ----
34
+ ollama = OllamaClient(host="http://127.0.0.1:11434")
35
+
36
+ SYSTEM_PROMPT = """You are a friendly conversational English coach and voice assistant.
37
+ - First, understand the user's utterance.
38
+ - If there are mistakes (grammar/word choice/tense), provide a brief corrected sentence first, prefixed with "Correction:".
39
+ - In 1 short line, explain the key fix, prefixed with "Why:".
40
+ - Then continue the conversation naturally in one or two sentences.
41
+ - Be concise, supportive, and avoid long lectures.
42
+ Format:
43
+ Correction: <corrected sentence or "None">
44
+ Why: <very brief reason, or "N/A">
45
+ Reply: <your friendly response to keep the conversation going>"""
46
+
47
+ def chat_with_llm(history_messages, user_text):
48
+ if USE_REMOTE_OLLAMA:
49
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
50
+ for m in (history_messages or []):
51
+ if m.get("role") in ("user", "assistant") and m.get("content"):
52
+ messages.append({"role": m["role"], "content": m["content"]})
53
+ messages.append({"role": "user", "content": user_text})
54
+ resp = ollama.chat(model=LLM_MODEL, messages=messages)
55
+ return resp["message"]["content"]
56
+ else:
57
+ # Simple prompt stitching for the fallback pipeline
58
+ history_text = "\n".join(
59
+ [f"User: {m['content']}" if m["role"]=="user" else f"Assistant: {m['content']}"
60
+ for m in (history_messages or [])]
61
+ )
62
+ prompt = f"{SYSTEM_PROMPT}\n{history_text}\nUser: {user_text}\nAssistant:"
63
+ out = gen(prompt)[0]["generated_text"]
64
+ # Return only the new assistant chunk after the prompt
65
+ return out.split("Assistant:", 1)[-1].strip()
66
+
67
+
68
+
69
+ # ---- TTS ----
70
+ def tts_silero(text: str) -> str:
71
+ """
72
+ Return path to a WAV file synthesized by Silero (CPU-friendly).
73
+ Works across recent torch.hub return signatures.
74
+ """
75
+ import torch, tempfile
76
+ import soundfile as sf
77
+
78
+ # Newer torch.hub supports "trust_repo"; set to True or 'check'
79
+ obj = torch.hub.load(
80
+ repo_or_dir="snakers4/silero-models",
81
+ model="silero_tts",
82
+ language="en",
83
+ speaker="v3_en",
84
+ trust_repo=True # or 'check' to be prompted the first time
85
+ )
86
+
87
+ # Handle both cases: either a single model, or a (model, something) tuple
88
+ model = obj[0] if isinstance(obj, (list, tuple)) else obj
89
+
90
+ sample_rate = 48000
91
+ speaker = "en_0" # valid default voice in v3_en pack
92
+ audio = model.apply_tts(text=text, speaker=speaker, sample_rate=sample_rate)
93
+
94
+ out_wav = tempfile.mktemp(suffix=".wav")
95
+ sf.write(out_wav, audio, sample_rate)
96
+ return out_wav
97
+
98
+
99
+ def tts_coqui_xtts(text: str) -> str:
100
+ """
101
+ Returns path to a WAV file synthesized by Coqui XTTS v2 (higher quality; GPU-friendly).
102
+ """
103
+ from TTS.api import TTS
104
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
105
+ out_wav = tempfile.mktemp(suffix=".wav")
106
+ tts.tts_to_file(text=text, file_path=out_wav, speaker="female-en-5", language="en")
107
+ return out_wav
108
+
109
+ def text_to_speech(text: str) -> str:
110
+ if USE_SILERO:
111
+ return tts_silero(text)
112
+ else:
113
+ return tts_coqui_xtts(text)
114
+
115
+ # ---- Gradio pipeline ----
116
+ def pipeline(audio, history):
117
+ # audio is (sample_rate, np.array) OR a filepath (depends on Gradio version)
118
+ # Normalize to a temp wav file
119
+ if audio is None:
120
+ return history, None, "Please speak something."
121
+
122
+ if isinstance(audio, tuple):
123
+ # (sr, data) -> write wav
124
+ import soundfile as sf, numpy as np, tempfile
125
+ sr, data = audio
126
+ tmp_in = tempfile.mktemp(suffix=".wav")
127
+ sf.write(tmp_in, data.astype("float32"), sr)
128
+ audio_path = tmp_in
129
+ else:
130
+ audio_path = audio # path already
131
+
132
+ user_text = speech_to_text(audio_path)
133
+ if not user_text:
134
+ return history, None, "Didn't catch that—could you repeat?"
135
+
136
+ reply = chat_with_llm(history, user_text)
137
+
138
+ # Extract the "Reply:" line for TTS; speak only the conversational reply
139
+ speak_text = reply
140
+ for tag in ["Reply:", "Correction:", "Why:"]:
141
+ # Try to find "Reply:" block
142
+ if "Reply:" in reply:
143
+ speak_text = reply.split("Reply:", 1)[1].strip()
144
+ break
145
+
146
+ wav_path = text_to_speech(speak_text)
147
+ updated = (history or []) + [
148
+ {"role": "user", "content": user_text},
149
+ {"role": "assistant", "content": reply},
150
+ ]
151
+ return updated, wav_path, ""
152
+
153
+ with gr.Blocks(title="Voice Coach") as demo:
154
+ gr.Markdown("## 🎙️ Interactive Voice Chat (with on-the-fly sentence correction)")
155
+ with gr.Row():
156
+ audio_in = gr.Audio(sources=["microphone"], type="filepath", label="Speak")
157
+ audio_out = gr.Audio(label="Assistant (TTS)", autoplay=True)
158
+ chatbox = gr.Chatbot(type="messages", height=300)
159
+ status = gr.Markdown()
160
+ btn = gr.Button("Send")
161
+
162
+ # Use continuous recording or press "Send" after recording
163
+ audio_in.change(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
164
+ btn.click(pipeline, inputs=[audio_in, chatbox], outputs=[chatbox, audio_out, status])
165
+
166
+ if __name__ == "__main__":
167
+ demo.launch(share=True)
apt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-doc==0.0.3
3
+ annotated-types==0.7.0
4
+ antlr4-python3-runtime==4.9.3
5
+ anyio==4.11.0
6
+ av==16.0.1
7
+ Brotli==1.1.0
8
+ certifi==2025.10.5
9
+ cffi==2.0.0
10
+ charset-normalizer==3.4.4
11
+ click==8.3.0
12
+ colorama==0.4.6
13
+ coloredlogs==15.0.1
14
+ contourpy==1.3.3
15
+ ctranslate2==4.6.0
16
+ cycler==0.12.1
17
+ fastapi==0.120.3
18
+ faster-whisper==1.2.0
19
+ ffmpy==0.6.4
20
+ filelock==3.19.1
21
+ flatbuffers==25.9.23
22
+ fonttools==4.60.1
23
+ fsspec==2025.9.0
24
+ gradio==5.49.1
25
+ gradio_client==1.13.3
26
+ groovy==0.1.2
27
+ h11==0.16.0
28
+ hf-xet==1.2.0
29
+ httpcore==1.0.9
30
+ httpx==0.28.1
31
+ huggingface-hub==1.0.1
32
+ humanfriendly==10.0
33
+ idna==3.11
34
+ importlib_resources==6.5.2
35
+ Jinja2==3.1.6
36
+ kiwisolver==1.4.9
37
+ markdown-it-py==4.0.0
38
+ MarkupSafe==2.1.5
39
+ matplotlib==3.10.7
40
+ mdurl==0.1.2
41
+ mpmath==1.3.0
42
+ networkx==3.5
43
+ numpy==2.3.4
44
+ ollama==0.6.0
45
+ omegaconf==2.3.0
46
+ onnxruntime==1.23.2
47
+ orjson==3.11.4
48
+ packaging==25.0
49
+ pandas==2.3.3
50
+ pillow==10.4.0
51
+ protobuf==6.33.0
52
+ pycparser==2.23
53
+ pydantic==2.11.10
54
+ pydantic_core==2.33.2
55
+ pydub==0.25.1
56
+ Pygments==2.19.2
57
+ pyparsing==3.2.5
58
+ pyreadline3==3.5.4
59
+ python-dateutil==2.9.0.post0
60
+ python-multipart==0.0.20
61
+ pytz==2025.2
62
+ PyYAML==6.0.3
63
+ requests==2.32.5
64
+ rich==14.2.0
65
+ ruff==0.14.3
66
+ safehttpx==0.1.7
67
+ semantic-version==2.10.0
68
+ shellingham==1.5.4
69
+ silero==0.5.1
70
+ six==1.17.0
71
+ sniffio==1.3.1
72
+ sounddevice==0.5.3
73
+ soundfile==0.13.1
74
+ starlette==0.49.1
75
+ sympy==1.14.0
76
+ tokenizers==0.22.1
77
+ tomlkit==0.12.0
78
+ torch==2.9.0+cu128
79
+ torchaudio==2.9.0
80
+ tqdm==4.67.1
81
+ typer==0.20.0
82
+ typer-slim==0.20.0
83
+ typing-inspection==0.4.2
84
+ typing_extensions==4.15.0
85
+ tzdata==2025.2
86
+ urllib3==2.5.0
87
+ uvicorn==0.38.0
88
+ websockets==15.0.1
89
+ transformers==4.44.2
90
+ accelerate
91
+ safetensors
92
+ sentencepiece
93
+ protobuf
94
+ einops
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.11