Nguyen5 commited on
Commit
fcc2090
·
1 Parent(s): 090a936
Files changed (2) hide show
  1. app.py +566 -228
  2. speech_io.py +593 -215
app.py CHANGED
@@ -1,13 +1,12 @@
1
- # app.py – Prüfungsrechts-Chatbot (RAG + Sprache, UI kiểu ChatGPT + VAD)
2
-
3
- from __future__ import annotations
4
-
5
  import os
6
- from dataclasses import dataclass, field
7
- from typing import Any, List
8
-
9
  import gradio as gr
10
  from gradio_pdf import PDF
 
11
 
12
  from load_documents import load_all_documents
13
  from split_documents import split_documents
@@ -15,9 +14,72 @@ from vectorstore import build_vectorstore
15
  from retriever import get_retriever
16
  from llm import load_llm
17
  from rag_pipeline import answer
18
- from speech_io import transcribe_audio, synthesize_speech
 
 
 
 
 
 
 
19
 
20
- ASR_LANGUAGE_HINT = os.getenv("ASR_LANGUAGE", "de") # "auto" = Auto-Detect
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # =====================================================
23
  # INITIALISIERUNG (global)
@@ -43,6 +105,121 @@ pdf_meta = next(d.metadata for d in docs if d.metadata.get("type") == "pdf")
43
  hg_meta = next(d.metadata for d in docs if d.metadata.get("type") == "hg")
44
  hg_url = hg_meta.get("viewer_url")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # =====================================================
47
  # Quellen formatieren – Markdown für Chat
48
  # =====================================================
@@ -61,255 +238,416 @@ def format_sources(src):
61
  return "\n".join(out)
62
 
63
  # =====================================================
64
- # State Management (wie Gradio Guide)
65
- # =====================================================
66
- @dataclass
67
- class AppState:
68
- conversation: list = field(default_factory=list) # LLM-History (role/content)
69
- stopped: bool = False
70
- model_outs: Any = None
71
-
72
- # =====================================================
73
- # CORE CHAT-FUNKTION (Text + Mikro)
74
  # =====================================================
75
- def chat_fn(
76
- text_input: str,
77
- audio_path: str,
78
- history: list,
79
- state: AppState,
80
- ):
81
- # Ensure history is list of dicts
82
- if history is None or not isinstance(history, list):
83
- history = []
84
-
85
- # Convert old style [[u, a], ...] → new style messages
86
- new_history = []
87
- for h in history:
88
- if isinstance(h, dict):
89
- new_history.append(h)
90
- elif isinstance(h, list) and len(h) == 2:
91
- new_history.append({"role": "user", "content": h[0]})
92
- new_history.append({"role": "assistant", "content": h[1]})
93
-
94
- history = new_history
95
-
96
  text = (text_input or "").strip()
97
-
98
- # Audio-only input transcribe
99
- if audio_path and not text:
100
- text = transcribe_audio(audio_path, language=ASR_LANGUAGE_HINT).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  if not text:
103
- return history, state, "", None
104
 
105
- # Retrieve RAG answer
106
- ans, sources = answer(text, retriever, llm)
 
 
 
107
  bot_msg = ans + format_sources(sources)
 
 
 
108
 
109
- # Append user / assistant messages
110
- history.append({"role": "user", "content": text})
111
- history.append({"role": "assistant", "content": bot_msg})
 
 
112
 
113
- # Also update state
114
- state.conversation.append({"role": "user", "content": text})
115
- state.conversation.append({"role": "assistant", "content": bot_msg})
116
 
117
- return history, state, "", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
 
 
 
 
119
 
120
  # =====================================================
121
- # CSS + JS (VAD) nach Gradio Guide adaptiert
122
  # =====================================================
123
- CUSTOM_STYLE_AND_VAD = """
124
- <style>
125
- html, body {height: auto !important; overflow-y: auto !important;}
126
- .gradio-container {max-width: 960px; margin: 0 auto; padding: 12px;}
127
- #chat-wrap {position: relative;}
128
- #chat-input-row {transform: translateY(-28px); margin-bottom: -28px;}
129
-
130
- /* ChatGPT-like Bottom Bar */
131
- #chat-input-row {
132
- align-items: center;
133
- gap: 8px;
134
- padding: 8px 12px;
135
- border: 1px solid rgba(0,0,0,0.08);
136
- border-radius: 9999px;
137
- background: var(--background-primary);
138
- box-shadow: 0 2px 6px rgba(0,0,0,0.06);
139
- }
140
-
141
- /* Textbox inside pill */
142
- #chat-textbox textarea {
143
- min-height: 42px;
144
- max-height: 120px;
145
- border: none !important;
146
- background: transparent !important;
147
- box-shadow: none !important;
148
- resize: none;
149
- padding-left: 0;
150
- }
151
-
152
- /* Icon buttons (plus, mic, send) */
153
- .icon-btn, .compact-btn {
154
- width: 32px;
155
- height: 32px;
156
- border-radius: 9999px !important;
157
- display: inline-flex;
158
- align-items: center;
159
- justify-content: center;
160
- border: 1px solid rgba(0,0,0,0.08) !important;
161
- background: #f7f7f8 !important;
162
- box-shadow: none !important;
163
- }
164
- .send-btn {
165
- background: #111 !important;
166
- color: #fff !important;
167
- border-color: #111 !important;
168
- }
169
-
170
- /* Make audio mic compact and borderless */
171
- #chat-audio {min-width: 32px; border: none !important; background: transparent !important;}
172
- #chat-audio .wrap, #chat-audio .audio-wrap, #chat-audio .audio-controls {max-width: 32px;}
173
- #chat-textbox textarea {border: none !important; outline: none !important;}
174
- @media (max-width: 768px) { #chat-input-row {transform: none; margin-bottom: 0;} }
175
- </style>
176
-
177
- <script>
178
- /*
179
- * Voice Activity Detection (VAD) nach Gradio Guide:
180
- * Nutzt @ricky0123/vad-web, um automatisch auf die
181
- * .record-button / .stop-button der Audio-Komponente zu klicken.
182
- */
183
- async function init_vad() {
184
- try {
185
- const script1 = document.createElement("script");
186
- script1.src = "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.14.0/dist/ort.js";
187
- document.head.appendChild(script1);
188
-
189
- const script2 = document.createElement("script");
190
- script2.onload = async () => {
191
- console.log("VAD JS geladen");
192
- const recordButton = document.querySelector('.record-button');
193
- if (recordButton) {
194
- recordButton.textContent = "Just start talking";
195
- }
196
- const myvad = await vad.MicVAD.new({
197
- onSpeechStart: () => {
198
- const record = document.querySelector('.record-button');
199
- const player = document.querySelector('#streaming-out');
200
- if (record && (!player || player.paused)) {
201
- console.log("VAD: speech start → record.click()");
202
- record.click();
203
- }
204
- },
205
- onSpeechEnd: (audio) => {
206
- const stop = document.querySelector('.stop-button');
207
- if (stop) {
208
- console.log("VAD: speech end → stop.click()");
209
- stop.click();
210
- }
211
- }
212
- });
213
- myvad.start();
214
- };
215
- script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/vad-web@0.0.7/dist/bundle.min.js";
216
- document.head.appendChild(script2);
217
- } catch (e) {
218
- console.log("VAD init Fehler:", e);
219
- }
220
- }
221
- if (typeof window !== "undefined") {
222
- window.addEventListener("load", init_vad);
223
- }
224
- </script>
225
- """
226
 
227
  # =====================================================
228
- # UI – GRADIO (ChatGPT-artig + VAD)
229
  # =====================================================
230
- with gr.Blocks(title="Prüfungsrechts-Chatbot (RAG + Sprache)") as demo:
231
- gr.HTML(CUSTOM_STYLE_AND_VAD)
232
-
233
- gr.Markdown("# 🧑‍⚖️ Prüfungsrechts-Chatbot")
234
- gr.Markdown(
235
- "Dieser Chatbot beantwortet Fragen **ausschließlich** aus der "
236
- "Prüfungsordnung (PDF) und dem Hochschulgesetz NRW. "
237
- "Du kannst Text eingeben oder einfach anfangen zu sprechen – "
238
- "die Aufnahme startet/stopt automatisch (Voice Activity Detection)."
239
- )
240
-
241
- with gr.Column(elem_id="chat-wrap"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  chatbot = gr.Chatbot(
243
- label="Chat",
244
- height=380,
 
 
245
  )
246
-
247
- # globaler State für Konversation usw.
248
- state = gr.State(value=AppState())
249
-
250
- # Eingabezeile à la ChatGPT: Plus + Text + Mikro + Senden
251
- with gr.Row(elem_id="chat-input-row"):
252
- attach_btn = gr.UploadButton(
253
- "+",
254
- file_types=["file"],
255
- file_count="multiple",
256
- elem_classes=["icon-btn"],
257
- scale=1,
258
  )
 
 
259
  chat_text = gr.Textbox(
260
- elem_id="chat-textbox",
261
  label=None,
262
- placeholder="Stelle irgendeine Frage oder sprich einfach los …",
263
  lines=1,
264
- max_lines=6,
265
- autofocus=True,
266
  scale=8,
 
267
  )
 
 
268
  chat_audio = gr.Audio(
269
- elem_id="chat-audio",
270
- label="🎤",
271
  sources=["microphone"],
272
  type="filepath",
273
  format="wav",
274
- streaming=False, # wichtig: record/stop Buttons für VAD
275
  interactive=True,
276
- scale=1,
277
  show_label=False,
 
278
  )
279
- send_btn = gr.Button(
280
- "➤",
281
- elem_classes=["compact-btn", "send-btn"],
282
- scale=1,
283
- )
284
-
285
- # Senden bei Enter (Text)
286
- chat_text.submit(
287
- chat_fn,
288
- [chat_text, chat_audio, chatbot, state],
289
- [chatbot, state, chat_text, chat_audio],
290
- )
291
- # Audio-Stop (manuell oder durch VAD) → ganze Pipeline
292
- chat_audio.change(
293
- chat_fn,
294
- [chat_text, chat_audio, chatbot, state],
295
- [chatbot, state, chat_text, chat_audio],
296
- )
297
- # Senden-Button
298
- send_btn.click(
299
- chat_fn,
300
- [chat_text, chat_audio, chatbot, state],
301
- [chatbot, state, chat_text, chat_audio],
302
- )
303
-
304
- # Quellen & Dokumente kompakt unterhalb
305
- with gr.Accordion("Quellen & Dokumente", open=False):
306
- gr.Markdown("### 📄 Prüfungsordnung (PDF)")
307
- PDF(pdf_meta["pdf_url"], height=250)
308
- gr.Markdown("### 📘 Hochschulgesetz NRW")
309
- if isinstance(hg_url, str) and hg_url.startswith("http"):
310
- gr.Markdown(f"[Im Viewer öffnen]({hg_url})")
311
- else:
312
- gr.Markdown("Viewer-Link nicht verfügbar.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  if __name__ == "__main__":
315
- demo.queue().launch(ssr_mode=False, show_error=True)
 
 
 
 
 
 
1
+ # app.py – Prüfungsrechts-Chatbot (RAG + Sprache, UI kiểu ChatGPT) với các tính năng nâng cao
2
+ #
 
 
3
  import os
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Dict, Any
7
  import gradio as gr
8
  from gradio_pdf import PDF
9
+ import numpy as np
10
 
11
  from load_documents import load_all_documents
12
  from split_documents import split_documents
 
14
  from retriever import get_retriever
15
  from llm import load_llm
16
  from rag_pipeline import answer
17
+ from speech_io import transcribe_audio, synthesize_speech, transcribe_with_groq, detect_voice_activity
18
+
19
+ # Cấu hình môi trường
20
+ ASR_LANGUAGE_HINT = os.getenv("ASR_LANGUAGE", "de")
21
+ USE_GROQ = os.getenv("USE_GROQ", "false").lower() == "true"
22
+ GROQ_MODEL = os.getenv("GROQ_MODEL", "whisper-large-v3-turbo")
23
+ ENABLE_VAD = os.getenv("ENABLE_VAD", "true").lower() == "true"
24
+ VAD_THRESHOLD = float(os.getenv("VAD_THRESHOLD", "0.5"))
25
 
26
+ # =====================================================
27
+ # STATE MANAGEMENT - Quản lý trạng thái hội thoại liền mạch
28
+ # =====================================================
29
+ @dataclass
30
+ class ConversationState:
31
+ """Quản lý trạng thái hội thoại"""
32
+ messages: list
33
+ last_audio_time: float
34
+ is_listening: bool
35
+ vad_confidence: float
36
+ conversation_context: str
37
+ whisper_model: str
38
+ language: str
39
+
40
+ def __init__(self):
41
+ self.messages = []
42
+ self.last_audio_time = 0
43
+ self.is_listening = False
44
+ self.vad_confidence = 0.0
45
+ self.conversation_context = ""
46
+ self.whisper_model = os.getenv("WHISPER_MODEL", "base")
47
+ self.language = ASR_LANGUAGE_HINT
48
+
49
+ def add_message(self, role: str, content: str):
50
+ """Thêm message vào hội thoại"""
51
+ self.messages.append({
52
+ "role": role,
53
+ "content": content,
54
+ "timestamp": time.time()
55
+ })
56
+ # Cập nhật context (giữ lại 5 message gần nhất)
57
+ if len(self.messages) > 10:
58
+ self.messages = self.messages[-10:]
59
+
60
+ # Cập nhật context cho hội thoại
61
+ self._update_context()
62
+
63
+ def _update_context(self):
64
+ """Cập nhật context từ hội thoại"""
65
+ context_parts = []
66
+ for msg in self.messages[-5:]: # Giữ 5 message gần nhất
67
+ prefix = "User" if msg["role"] == "user" else "Assistant"
68
+ context_parts.append(f"{prefix}: {msg['content']}")
69
+ self.conversation_context = "\n".join(context_parts)
70
+
71
+ def get_recent_context(self, num_messages: int = 3) -> str:
72
+ """Lấy context gần đây"""
73
+ recent = self.messages[-num_messages:] if self.messages else []
74
+ return "\n".join([f"{m['role']}: {m['content']}" for m in recent])
75
+
76
+ def reset(self):
77
+ """Reset trạng thái hội thoại"""
78
+ self.messages = []
79
+ self.conversation_context = ""
80
+
81
+ # Khởi tạo state
82
+ state = ConversationState()
83
 
84
  # =====================================================
85
  # INITIALISIERUNG (global)
 
105
  hg_meta = next(d.metadata for d in docs if d.metadata.get("type") == "hg")
106
  hg_url = hg_meta.get("viewer_url")
107
 
108
+ # =====================================================
109
+ # BENCHMARK WHISPER MODELS
110
+ # =====================================================
111
+ def benchmark_whisper_models(audio_path: str) -> Dict[str, Any]:
112
+ """Benchmark các model Whisper khác nhau"""
113
+ import torch
114
+ from transformers import pipeline
115
+
116
+ models_to_test = ["tiny", "base", "small", "medium"]
117
+ results = {}
118
+
119
+ for model_size in models_to_test:
120
+ model_id = f"openai/whisper-{model_size}"
121
+
122
+ try:
123
+ print(f"Testing {model_id}...")
124
+
125
+ # Measure memory usage
126
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
127
+ memory_before = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
128
+
129
+ # Load and transcribe
130
+ start_time = time.time()
131
+
132
+ asr_pipeline = pipeline(
133
+ task="automatic-speech-recognition",
134
+ model=model_id,
135
+ device="cpu",
136
+ return_timestamps=False,
137
+ chunk_length_s=8,
138
+ stride_length_s=(1, 1),
139
+ )
140
+
141
+ # Load audio
142
+ import soundfile as sf
143
+ data, sr = sf.read(audio_path)
144
+
145
+ # Transcribe
146
+ result = asr_pipeline({"array": data, "sampling_rate": sr})
147
+ transcription = result.get("text", "")
148
+
149
+ end_time = time.time()
150
+
151
+ # Memory after
152
+ memory_after = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
153
+
154
+ results[model_size] = {
155
+ "transcription": transcription,
156
+ "time_taken": end_time - start_time,
157
+ "memory_used": memory_after - memory_before,
158
+ "model_size": model_size
159
+ }
160
+
161
+ print(f" Time: {end_time - start_time:.2f}s")
162
+
163
+ except Exception as e:
164
+ print(f" Error with {model_id}: {e}")
165
+ results[model_size] = {"error": str(e)}
166
+
167
+ return results
168
+
169
+ # =====================================================
170
+ # VOICE ACTIVITY DETECTION
171
+ # =====================================================
172
+ def handle_voice_activity(audio_data: Optional[np.ndarray], sample_rate: int) -> Dict[str, Any]:
173
+ """Xử lý phát hiện hoạt động giọng nói"""
174
+ if audio_data is None or len(audio_data) == 0:
175
+ return {"is_speech": False, "confidence": 0.0}
176
+
177
+ vad_result = detect_voice_activity(audio_data, sample_rate, threshold=VAD_THRESHOLD)
178
+
179
+ # Cập nhật state
180
+ if vad_result["is_speech"]:
181
+ state.last_audio_time = time.time()
182
+ state.vad_confidence = vad_result["confidence"]
183
+
184
+ return vad_result
185
+
186
+ # =====================================================
187
+ # TRANSCRIBE WITH OPTIMIZED PIPELINE
188
+ # =====================================================
189
+ def transcribe_audio_optimized(audio_path: str, language: Optional[str] = None) -> str:
190
+ """Transcribe audio với pipeline tối ưu"""
191
+ if USE_GROQ:
192
+ print("Using Groq for transcription...")
193
+ return transcribe_with_groq(audio_path, language=language)
194
+ else:
195
+ return transcribe_audio(audio_path, language=language)
196
+
197
+ # =====================================================
198
+ # CONVERSATIONAL INTELLIGENCE
199
+ # =====================================================
200
+ def enhance_conversation_context(user_input: str, history: list) -> str:
201
+ """Tăng cường context hội thoại với LLM"""
202
+ # Tạo prompt có context
203
+ context = state.get_recent_context(3)
204
+
205
+ prompt = f"""Context from previous conversation:
206
+ {context}
207
+
208
+ Current user input: {user_input}
209
+
210
+ Based on the context, provide a concise summary or additional context that might help answer this question better:"""
211
+
212
+ # Gọi LLM để xử lý context (có thể dùng model nhỏ hơn cho việc này)
213
+ try:
214
+ # Ở đây có thể tích hợp với một LLM nhỏ để xử lý context
215
+ # Tạm thời trả về context đơn giản
216
+ if context:
217
+ return f"Context from conversation: {context}\n\nQuestion: {user_input}"
218
+ else:
219
+ return user_input
220
+ except:
221
+ return user_input
222
+
223
  # =====================================================
224
  # Quellen formatieren – Markdown für Chat
225
  # =====================================================
 
238
  return "\n".join(out)
239
 
240
  # =====================================================
241
+ # CORE CHAT-FUNKTION với tất cả tính năng mới
 
 
 
 
 
 
 
 
 
242
  # =====================================================
243
+ def chat_fn(text_input, audio_path, history, lang_sel, use_vad):
244
+ """
245
+ text_input: Textbox-Inhalt (str)
246
+ audio_path: Pfad zu WAV/FLAC vom Mikro (gr.Audio, type="filepath")
247
+ history: Liste von OpenAI-ähnlichen Messages (role, content)
248
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  text = (text_input or "").strip()
250
+
251
+ # Xử VAD nếu được bật
252
+ if use_vad and ENABLE_VAD and audio_path:
253
+ import soundfile as sf
254
+ try:
255
+ audio_data, sample_rate = sf.read(audio_path)
256
+ vad_result = handle_voice_activity(audio_data, sample_rate)
257
+
258
+ if vad_result["is_speech"]:
259
+ print(f"Voice activity detected with confidence: {vad_result['confidence']:.2f}")
260
+ else:
261
+ print("No voice activity detected")
262
+ if not text:
263
+ return history, "", None, "Bereit (keine Sprache erkannt)"
264
+ except Exception as e:
265
+ print(f"VAD error: {e}")
266
+
267
+ # Transcribe audio nếu có
268
+ if (not text) and audio_path:
269
+ state.last_audio_time = time.time()
270
+
271
+ # Chọn phương thức transcribe
272
+ if USE_GROQ:
273
+ spoken = transcribe_with_groq(audio_path, language=lang_sel)
274
+ else:
275
+ spoken = transcribe_audio(audio_path, language=lang_sel)
276
+
277
+ text = spoken.strip()
278
+
279
+ if text:
280
+ # Tăng cường context hội thoại
281
+ enhanced_text = enhance_conversation_context(text, history)
282
+ state.add_message("user", text)
283
+ print(f"✅ Transkribiert: {text}")
284
 
285
  if not text:
286
+ return history, "", None, "Bereit"
287
 
288
+ # Tăng cường context cho câu hỏi
289
+ question_with_context = enhance_conversation_context(text, history)
290
+
291
+ # RAG-Antwort berechnen với context
292
+ ans, sources = answer(question_with_context, retriever, llm)
293
  bot_msg = ans + format_sources(sources)
294
+
295
+ # Thêm vào state
296
+ state.add_message("assistant", ans)
297
 
298
+ # History aktualisieren (ChatGPT-Style)
299
+ history = history + [
300
+ {"role": "user", "content": text},
301
+ {"role": "assistant", "content": bot_msg},
302
+ ]
303
 
304
+ status_text = f"Bereit | Model: {state.whisper_model} | VAD: {'On' if use_vad else 'Off'}"
305
+ return history, "", None, status_text
 
306
 
307
+ # =====================================================
308
+ # FUNCTIONS FOR UI CONTROLS
309
+ # =====================================================
310
+ def toggle_vad(use_vad):
311
+ """Toggle Voice Activity Detection"""
312
+ global ENABLE_VAD
313
+ ENABLE_VAD = use_vad
314
+ status = "EIN" if use_vad else "AUS"
315
+ return f"Voice Activity Detection: {status}"
316
+
317
+ def change_whisper_model(model_size):
318
+ """Đổi Whisper model"""
319
+ state.whisper_model = model_size
320
+ os.environ["WHISPER_MODEL"] = model_size
321
+ return f"Whisper Model: {model_size}"
322
+
323
+ def run_benchmark(audio_path):
324
+ """Chạy benchmark các model Whisper"""
325
+ if not audio_path:
326
+ return "Bitte wählen Sie eine Audiodatei für den Benchmark aus."
327
+
328
+ results = benchmark_whisper_models(audio_path)
329
+
330
+ # Format results
331
+ report = ["## 📊 Whisper Model Benchmark", ""]
332
+ for model_size, result in results.items():
333
+ if "error" in result:
334
+ report.append(f"**{model_size}**: Fehler - {result['error']}")
335
+ else:
336
+ report.append(
337
+ f"**{model_size}**: {result['time_taken']:.2f}s | "
338
+ f"Speicher: {result['memory_used'] / 1024**2:.1f}MB | "
339
+ f"Text: {result['transcription'][:100]}..."
340
+ )
341
+
342
+ return "\n".join(report)
343
 
344
+ def clear_conversation():
345
+ """Xóa hội thoại"""
346
+ state.reset()
347
+ return [], "Hội thoại đã được xóa"
348
 
349
  # =====================================================
350
+ # LAST ANSWER TTS (für Button "Antwort erneut vorlesen")
351
  # =====================================================
352
+ def read_last_answer(history):
353
+ if not history:
354
+ return None
355
+
356
+ for msg in reversed(history):
357
+ if msg.get("role") == "assistant":
358
+ return synthesize_speech(msg.get("content", ""))
359
+
360
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  # =====================================================
363
+ # UI – GRADIO với tất cả tính năng mới
364
  # =====================================================
365
+ with gr.Blocks(title="Prüfungsrechts-Chatbot (RAG + Sprache) - Enhanced", theme=gr.themes.Soft()) as demo:
366
+ # CSS Styling nâng cao
367
+ gr.HTML("""
368
+ <style>
369
+ .gradio-container {
370
+ max-width: 1200px;
371
+ margin: 0 auto;
372
+ padding: 20px;
373
+ }
374
+
375
+ .header {
376
+ text-align: center;
377
+ margin-bottom: 30px;
378
+ }
379
+
380
+ .control-panel {
381
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
382
+ padding: 20px;
383
+ border-radius: 15px;
384
+ margin-bottom: 20px;
385
+ color: white;
386
+ }
387
+
388
+ .stats-bar {
389
+ background: #f8f9fa;
390
+ border-radius: 10px;
391
+ padding: 10px;
392
+ margin: 10px 0;
393
+ border-left: 4px solid #667eea;
394
+ }
395
+
396
+ .vad-indicator {
397
+ display: inline-block;
398
+ width: 12px;
399
+ height: 12px;
400
+ border-radius: 50%;
401
+ margin-right: 8px;
402
+ }
403
+
404
+ .vad-active {
405
+ background-color: #10b981;
406
+ box-shadow: 0 0 10px #10b981;
407
+ }
408
+
409
+ .vad-inactive {
410
+ background-color: #ef4444;
411
+ }
412
+
413
+ .model-selector {
414
+ background: white;
415
+ padding: 15px;
416
+ border-radius: 10px;
417
+ margin: 10px 0;
418
+ }
419
+
420
+ .chat-container {
421
+ background: white;
422
+ border-radius: 15px;
423
+ padding: 20px;
424
+ box-shadow: 0 10px 40px rgba(0,0,0,0.1);
425
+ }
426
+
427
+ .input-row {
428
+ background: #f8fafc;
429
+ border-radius: 25px;
430
+ padding: 5px 20px;
431
+ border: 2px solid #e2e8f0;
432
+ transition: all 0.3s ease;
433
+ }
434
+
435
+ .input-row:focus-within {
436
+ border-color: #667eea;
437
+ box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1);
438
+ }
439
+
440
+ .feature-badge {
441
+ display: inline-block;
442
+ padding: 4px 12px;
443
+ background: #e0e7ff;
444
+ color: #4f46e5;
445
+ border-radius: 20px;
446
+ font-size: 12px;
447
+ margin: 2px;
448
+ }
449
+ </style>
450
+ """)
451
+
452
+ # Header
453
+ with gr.Column(elem_classes=["header"]):
454
+ gr.Markdown("# 🧑‍⚖️ Prüfungsrechts-Chatbot - Enhanced")
455
+ gr.Markdown("### Intelligent Voice Interface with Advanced Features")
456
+
457
+ # Feature badges
458
+ gr.HTML("""
459
+ <div style="text-align: center; margin: 10px 0;">
460
+ <span class="feature-badge">🎤 Voice Activity Detection</span>
461
+ <span class="feature-badge">⚡ Groq Optimization</span>
462
+ <span class="feature-badge">🧠 Conversational AI</span>
463
+ <span class="feature-badge">📊 Model Benchmarking</span>
464
+ <span class="feature-badge">🔄 State Management</span>
465
+ </div>
466
+ """)
467
+
468
+ # Control Panel
469
+ with gr.Column(elem_classes=["control-panel"]):
470
+ gr.Markdown("### 🎛️ Control Panel")
471
+
472
+ with gr.Row():
473
+ with gr.Column(scale=2):
474
+ # Model Selection
475
+ model_selector = gr.Dropdown(
476
+ choices=["tiny", "base", "small", "medium"],
477
+ value=state.whisper_model,
478
+ label="Whisper Model",
479
+ info="Chọn model cho speech recognition"
480
+ )
481
+
482
+ # VAD Control
483
+ vad_toggle = gr.Checkbox(
484
+ value=ENABLE_VAD,
485
+ label="Enable Voice Activity Detection",
486
+ info="Tự động phát hiện khi người dùng nói"
487
+ )
488
+
489
+ # Language Selection
490
+ lang_selector = gr.Dropdown(
491
+ choices=["de", "en", "auto"],
492
+ value=ASR_LANGUAGE_HINT,
493
+ label="Speech Recognition Language"
494
+ )
495
+
496
+ with gr.Column(scale=1):
497
+ # Stats Display
498
+ status_display = gr.Textbox(
499
+ label="System Status",
500
+ value="Bereit",
501
+ interactive=False
502
+ )
503
+
504
+ # Clear Conversation Button
505
+ clear_btn = gr.Button("🗑️ Clear Conversation", variant="secondary")
506
+
507
+ # Benchmark Section
508
+ benchmark_audio = gr.Audio(
509
+ label="Benchmark Audio",
510
+ type="filepath",
511
+ visible=False
512
+ )
513
+ benchmark_btn = gr.Button("📊 Run Model Benchmark", variant="secondary")
514
+ benchmark_output = gr.Markdown()
515
+
516
+ # Main Chat Interface
517
+ with gr.Column(elem_classes=["chat-container"]):
518
  chatbot = gr.Chatbot(
519
+ label="Conversation",
520
+ height=400,
521
+ bubble_full_width=False,
522
+ show_copy_button=True
523
  )
524
+
525
+ # Input Row với VAD Indicator
526
+ with gr.Row(elem_classes=["input-row"]):
527
+ # VAD Indicator
528
+ vad_indicator = gr.HTML(
529
+ f"""
530
+ <div class="vad-indicator {'vad-active' if state.is_listening else 'vad-inactive'}"></div>
531
+ <span>VAD: {'Active' if state.is_listening else 'Inactive'}</span>
532
+ """
 
 
 
533
  )
534
+
535
+ # Text Input
536
  chat_text = gr.Textbox(
 
537
  label=None,
538
+ placeholder="Stelle eine Frage oder spreche ins Mikrofon...",
539
  lines=1,
540
+ max_lines=4,
 
541
  scale=8,
542
+ container=False
543
  )
544
+
545
+ # Audio Input
546
  chat_audio = gr.Audio(
 
 
547
  sources=["microphone"],
548
  type="filepath",
549
  format="wav",
550
+ streaming=True,
551
  interactive=True,
 
552
  show_label=False,
553
+ scale=1
554
  )
555
+
556
+ # Send Button
557
+ send_btn = gr.Button("Senden", variant="primary", scale=1)
558
+
559
+ # TTS Controls
560
+ with gr.Row():
561
+ tts_btn = gr.Button("🔊 Antwort vorlesen", variant="secondary")
562
+ tts_audio = gr.Audio(label="Audio Output", interactive=False)
563
+
564
+ # Documents Section
565
+ with gr.Accordion("📚 Quellen & Dokumente", open=False):
566
+ with gr.Tabs():
567
+ with gr.TabItem("Prüfungsordnung (PDF)"):
568
+ PDF(pdf_meta["pdf_url"], height=300)
569
+
570
+ with gr.TabItem("Hochschulgesetz NRW"):
571
+ if isinstance(hg_url, str) and hg_url.startswith("http"):
572
+ gr.Markdown(f"### [Im Viewer öffnen]({hg_url})")
573
+ gr.HTML(f'<iframe src="{hg_url}" width="100%" height="500px"></iframe>')
574
+ else:
575
+ gr.Markdown("Viewer-Link nicht verfügbar.")
576
+
577
+ # Event Handlers
578
+ # Model Selection
579
+ model_selector.change(
580
+ change_whisper_model,
581
+ inputs=[model_selector],
582
+ outputs=[status_display]
583
+ )
584
+
585
+ # VAD Toggle
586
+ vad_toggle.change(
587
+ toggle_vad,
588
+ inputs=[vad_toggle],
589
+ outputs=[status_display]
590
+ )
591
+
592
+ # Clear Conversation
593
+ clear_btn.click(
594
+ clear_conversation,
595
+ outputs=[chatbot, status_display]
596
+ )
597
+
598
+ # Benchmark
599
+ benchmark_btn.click(
600
+ run_benchmark,
601
+ inputs=[benchmark_audio],
602
+ outputs=[benchmark_output]
603
+ )
604
+
605
+ # Main Chat Function
606
+ send_btn.click(
607
+ chat_fn,
608
+ inputs=[chat_text, chat_audio, chatbot, lang_selector, vad_toggle],
609
+ outputs=[chatbot, chat_text, chat_audio, status_display]
610
+ )
611
+
612
+ chat_text.submit(
613
+ chat_fn,
614
+ inputs=[chat_text, chat_audio, chatbot, lang_selector, vad_toggle],
615
+ outputs=[chatbot, chat_text, chat_audio, status_display]
616
+ )
617
+
618
+ # Real-time transcription với VAD
619
+ def handle_streaming_audio(audio_path, use_vad):
620
+ if audio_path and use_vad:
621
+ import soundfile as sf
622
+ try:
623
+ audio_data, sr = sf.read(audio_path)
624
+ vad_result = handle_voice_activity(audio_data, sr)
625
+
626
+ if vad_result["is_speech"]:
627
+ text = transcribe_audio_optimized(audio_path, language=lang_selector.value)
628
+ return text, f"VAD Active (Confidence: {vad_result['confidence']:.2f})"
629
+ except Exception as e:
630
+ print(f"Streaming error: {e}")
631
+
632
+ return "", status_display.value
633
+
634
+ chat_audio.stream(
635
+ handle_streaming_audio,
636
+ inputs=[chat_audio, vad_toggle],
637
+ outputs=[chat_text, status_display]
638
+ )
639
+
640
+ # TTS
641
+ tts_btn.click(
642
+ read_last_answer,
643
+ inputs=[chatbot],
644
+ outputs=[tts_audio]
645
+ )
646
 
647
  if __name__ == "__main__":
648
+ demo.queue(max_size=20).launch(
649
+ server_name="0.0.0.0",
650
+ server_port=7860,
651
+ share=False,
652
+ debug=True
653
+ )
speech_io.py CHANGED
@@ -1,248 +1,626 @@
1
  """
2
- speech_io.py Final FIXED Version
3
- ✔ No 'prompt' in generate_kwargs
4
- Fully HF Whisper-compatible
5
- Supports Groq Whisper
6
- Stable for HuggingFace Spaces
 
 
7
  """
8
 
9
  import os
 
 
10
  import numpy as np
11
  import soundfile as sf
12
- import difflib
13
  import re
14
- from scipy.signal import butter, filtfilt, resample
15
- from transformers import pipeline
16
-
17
- # Optional Groq import
18
- try:
19
- from groq import Groq
20
- except:
21
- Groq = None
22
-
23
- ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny")
24
- USE_GROQ = os.getenv("USE_GROQ_WHISPER", "0").lower() in ("1", "true", "yes")
25
- GROQ_MODEL = os.getenv("GROQ_WHISPER_MODEL", "whisper-large-v3-turbo")
26
- ASR_DEFAULT_LANGUAGE = os.getenv("ASR_LANGUAGE", "de")
27
- ASR_MAX_DURATION_S = int(os.getenv("ASR_MAX_DURATION_S", "30"))
28
-
29
  TTS_MODEL_ID = os.getenv("TTS_MODEL_ID", "facebook/mms-tts-deu")
30
- TTS_ENABLED = os.getenv("TTS_ENABLED", "1").lower() not in ("0", "false", "no")
31
 
32
- _asr = None
33
- _tts = None
34
- _groq = None
 
 
35
 
36
- # ======================================================
37
- # Helpers
38
- # ======================================================
39
 
40
- def butter_highpass_filter(data, cutoff=60, fs=16000, order=4):
41
- nyq = 0.5 * fs
42
- norm_cutoff = cutoff / nyq
43
- b, a = butter(order, norm_cutoff, btype="high")
44
- return filtfilt(b, a, data)
45
-
46
- def apply_fade(audio, sr, ms=10):
47
- n = int(sr * ms / 1000)
48
- if n * 2 >= len(audio):
49
- return audio
50
- fadein = np.linspace(0, 1, n)
51
- fadeout = np.linspace(1, 0, n)
52
- audio[:n] *= fadein
53
- audio[-n:] *= fadeout
54
- return audio
55
 
56
- # ======================================================
57
- # Whisper LOCAL
58
- # ======================================================
59
 
60
- def get_asr_pipeline():
61
- global _asr
62
- if _asr is None:
63
- print(f">>> Lade lokales Whisper-Modell: {ASR_MODEL_ID}")
64
- _asr = pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  task="automatic-speech-recognition",
66
- model=ASR_MODEL_ID,
67
  device="cpu",
68
  return_timestamps=False,
69
  chunk_length_s=8,
70
  stride_length_s=(1, 1),
71
  )
72
- return _asr
73
-
74
- # ======================================================
75
- # Whisper GROQ
76
- # ======================================================
77
-
78
- def get_groq_client():
79
- global _groq
80
- if _groq is None:
81
- key = os.getenv("GROQ_API_KEY")
82
- if not key:
83
- raise RuntimeError("GROQ_API_KEY fehlt.")
84
- _groq = Groq(api_key=key)
85
- print(">>> Groq Client bereit.")
86
- return _groq
87
-
88
- def _groq_transcribe(audio_path, language):
89
- client = get_groq_client()
90
-
91
- lang = None
92
- if language and language.lower() != "auto":
93
- lang = language
94
 
95
- with open(audio_path, "rb") as f:
96
- resp = client.audio.transcriptions.with_raw_response.create(
97
- file=("audio.wav", f),
98
- model=GROQ_MODEL,
99
- response_format="verbose_json",
100
- language=lang,
101
- ).parse()
102
-
103
- segments = resp.segments or []
104
- if not segments:
105
- return ""
106
-
107
- if segments[0].get("no_speech_prob", 0) > 0.7:
108
- return ""
109
-
110
- return resp.text.strip()
111
-
112
- # ======================================================
113
- # LOCAL WHISPER STT
114
- # ======================================================
115
 
116
- def _local_transcribe(audio_path, language, max_duration_s):
117
- data, sr = sf.read(audio_path, always_2d=False)
 
 
 
 
 
 
 
 
 
 
118
 
119
- if data.ndim > 1:
120
- data = data.mean(axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- data = np.asarray(data, dtype=np.float32)
123
- data = np.clip(data, -1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  try:
126
- data = butter_highpass_filter(data, 60, sr)
127
- except:
128
- pass
129
-
130
- m = np.max(np.abs(data))
131
- if m > 0:
132
- data = data / m
133
-
134
- rms = float(np.sqrt(np.mean(data ** 2)))
135
- if rms < 5e-5:
136
- print(">>> Audio zu leise.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  return ""
138
-
139
- if sr != 16000:
140
- target_len = int(len(data) * 16000 / sr)
141
- data = resample(data, target_len)
142
- sr = 16000
143
-
144
- idx = np.where(np.abs(data) > 0.02)[0]
145
- if idx.size:
146
- data = data[idx[0]: idx[-1] + 1]
147
-
148
- dur = len(data) / sr
149
- if dur < 0.3:
150
- print(">>> Audio zu kurz.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  return ""
152
 
153
- if len(data) > sr * max_duration_s:
154
- data = data[: sr * max_duration_s]
155
-
156
- asr = get_asr_pipeline()
157
-
158
- # ---- FIXED generate_kwargs: ALLOWED ONLY ----
159
- gen = {
160
- "task": "transcribe",
161
- "temperature": 0.0,
162
- "num_beams": 1,
163
- "compression_ratio_threshold": 2.4,
164
- "logprob_threshold": -1.0,
165
- "no_speech_threshold": 0.6,
166
- "no_repeat_ngram_size": 3,
167
- }
168
-
169
- if language and language.lower() != "auto":
170
- gen["language"] = language
171
-
172
- print(">>> Transkribiere Audio (lokal)…")
173
-
174
- result = asr(
175
- {"array": data, "sampling_rate": sr},
176
- generate_kwargs=gen,
177
- )
178
-
179
- text = (result.get("text", "") if isinstance(result, dict) else result).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Domain cleanup
182
- vocab = [
183
- "prüfung","prüfungsordnung","hochschulgesetz","modul","klausur",
184
- "immatrikulation","exmatrikulation","anmeldung","wiederholung"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  ]
186
-
 
 
 
 
 
 
 
 
 
 
 
 
187
  tokens = text.split()
188
- fixed = []
189
- for t in tokens:
190
- m = difflib.get_close_matches(t.lower(), vocab, n=1, cutoff=0.82)
191
- fixed.append(m[0] if m else t)
192
-
193
- return " ".join(fixed)
194
-
195
- # ======================================================
196
- # PUBLIC STT WRAPPER
197
- # ======================================================
198
-
199
- def transcribe_audio(audio_path, language=None, max_duration_s=ASR_MAX_DURATION_S):
200
- if not audio_path:
201
- return ""
202
-
203
- # Try Groq first
204
- if USE_GROQ:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  try:
206
- return _groq_transcribe(audio_path, language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  except Exception as e:
208
- print("Groq Fehler → fallback lokal:", e)
209
-
210
- return _local_transcribe(audio_path, language, max_duration_s)
211
-
212
- # ======================================================
213
- # TTS
214
- # ======================================================
215
-
216
- def get_tts_pipeline():
217
- global _tts
218
- if _tts is None:
219
- print(">>> Lade TTS:", TTS_MODEL_ID)
220
- _tts = pipeline("text-to-speech", model=TTS_MODEL_ID)
221
- return _tts
222
-
223
- def synthesize_speech(text: str):
224
- if not text or not TTS_ENABLED:
225
- return None
226
-
227
- tts = get_tts_pipeline()
228
- out = tts(text)
229
-
230
- audio = np.array(out["audio"], dtype=np.float32)
231
- sr = out.get("sampling_rate", 16000)
232
-
233
- if audio.ndim > 1:
234
- audio = audio.squeeze()
235
-
236
- try:
237
- audio = butter_highpass_filter(audio, 60, sr)
238
- except:
239
- pass
240
-
241
- maxv = np.max(np.abs(audio))
242
- if maxv > 0:
243
- audio = audio / maxv
244
-
245
- audio = apply_fade(audio, sr)
246
- audio = np.clip(audio * 32767, -32768, 32767).astype(np.int16)
247
-
248
- return sr, audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ speech_io.py - Enhanced Version
3
+
4
+ Sprachbasierte Ein-/Ausgabe với:
5
+ - Speech-to-Text (STT) với Whisper (nhiều phiên bản + Groq)
6
+ - Text-to-Speech (TTS) với MMS-TTS Deutsch
7
+ - Voice Activity Detection (VAD)
8
+ - Model Benchmarking
9
  """
10
 
11
  import os
12
+ import time
13
+ from typing import Optional, Tuple, Dict, Any, Union
14
  import numpy as np
15
  import soundfile as sf
16
+ from scipy.signal import butter, filtfilt, resample, sosfiltfilt
17
  import re
18
+ import difflib
19
+ import requests
20
+ import json
21
+ from dataclasses import dataclass
22
+
23
+ # ========================================================
24
+ # CẤU HÌNH
25
+ # ========================================================
26
+ # Model Selection
27
+ WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") # tiny, base, small, medium
28
+ ASR_MODEL_ID = f"openai/whisper-{WHISPER_MODEL}"
 
 
 
 
29
  TTS_MODEL_ID = os.getenv("TTS_MODEL_ID", "facebook/mms-tts-deu")
 
30
 
31
+ # Groq Configuration
32
+ USE_GROQ = os.getenv("USE_GROQ", "false").lower() == "true"
33
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
34
+ GROQ_MODEL = os.getenv("GROQ_MODEL", "whisper-large-v3-turbo")
35
+ GROQ_API_URL = "https://api.groq.com/openai/v1/audio/transcriptions"
36
 
37
+ # VAD Configuration
38
+ ENABLE_VAD = os.getenv("ENABLE_VAD", "true").lower() == "true"
39
+ VAD_THRESHOLD = float(os.getenv("VAD_THRESHOLD", "0.5"))
40
 
41
+ # Other Configs
42
+ ASR_DEFAULT_LANGUAGE = os.getenv("ASR_LANGUAGE", "de")
43
+ TTS_ENABLED = os.getenv("TTS_ENABLED", "1").lower() not in ("0", "false", "no")
44
+ ASR_PROMPT = os.getenv("ASR_PROMPT", "Dies ist ein Diktat in deutscher Sprache.")
45
+ ASR_MAX_DURATION_S = int(os.getenv("ASR_MAX_DURATION_S", "30"))
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Cache for models
48
+ _asr_cache = {}
49
+ _tts = None
50
 
51
+ # ========================================================
52
+ # DATA CLASSES
53
+ # ========================================================
54
+ @dataclass
55
+ class TranscriptionResult:
56
+ text: str
57
+ confidence: float
58
+ language: str
59
+ processing_time: float
60
+ model: str
61
+
62
+ @dataclass
63
+ class VADResult:
64
+ is_speech: bool
65
+ confidence: float
66
+ speech_segments: list
67
+ energy: float
68
+
69
+ # ========================================================
70
+ # MODEL LOADING WITH CACHE
71
+ # ========================================================
72
+ def get_asr_pipeline(model_size: str = None):
73
+ """Lấy ASR pipeline với cache"""
74
+ global _asr_cache
75
+
76
+ if model_size is None:
77
+ model_size = WHISPER_MODEL
78
+
79
+ model_id = f"openai/whisper-{model_size}"
80
+
81
+ if model_id not in _asr_cache:
82
+ print(f">>> Lade ASR Modell: {model_id}")
83
+
84
+ from transformers import pipeline
85
+
86
+ _asr_cache[model_id] = pipeline(
87
  task="automatic-speech-recognition",
88
+ model=model_id,
89
  device="cpu",
90
  return_timestamps=False,
91
  chunk_length_s=8,
92
  stride_length_s=(1, 1),
93
  )
94
+
95
+ return _asr_cache[model_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ def get_tts_pipeline():
98
+ """Lấy TTS pipeline"""
99
+ global _tts
100
+ if _tts is None:
101
+ print(f">>> Lade TTS Modell: {TTS_MODEL_ID}")
102
+
103
+ from transformers import pipeline
104
+
105
+ _tts = pipeline(
106
+ task="text-to-speech",
107
+ model=TTS_MODEL_ID,
108
+ )
109
+ return _tts
 
 
 
 
 
 
 
110
 
111
+ # ========================================================
112
+ # AUDIO PROCESSING UTILITIES
113
+ # ========================================================
114
+ def butter_highpass_filter(data, cutoff=60, fs=16000, order=4):
115
+ """Highpass filter để loại bỏ noise tần số thấp"""
116
+ if len(data) == 0:
117
+ return data
118
+
119
+ nyq = 0.5 * fs
120
+ norm_cutoff = cutoff / nyq
121
+ sos = butter(order, norm_cutoff, btype="high", output='sos')
122
+ return sosfiltfilt(sos, data)
123
 
124
+ def apply_fade(audio, sr, fade_in_ms=10, fade_out_ms=10):
125
+ """Áp dụng fade in/out để tránh pop"""
126
+ if len(audio) == 0:
127
+ return audio
128
+
129
+ fade_in_samples = int(sr * fade_in_ms / 1000)
130
+ fade_out_samples = int(sr * fade_out_ms / 1000)
131
+
132
+ if fade_in_samples * 2 >= len(audio):
133
+ return audio
134
+
135
+ # Fade in
136
+ fade_in_curve = np.linspace(0, 1, fade_in_samples)
137
+ audio[:fade_in_samples] *= fade_in_curve
138
+
139
+ # Fade out
140
+ fade_out_curve = np.linspace(1, 0, fade_out_samples)
141
+ audio[-fade_out_samples:] *= fade_out_curve
142
+
143
+ return audio
144
 
145
+ def normalize_audio(audio_data: np.ndarray) -> np.ndarray:
146
+ """Chuẩn hóa audio"""
147
+ if len(audio_data) == 0:
148
+ return audio_data
149
+
150
+ # Chuyển đổi sang float32 nếu cần
151
+ if audio_data.dtype != np.float32:
152
+ audio_data = audio_data.astype(np.float32)
153
+
154
+ # Normalize về [-1, 1]
155
+ max_val = np.max(np.abs(audio_data))
156
+ if max_val > 0:
157
+ audio_data = audio_data / max_val
158
+
159
+ return audio_data
160
+
161
+ def resample_audio(audio_data: np.ndarray, orig_sr: int, target_sr: int = 16000) -> np.ndarray:
162
+ """Resample audio về target sample rate"""
163
+ if orig_sr == target_sr:
164
+ return audio_data
165
+
166
+ target_len = int(len(audio_data) * target_sr / orig_sr)
167
+ return resample(audio_data, target_len)
168
+
169
+ # ========================================================
170
+ # VOICE ACTIVITY DETECTION (VAD)
171
+ # ========================================================
172
+ def detect_voice_activity(
173
+ audio_data: np.ndarray,
174
+ sample_rate: int,
175
+ threshold: float = 0.5,
176
+ frame_duration_ms: int = 30
177
+ ) -> VADResult:
178
+ """
179
+ Phát hiện hoạt động giọng nói
180
+ """
181
+ if len(audio_data) == 0:
182
+ return VADResult(
183
+ is_speech=False,
184
+ confidence=0.0,
185
+ speech_segments=[],
186
+ energy=0.0
187
+ )
188
+
189
+ # Tính toán energy
190
+ energy = np.mean(audio_data ** 2)
191
+
192
+ # Frame-based analysis
193
+ frame_size = int(sample_rate * frame_duration_ms / 1000)
194
+ num_frames = len(audio_data) // frame_size
195
+
196
+ speech_frames = 0
197
+ speech_segments = []
198
+ current_segment = None
199
+
200
+ for i in range(num_frames):
201
+ start_idx = i * frame_size
202
+ end_idx = start_idx + frame_size
203
+ frame = audio_data[start_idx:end_idx]
204
+
205
+ # Tính frame energy
206
+ frame_energy = np.mean(frame ** 2)
207
+
208
+ # Kiểm tra zero-crossing rate (ZCR) để phân biệt speech/noise
209
+ zcr = np.mean(np.abs(np.diff(np.sign(frame))))
210
+
211
+ # Kết hợp các đặc trưng để phát hiện speech
212
+ is_speech_frame = (frame_energy > threshold * energy) and (zcr < 0.3)
213
+
214
+ if is_speech_frame:
215
+ speech_frames += 1
216
+
217
+ if current_segment is None:
218
+ current_segment = [start_idx / sample_rate, end_idx / sample_rate]
219
+ else:
220
+ current_segment[1] = end_idx / sample_rate
221
+ else:
222
+ if current_segment is not None:
223
+ speech_segments.append(current_segment)
224
+ current_segment = None
225
+
226
+ if current_segment is not None:
227
+ speech_segments.append(current_segment)
228
+
229
+ # Tính confidence
230
+ confidence = speech_frames / max(num_frames, 1)
231
+ is_speech = confidence > 0.1 # Ít nhất 10% frames là speech
232
+
233
+ return VADResult(
234
+ is_speech=is_speech,
235
+ confidence=confidence,
236
+ speech_segments=speech_segments,
237
+ energy=energy
238
+ )
239
 
240
+ def vad_preprocess(audio_data: np.ndarray, sample_rate: int) -> np.ndarray:
241
+ """Tiền xử lý audio cho VAD"""
242
+ # Normalize
243
+ audio_data = normalize_audio(audio_data)
244
+
245
+ # Highpass filter
246
+ audio_data = butter_highpass_filter(audio_data, cutoff=80, fs=sample_rate)
247
+
248
+ return audio_data
249
+
250
+ # ========================================================
251
+ # SPEECH-TO-TEXT CORE FUNCTIONS
252
+ # ========================================================
253
+ def transcribe_with_groq(
254
+ audio_path: str,
255
+ language: Optional[str] = None,
256
+ prompt: Optional[str] = None
257
+ ) -> str:
258
+ """
259
+ Transcribe audio sử dụng Groq Cloud API
260
+ """
261
+ if not GROQ_API_KEY:
262
+ print(">>> Groq API key nicht gefunden. Verwende lokales Modell.")
263
+ return transcribe_audio(audio_path, language)
264
+
265
  try:
266
+ # Đọc audio file
267
+ with open(audio_path, 'rb') as audio_file:
268
+ files = {
269
+ 'file': (os.path.basename(audio_path), audio_file, 'audio/wav')
270
+ }
271
+
272
+ data = {
273
+ 'model': GROQ_MODEL,
274
+ 'response_format': 'json',
275
+ }
276
+
277
+ if language and language != 'auto':
278
+ data['language'] = language
279
+
280
+ if prompt:
281
+ data['prompt'] = prompt
282
+
283
+ headers = {
284
+ 'Authorization': f'Bearer {GROQ_API_KEY}'
285
+ }
286
+
287
+ print(f">>> Sende Anfrage an Groq API (Modell: {GROQ_MODEL})...")
288
+ start_time = time.time()
289
+
290
+ response = requests.post(
291
+ GROQ_API_URL,
292
+ headers=headers,
293
+ files=files,
294
+ data=data,
295
+ timeout=30
296
+ )
297
+
298
+ processing_time = time.time() - start_time
299
+
300
+ if response.status_code == 200:
301
+ result = response.json()
302
+ text = result.get('text', '').strip()
303
+ print(f">>> Groq Transkription ({processing_time:.2f}s): {text}")
304
+ return text
305
+ else:
306
+ print(f">>> Groq Fehler {response.status_code}: {response.text}")
307
+ # Fallback to local model
308
+ return transcribe_audio(audio_path, language)
309
+
310
+ except Exception as e:
311
+ print(f">>> Groq Fehler: {e}")
312
+ return transcribe_audio(audio_path, language)
313
+
314
+ def transcribe_audio(
315
+ audio_path: str,
316
+ language: Optional[str] = None,
317
+ model_size: Optional[str] = None,
318
+ max_duration_s: int = ASR_MAX_DURATION_S
319
+ ) -> str:
320
+ """
321
+ Transcribe audio với Whisper local
322
+ """
323
+ if audio_path is None or not os.path.exists(audio_path):
324
+ print(">>> Kein Audio gefunden.")
325
  return ""
326
+
327
+ try:
328
+ # Đọc audio file
329
+ data, sr = sf.read(audio_path, always_2d=False)
330
+
331
+ if data is None or data.size == 0:
332
+ print(">>> Audio leer.")
333
+ return ""
334
+
335
+ # Chuyển sang mono nếu cần
336
+ if len(data.shape) > 1:
337
+ data = np.mean(data, axis=1)
338
+
339
+ # Tiền xử lý audio
340
+ data = normalize_audio(data)
341
+
342
+ # Resample về 16kHz nếu cần
343
+ TARGET_SR = 16000
344
+ if sr != TARGET_SR:
345
+ data = resample_audio(data, sr, TARGET_SR)
346
+ sr = TARGET_SR
347
+
348
+ # Lọc noise
349
+ try:
350
+ data = butter_highpass_filter(data, cutoff=60, fs=sr)
351
+ except:
352
+ pass
353
+
354
+ # Kiểm tra audio quality
355
+ duration_s = len(data) / sr
356
+ rms = float(np.sqrt(np.mean(data ** 2)))
357
+ peak = float(np.max(np.abs(data)))
358
+
359
+ print(f">>> Audio stats – Dauer: {duration_s:.2f}s, RMS: {rms:.6f}, Peak: {peak:.6f}")
360
+
361
+ # Kiểm tra điều kiện tối thiểu
362
+ if duration_s < 0.3 or rms < 3e-4 or peak < 8e-4:
363
+ print(">>> Audio zu kurz oder zu leise.")
364
+ return ""
365
+
366
+ # Giới hạn độ dài
367
+ MAX_SAMPLES = sr * max_duration_s
368
+ if len(data) > MAX_SAMPLES:
369
+ data = data[:MAX_SAMPLES]
370
+ print(f">>> Audio auf {max_duration_s}s gekürzt.")
371
+
372
+ # Chọn model
373
+ if model_size is None:
374
+ model_size = WHISPER_MODEL
375
+
376
+ asr = get_asr_pipeline(model_size)
377
+
378
+ # Cấu hình transcribe
379
+ lang = language
380
+ if not lang and ASR_DEFAULT_LANGUAGE and ASR_DEFAULT_LANGUAGE.lower() != "auto":
381
+ lang = ASR_DEFAULT_LANGUAGE
382
+ if isinstance(lang, str) and lang.lower() == "auto":
383
+ lang = None
384
+
385
+ call_kwargs = {}
386
+
387
+ # Dynamic token budget based on audio length
388
+ token_budget = min(120, int(duration_s * 20))
389
+ if duration_s < 2.0:
390
+ token_budget = 60
391
+ if duration_s < 1.0:
392
+ token_budget = 36
393
+
394
+ if lang:
395
+ call_kwargs["generate_kwargs"] = {
396
+ "language": lang,
397
+ "task": "transcribe",
398
+ "max_new_tokens": token_budget,
399
+ "temperature": 0.0,
400
+ "num_beams": 1,
401
+ "compression_ratio_threshold": 2.4,
402
+ "logprob_threshold": -1.0,
403
+ "no_speech_threshold": 0.6,
404
+ "no_repeat_ngram_size": 3,
405
+ }
406
+
407
+ print(f">>> Transkribiere mit Whisper-{model_size}...")
408
+ start_time = time.time()
409
+
410
+ result = asr({"array": data, "sampling_rate": sr}, **call_kwargs)
411
+
412
+ processing_time = time.time() - start_time
413
+
414
+ text = result.get("text", "") if isinstance(result, dict) else str(result)
415
+ text = text.strip()
416
+
417
+ # Sửa lỗi domain terms
418
+ text = fix_domain_terms(text)
419
+
420
+ print(f">>> Transkription ({processing_time:.2f}s): {text}")
421
+ return text
422
+
423
+ except Exception as e:
424
+ print(f">>> Transkriptionsfehler: {e}")
425
  return ""
426
 
427
+ # ========================================================
428
+ # TEXT-TO-SPEECH (TTS)
429
+ # ========================================================
430
+ def synthesize_speech(text: str) -> Optional[Tuple[int, np.ndarray]]:
431
+ """
432
+ Chuyển text sang speech
433
+ """
434
+ if not text or not text.strip() or not TTS_ENABLED:
435
+ return None
436
+
437
+ try:
438
+ tts = get_tts_pipeline()
439
+
440
+ # TTS inference
441
+ out = tts(text)
442
+
443
+ # Extract audio data
444
+ audio = np.array(out["audio"], dtype=np.float32)
445
+ sr = out.get("sampling_rate", 16000)
446
+
447
+ # Ensure valid sample rate
448
+ if sr is None or sr <= 0 or sr > 65535:
449
+ sr = 16000
450
+
451
+ # Ensure mono
452
+ if audio.ndim > 1:
453
+ audio = audio.squeeze()
454
+ if audio.ndim > 1:
455
+ audio = audio[:, 0]
456
+
457
+ # Apply processing
458
+ try:
459
+ audio = butter_highpass_filter(audio, cutoff=60, fs=sr)
460
+ except:
461
+ pass
462
+
463
+ # Normalize
464
+ max_val = np.max(np.abs(audio))
465
+ if max_val > 0:
466
+ audio = audio / max_val
467
+
468
+ # Apply fade
469
+ audio = apply_fade(audio, sr)
470
+
471
+ # Convert to int16
472
+ audio_int16 = np.clip(audio * 32767, -32768, 32767).astype(np.int16)
473
+
474
+ return (sr, audio_int16)
475
+
476
+ except Exception as e:
477
+ print(f">>> TTS Fehler: {e}")
478
+ return None
479
 
480
+ # ========================================================
481
+ # DOMAIN-SPECIFIC TEXT PROCESSING
482
+ # ========================================================
483
+ def fix_domain_terms(text: str) -> str:
484
+ """
485
+ Sửa lỗi các thuật ngữ chuyên ngành
486
+ """
487
+ if not text:
488
+ return text
489
+
490
+ # Common mis-transcriptions in German academic/legal context
491
+ correction_pairs = [
492
+ (r"\bbriefe\s*um\b", "prüfung"),
493
+ (r"\bbrieft\s*um\b", "prüfung"),
494
+ (r"\bbriefung\b", "prüfung"),
495
+ (r"\bpruefung\b", "prüfung"),
496
+ (r"\bhochschule\s*gesetz\b", "hochschulgesetz"),
497
+ (r"\bmodule\b", "modul"),
498
+ (r"\bklausuren\b", "klausur"),
499
+ (r"\bimmatrikulations\b", "immatrikulation"),
500
+ (r"\bexmatrikulations\b", "exmatrikulation"),
501
  ]
502
+
503
+ for pattern, replacement in correction_pairs:
504
+ text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
505
+
506
+ # Vocabulary matching for domain terms
507
+ domain_vocabulary = [
508
+ "prüfung", "prüfungsordnung", "hochschulgesetz", "modul", "klausur",
509
+ "immatrikulation", "exmatrikulation", "anmeldung", "wiederholung",
510
+ "noten", "semester", "vorlesung", "übung", "praktikum",
511
+ "bachelor", "master", "promotion", "habilitation"
512
+ ]
513
+
514
+ # Simple word-by-word correction
515
  tokens = text.split()
516
+ corrected_tokens = []
517
+
518
+ for token in tokens:
519
+ # Check if token is likely a domain term
520
+ if len(token) > 3: # Only check longer tokens
521
+ matches = difflib.get_close_matches(
522
+ token.lower(),
523
+ domain_vocabulary,
524
+ n=1,
525
+ cutoff=0.8
526
+ )
527
+ if matches:
528
+ corrected_tokens.append(matches[0])
529
+ else:
530
+ corrected_tokens.append(token)
531
+ else:
532
+ corrected_tokens.append(token)
533
+
534
+ return " ".join(corrected_tokens)
535
+
536
+ # ========================================================
537
+ # BENCHMARKING UTILITIES
538
+ # ========================================================
539
+ def benchmark_transcription(
540
+ audio_path: str,
541
+ models: list = ["tiny", "base", "small", "medium"]
542
+ ) -> Dict[str, Dict[str, Any]]:
543
+ """
544
+ Benchmark các model Whisper khác nhau
545
+ """
546
+ results = {}
547
+
548
+ for model_size in models:
549
  try:
550
+ print(f"\n>>> Benchmarking Whisper-{model_size}...")
551
+
552
+ start_time = time.time()
553
+ text = transcribe_audio(audio_path, model_size=model_size)
554
+ processing_time = time.time() - start_time
555
+
556
+ # Đánh giá chất lượng (đơn giản)
557
+ quality_score = estimate_transcription_quality(text)
558
+
559
+ results[model_size] = {
560
+ "text": text,
561
+ "time": processing_time,
562
+ "quality_score": quality_score,
563
+ "word_count": len(text.split()),
564
+ "chars_per_second": len(text) / max(processing_time, 0.001)
565
+ }
566
+
567
+ print(f" Time: {processing_time:.2f}s, Quality: {quality_score:.2f}")
568
+
569
  except Exception as e:
570
+ print(f" Error: {e}")
571
+ results[model_size] = {"error": str(e)}
572
+
573
+ return results
574
+
575
+ def estimate_transcription_quality(text: str) -> float:
576
+ """
577
+ Ước tính chất lượng transcription dựa trên các heuristic
578
+ """
579
+ if not text:
580
+ return 0.0
581
+
582
+ score = 0.0
583
+
584
+ # Length-based score
585
+ word_count = len(text.split())
586
+ if word_count > 3:
587
+ score += 0.3
588
+
589
+ # Domain terms presence
590
+ domain_terms = ["prüfung", "hochschul", "gesetz", "ordnung", "modul"]
591
+ found_terms = sum(1 for term in domain_terms if term in text.lower())
592
+ score += min(0.3, found_terms * 0.1)
593
+
594
+ # Grammar/syntax indicators (German)
595
+ # Check for capital nouns, common sentence endings
596
+ if any(marker in text for marker in [". ", "? ", "! ", ", "]):
597
+ score += 0.2
598
+
599
+ # Word length consistency
600
+ words = text.split()
601
+ avg_word_len = np.mean([len(w) for w in words]) if words else 0
602
+ if 4 <= avg_word_len <= 10:
603
+ score += 0.2
604
+
605
+ return min(1.0, score)
606
+
607
+ # ========================================================
608
+ # MAIN EXPORT
609
+ # ========================================================
610
+ __all__ = [
611
+ 'transcribe_audio',
612
+ 'transcribe_with_groq',
613
+ 'synthesize_speech',
614
+ 'detect_voice_activity',
615
+ 'benchmark_transcription',
616
+ 'fix_domain_terms',
617
+ 'TranscriptionResult',
618
+ 'VADResult'
619
+ ]
620
+
621
+ if __name__ == "__main__":
622
+ # Test functionality
623
+ print("Speech IO Module - Enhanced Version")
624
+ print(f"Whisper Model: {WHISPER_MODEL}")
625
+ print(f"Groq Enabled: {USE_GROQ}")
626
+ print(f"VAD Enabled: {ENABLE_VAD}")