Nguyen5 commited on
Commit
6bb0f73
·
1 Parent(s): cedda96
Files changed (2) hide show
  1. app.py +204 -114
  2. speech_io.py +239 -43
app.py CHANGED
@@ -1,7 +1,11 @@
1
- # app.py – Prüfungsrechts-Chatbot (RAG + Sprache, UI kiểu ChatGPT)
2
- #
 
 
3
  import os
4
- from dataclasses import dataclass
 
 
5
  import gradio as gr
6
  from gradio_pdf import PDF
7
 
@@ -13,7 +17,7 @@ from llm import load_llm
13
  from rag_pipeline import answer
14
  from speech_io import transcribe_audio, synthesize_speech
15
 
16
- ASR_LANGUAGE_HINT = os.getenv("ASR_LANGUAGE", "de") # set to "auto" for detection, or e.g. "en"
17
 
18
  # =====================================================
19
  # INITIALISIERUNG (global)
@@ -39,6 +43,7 @@ pdf_meta = next(d.metadata for d in docs if d.metadata.get("type") == "pdf")
39
  hg_meta = next(d.metadata for d in docs if d.metadata.get("type") == "hg")
40
  hg_url = hg_meta.get("viewer_url")
41
 
 
42
  # =====================================================
43
  # Quellen formatieren – Markdown für Chat
44
  # =====================================================
@@ -56,144 +61,221 @@ def format_sources(src):
56
 
57
  return "\n".join(out)
58
 
 
59
  # =====================================================
60
- # CORE CHAT-FUNKTION (Text + separates Mikro-Audio)
61
  # =====================================================
62
  @dataclass
63
  class AppState:
64
- conversation: list
65
- recording_state: str
66
- mode: str
67
- last_record_path: str | None
68
- status_text: str
69
 
70
- def chat_fn(text_input, audio_path, history, state: AppState, lang_sel):
 
 
 
 
 
 
 
 
71
  """
72
- text_input: Textbox-Inhalt (str)
73
- audio_path: Pfad zu WAV/FLAC vom Mikro (gr.Audio, type="filepath")
74
- history: Liste von OpenAI-ähnlichen Messages (role, content)
 
75
  """
76
  text = (text_input or "").strip()
77
 
78
- if (not text) and audio_path:
79
- state.recording_state = "processing"
80
- state.last_record_path = audio_path
81
- spoken = transcribe_audio(audio_path, language=lang_sel)
82
- text = spoken.strip()
83
- state.status_text = "✅ Verarbeitung abgeschlossen"
84
 
85
  if not text:
86
- # Nichts zu tun
87
- return history, "", None, "", "Bereit"
88
 
89
- # 2) RAG-Antwort berechnen
90
  ans, sources = answer(text, retriever, llm)
91
  bot_msg = ans + format_sources(sources)
92
 
93
- # 3) History aktualisieren (ChatGPT-Style)
94
- history = history + [
95
- {"role": "user", "content": text},
96
- {"role": "assistant", "content": bot_msg},
97
- ]
 
 
 
 
 
 
 
 
 
 
98
 
99
- state.conversation = history
100
- state.status_text = "Bereit"
101
- return history, "", None, text, state.status_text
102
 
103
  # =====================================================
104
  # LAST ANSWER → TTS (für Button "Antwort erneut vorlesen")
105
  # =====================================================
106
- def read_last_answer(history):
107
  if not history:
108
  return None
 
 
 
 
 
109
 
110
- for msg in reversed(history):
111
- if msg.get("role") == "assistant":
112
- return synthesize_speech(msg.get("content", ""))
113
-
114
- return None
115
 
116
  # =====================================================
117
- # UIGRADIO
118
  # =====================================================
119
- with gr.Blocks(title="Prüfungsrechts-Chatbot (RAG + Sprache)") as demo:
120
- # Leichtes Styling: zentriert, schmale Breite, kompakte Input-Zeile (ChatGPT-like)
121
- gr.HTML(
122
- """
123
- <style>
124
- html, body {height: auto !important; overflow-y: auto !important;}
125
- .gradio-container {max-width: 960px; margin: 0 auto; padding: 12px;}
126
- #chat-wrap {position: relative;}
127
- #chat-input-row {transform: translateY(-28px); margin-bottom: -28px;}
128
-
129
- /* ChatGPT-like Bottom Bar */
130
- #chat-input-row {
131
- align-items: center;
132
- gap: 8px;
133
- padding: 8px 12px;
134
- border: 1px solid rgba(0,0,0,0.08);
135
- border-radius: 9999px;
136
- background: var(--background-primary);
137
- box-shadow: 0 2px 6px rgba(0,0,0,0.06);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- /* Textbox inside pill */
141
- #chat-textbox textarea {
142
- min-height: 42px;
143
- max-height: 120px;
144
- border: none !important;
145
- background: transparent !important;
146
- box-shadow: none !important;
147
- resize: none;
148
- padding-left: 0;
149
- }
150
 
151
- /* Icon buttons (plus, mic, send) */
152
- .icon-btn, .compact-btn {
153
- width: 32px;
154
- height: 32px;
155
- border-radius: 9999px !important;
156
- display: inline-flex;
157
- align-items: center;
158
- justify-content: center;
159
- border: 1px solid rgba(0,0,0,0.08) !important;
160
- background: #f7f7f8 !important;
161
- box-shadow: none !important;
162
- }
163
- .send-btn {
164
- background: #111 !important;
165
- color: #fff !important;
166
- border-color: #111 !important;
167
- }
168
- /* Make audio mic compact and borderless */
169
- #chat-audio {min-width: 32px; border: none !important; background: transparent !important;}
170
- #chat-audio .wrap, #chat-audio .audio-wrap, #chat-audio .audio-controls {max-width: 32px;}
171
- #chat-textbox textarea {border: none !important; outline: none !important;}
172
- @media (max-width: 768px) { #chat-input-row {transform: none; margin-bottom: 0;} }
173
- </style>
174
- """
175
- )
176
  gr.Markdown("# 🧑‍⚖️ Prüfungsrechts-Chatbot")
177
  gr.Markdown(
178
  "Dieser Chatbot beantwortet Fragen **ausschließlich** aus der "
179
  "Prüfungsordnung (PDF) und dem Hochschulgesetz NRW. "
180
- "Du kannst Text eingeben oder direkt ins Mikrofon sprechen."
 
181
  )
182
 
183
- # Einspaltiges Layout, alles untereinander (verhindert abgeschnittene Bereiche)
184
  with gr.Column(elem_id="chat-wrap"):
185
  chatbot = gr.Chatbot(
186
  label="Chat",
187
- height=280,
188
  )
189
 
 
 
 
190
  # Eingabezeile à la ChatGPT: Plus + Text + Mikro + Senden
191
  with gr.Row(elem_id="chat-input-row"):
192
- attach_btn = gr.UploadButton("+", file_types=["file"], file_count="multiple", elem_classes=["icon-btn"], scale=1)
 
 
 
 
 
 
193
  chat_text = gr.Textbox(
194
  elem_id="chat-textbox",
195
  label=None,
196
- placeholder="Stelle irgendeine Frage",
197
  lines=1,
198
  max_lines=6,
199
  autofocus=True,
@@ -205,35 +287,43 @@ with gr.Blocks(title="Prüfungsrechts-Chatbot (RAG + Sprache)") as demo:
205
  sources=["microphone"],
206
  type="filepath",
207
  format="wav",
208
- streaming=True,
209
  interactive=True,
210
  scale=1,
211
  show_label=False,
212
  )
213
- send_btn = gr.Button("➤", elem_classes=["compact-btn", "send-btn"], scale=1)
 
 
 
 
214
 
215
- # Senden bei Enter
216
  chat_text.submit(
217
  chat_fn,
218
- [chat_text, chat_audio, chatbot],
219
- [chatbot, chat_text, chat_audio],
220
  )
221
- def transcribe_to_textbox(audio_path):
222
- return transcribe_audio(audio_path, language=ASR_LANGUAGE_HINT)
223
  chat_audio.change(
224
- transcribe_to_textbox,
225
- [chat_audio],
226
- [chat_text],
227
- )
228
- chat_audio.stream(
229
- transcribe_to_textbox,
230
- [chat_audio],
231
- [chat_text],
232
  )
 
233
  send_btn.click(
234
  chat_fn,
235
- [chat_text, chat_audio, chatbot],
236
- [chatbot, chat_text, chat_audio],
 
 
 
 
 
 
 
 
 
237
  )
238
 
239
  # Quellen & Dokumente kompakt unterhalb
 
1
+ # app.py – Prüfungsrechts-Chatbot (RAG + Sprache, UI kiểu ChatGPT + VAD)
2
+
3
+ from __future__ import annotations
4
+
5
  import os
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, List
8
+
9
  import gradio as gr
10
  from gradio_pdf import PDF
11
 
 
17
  from rag_pipeline import answer
18
  from speech_io import transcribe_audio, synthesize_speech
19
 
20
+ ASR_LANGUAGE_HINT = os.getenv("ASR_LANGUAGE", "de") # "auto" = Auto-Detect
21
 
22
  # =====================================================
23
  # INITIALISIERUNG (global)
 
43
  hg_meta = next(d.metadata for d in docs if d.metadata.get("type") == "hg")
44
  hg_url = hg_meta.get("viewer_url")
45
 
46
+
47
  # =====================================================
48
  # Quellen formatieren – Markdown für Chat
49
  # =====================================================
 
61
 
62
  return "\n".join(out)
63
 
64
+
65
  # =====================================================
66
+ # State Management (wie Gradio Guide)
67
  # =====================================================
68
  @dataclass
69
  class AppState:
70
+ conversation: list = field(default_factory=list) # LLM-History (role/content)
71
+ stopped: bool = False
72
+ model_outs: Any = None
73
+
 
74
 
75
+ # =====================================================
76
+ # CORE CHAT-FUNKTION (Text + Mikro)
77
+ # =====================================================
78
+ def chat_fn(
79
+ text_input: str,
80
+ audio_path: str,
81
+ history: List[List[str]],
82
+ state: AppState,
83
+ ):
84
  """
85
+ text_input: Textbox-Inhalt
86
+ audio_path: Pfad zur Audiodatei aus gr.Audio (type="filepath")
87
+ history: Chatbot-Verlauf [[user, bot], ...]
88
+ state: AppState (Gradio State)
89
  """
90
  text = (text_input or "").strip()
91
 
92
+ # Nur Audio erst transkribieren
93
+ if audio_path and not text:
94
+ spoken = transcribe_audio(audio_path, language=ASR_LANGUAGE_HINT)
95
+ text = (spoken or "").strip()
 
 
96
 
97
  if not text:
98
+ # nichts zu tun
99
+ return history, state, "", None
100
 
101
+ # RAG-Antwort
102
  ans, sources = answer(text, retriever, llm)
103
  bot_msg = ans + format_sources(sources)
104
 
105
+ # State.conversation im LLM-Format (für spätere Erweiterungen)
106
+ state.conversation.append({"role": "user", "content": text})
107
+ state.conversation.append({"role": "assistant", "content": bot_msg})
108
+
109
+ # Chatbot-History im klassischen Gradio-Format
110
+ if history is None:
111
+ history = []
112
+ history = history + [[text, bot_msg]]
113
+
114
+ # Optional: hier könnte synthesize_speech(bot_msg) aufgerufen werden,
115
+ # wenn du die Antwort automatisch vorlesen lassen willst.
116
+ # tts_audio = synthesize_speech(bot_msg)
117
+
118
+ # Text- und Audioeingabe leeren
119
+ return history, state, "", None
120
 
 
 
 
121
 
122
  # =====================================================
123
  # LAST ANSWER → TTS (für Button "Antwort erneut vorlesen")
124
  # =====================================================
125
+ def read_last_answer(history: List[List[str]]):
126
  if not history:
127
  return None
128
+ last_pair = history[-1]
129
+ if len(last_pair) < 2:
130
+ return None
131
+ bot_text = last_pair[1]
132
+ return synthesize_speech(bot_text)
133
 
 
 
 
 
 
134
 
135
  # =====================================================
136
+ # CSS + JS (VAD) nach Gradio Guide adaptiert
137
  # =====================================================
138
+ CUSTOM_STYLE_AND_VAD = """
139
+ <style>
140
+ html, body {height: auto !important; overflow-y: auto !important;}
141
+ .gradio-container {max-width: 960px; margin: 0 auto; padding: 12px;}
142
+ #chat-wrap {position: relative;}
143
+ #chat-input-row {transform: translateY(-28px); margin-bottom: -28px;}
144
+
145
+ /* ChatGPT-like Bottom Bar */
146
+ #chat-input-row {
147
+ align-items: center;
148
+ gap: 8px;
149
+ padding: 8px 12px;
150
+ border: 1px solid rgba(0,0,0,0.08);
151
+ border-radius: 9999px;
152
+ background: var(--background-primary);
153
+ box-shadow: 0 2px 6px rgba(0,0,0,0.06);
154
+ }
155
+
156
+ /* Textbox inside pill */
157
+ #chat-textbox textarea {
158
+ min-height: 42px;
159
+ max-height: 120px;
160
+ border: none !important;
161
+ background: transparent !important;
162
+ box-shadow: none !important;
163
+ resize: none;
164
+ padding-left: 0;
165
+ }
166
+
167
+ /* Icon buttons (plus, mic, send) */
168
+ .icon-btn, .compact-btn {
169
+ width: 32px;
170
+ height: 32px;
171
+ border-radius: 9999px !important;
172
+ display: inline-flex;
173
+ align-items: center;
174
+ justify-content: center;
175
+ border: 1px solid rgba(0,0,0,0.08) !important;
176
+ background: #f7f7f8 !important;
177
+ box-shadow: none !important;
178
+ }
179
+ .send-btn {
180
+ background: #111 !important;
181
+ color: #fff !important;
182
+ border-color: #111 !important;
183
+ }
184
+
185
+ /* Make audio mic compact and borderless */
186
+ #chat-audio {min-width: 32px; border: none !important; background: transparent !important;}
187
+ #chat-audio .wrap, #chat-audio .audio-wrap, #chat-audio .audio-controls {max-width: 32px;}
188
+ #chat-textbox textarea {border: none !important; outline: none !important;}
189
+ @media (max-width: 768px) { #chat-input-row {transform: none; margin-bottom: 0;} }
190
+ </style>
191
+
192
+ <script>
193
+ /*
194
+ * Voice Activity Detection (VAD) nach Gradio Guide:
195
+ * Nutzt @ricky0123/vad-web, um automatisch auf die
196
+ * .record-button / .stop-button der Audio-Komponente zu klicken.
197
+ */
198
+ async function init_vad() {
199
+ try {
200
+ const script1 = document.createElement("script");
201
+ script1.src = "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.14.0/dist/ort.js";
202
+ document.head.appendChild(script1);
203
+
204
+ const script2 = document.createElement("script");
205
+ script2.onload = async () => {
206
+ console.log("VAD JS geladen");
207
+ const recordButton = document.querySelector('.record-button');
208
+ if (recordButton) {
209
+ recordButton.textContent = "Just start talking";
210
+ }
211
+ const myvad = await vad.MicVAD.new({
212
+ onSpeechStart: () => {
213
+ const record = document.querySelector('.record-button');
214
+ const player = document.querySelector('#streaming-out');
215
+ if (record && (!player || player.paused)) {
216
+ console.log("VAD: speech start → record.click()");
217
+ record.click();
218
+ }
219
+ },
220
+ onSpeechEnd: (audio) => {
221
+ const stop = document.querySelector('.stop-button');
222
+ if (stop) {
223
+ console.log("VAD: speech end → stop.click()");
224
+ stop.click();
225
+ }
226
  }
227
+ });
228
+ myvad.start();
229
+ };
230
+ script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/vad-web@0.0.7/dist/bundle.min.js";
231
+ document.head.appendChild(script2);
232
+ } catch (e) {
233
+ console.log("VAD init Fehler:", e);
234
+ }
235
+ }
236
+ if (typeof window !== "undefined") {
237
+ window.addEventListener("load", init_vad);
238
+ }
239
+ </script>
240
+ """
241
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # =====================================================
244
+ # UI – GRADIO (ChatGPT-artig + VAD)
245
+ # =====================================================
246
+ with gr.Blocks(title="Prüfungsrechts-Chatbot (RAG + Sprache)") as demo:
247
+ gr.HTML(CUSTOM_STYLE_AND_VAD)
248
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  gr.Markdown("# 🧑‍⚖️ Prüfungsrechts-Chatbot")
250
  gr.Markdown(
251
  "Dieser Chatbot beantwortet Fragen **ausschließlich** aus der "
252
  "Prüfungsordnung (PDF) und dem Hochschulgesetz NRW. "
253
+ "Du kannst Text eingeben oder einfach anfangen zu sprechen"
254
+ "die Aufnahme startet/stopt automatisch (Voice Activity Detection)."
255
  )
256
 
 
257
  with gr.Column(elem_id="chat-wrap"):
258
  chatbot = gr.Chatbot(
259
  label="Chat",
260
+ height=380,
261
  )
262
 
263
+ # globaler State für Konversation usw.
264
+ state = gr.State(value=AppState())
265
+
266
  # Eingabezeile à la ChatGPT: Plus + Text + Mikro + Senden
267
  with gr.Row(elem_id="chat-input-row"):
268
+ attach_btn = gr.UploadButton(
269
+ "+",
270
+ file_types=["file"],
271
+ file_count="multiple",
272
+ elem_classes=["icon-btn"],
273
+ scale=1,
274
+ )
275
  chat_text = gr.Textbox(
276
  elem_id="chat-textbox",
277
  label=None,
278
+ placeholder="Stelle irgendeine Frage oder sprich einfach los …",
279
  lines=1,
280
  max_lines=6,
281
  autofocus=True,
 
287
  sources=["microphone"],
288
  type="filepath",
289
  format="wav",
290
+ streaming=False, # wichtig: record/stop Buttons für VAD
291
  interactive=True,
292
  scale=1,
293
  show_label=False,
294
  )
295
+ send_btn = gr.Button(
296
+ "➤",
297
+ elem_classes=["compact-btn", "send-btn"],
298
+ scale=1,
299
+ )
300
 
301
+ # Senden bei Enter (Text)
302
  chat_text.submit(
303
  chat_fn,
304
+ [chat_text, chat_audio, chatbot, state],
305
+ [chatbot, state, chat_text, chat_audio],
306
  )
307
+ # Audio-Stop (manuell oder durch VAD) → ganze Pipeline
 
308
  chat_audio.change(
309
+ chat_fn,
310
+ [chat_text, chat_audio, chatbot, state],
311
+ [chatbot, state, chat_text, chat_audio],
 
 
 
 
 
312
  )
313
+ # Senden-Button
314
  send_btn.click(
315
  chat_fn,
316
+ [chat_text, chat_audio, chatbot, state],
317
+ [chatbot, state, chat_text, chat_audio],
318
+ )
319
+
320
+ # Optional: Button "Antwort erneut vorlesen"
321
+ voice_out = gr.Audio(label="Vorgelesene Antwort", type="numpy", interactive=False, elem_id="streaming-out")
322
+ read_btn = gr.Button("🔁 Antwort erneut vorlesen")
323
+ read_btn.click(
324
+ read_last_answer,
325
+ [chatbot],
326
+ [voice_out],
327
  )
328
 
329
  # Quellen & Dokumente kompakt unterhalb
speech_io.py CHANGED
@@ -2,14 +2,20 @@
2
  speech_io.py
3
 
4
  Sprachbasierte Ein-/Ausgabe:
5
- - Speech-to-Text (STT) mit Whisper (transformers.pipeline)
 
 
6
  - Text-to-Speech (TTS) mit MMS-TTS Deutsch
7
 
8
- Dieses File ist 100% stabil für HuggingFace Spaces.
9
  """
10
 
 
 
11
  import os
12
- from typing import Optional, Tuple
 
 
13
  import numpy as np
14
  import soundfile as sf
15
  from scipy.signal import butter, filtfilt, resample
@@ -17,26 +23,43 @@ from transformers import pipeline
17
  import re
18
  import difflib
19
 
20
- # Modelle (per ENV konfigurierbar)
21
- # Standard jetzt "tiny" für schnellere Verarbeitung; setze ASR_MODEL_ID env auf small/base/medium für mehr Genauigkeit
 
 
 
 
 
 
 
 
 
22
  ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny")
 
 
 
 
 
23
  TTS_MODEL_ID = os.getenv("TTS_MODEL_ID", "facebook/mms-tts-deu")
24
- ASR_DEFAULT_LANGUAGE = os.getenv("ASR_LANGUAGE", "de") # "auto" um Auto-Detect zu erzwingen
25
  TTS_ENABLED = os.getenv("TTS_ENABLED", "1").lower() not in ("0", "false", "no")
26
  ASR_PROMPT = os.getenv("ASR_PROMPT", "Dies ist ein Diktat in deutscher Sprache.")
27
  ASR_MAX_DURATION_S = int(os.getenv("ASR_MAX_DURATION_S", "30"))
28
 
29
  _asr = None
30
  _tts = None
 
 
31
 
32
  # ========================================================
33
- # STT PIPELINE
34
  # ========================================================
35
 
36
  def get_asr_pipeline():
 
37
  global _asr
38
  if _asr is None:
39
- print(f">>> Lade ASR Modell: {ASR_MODEL_ID}")
40
  _asr = pipeline(
41
  task="automatic-speech-recognition",
42
  model=ASR_MODEL_ID,
@@ -47,9 +70,6 @@ def get_asr_pipeline():
47
  )
48
  return _asr
49
 
50
- # ========================================================
51
- # TTS PIPELINE
52
- # ========================================================
53
 
54
  def get_tts_pipeline():
55
  global _tts
@@ -61,8 +81,26 @@ def get_tts_pipeline():
61
  )
62
  return _tts
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # ========================================================
65
- # AUDIO FILTER – Noise Reduction + Highpass
66
  # ========================================================
67
 
68
  def butter_highpass_filter(data, cutoff=60, fs=16000, order=4):
@@ -71,6 +109,7 @@ def butter_highpass_filter(data, cutoff=60, fs=16000, order=4):
71
  b, a = butter(order, norm_cutoff, btype="high")
72
  return filtfilt(b, a, data)
73
 
 
74
  def apply_fade(audio, sr, duration_ms=10):
75
  fade_samples = int(sr * duration_ms / 1000)
76
 
@@ -85,22 +124,90 @@ def apply_fade(audio, sr, duration_ms=10):
85
 
86
  return audio
87
 
 
88
  # ========================================================
89
- # SPEECH-TO-TEXT (STT)
90
  # ========================================================
91
 
92
- def transcribe_audio(audio_path: str, language: Optional[str] = None, max_duration_s: int = ASR_MAX_DURATION_S) -> str:
93
  """
94
- audio_path: path zu WAV-Datei (von gr.Audio type="filepath")
95
- language: Optional Sprachcode (z.B. "de", "en"); None = Auto-Detect
96
- max_duration_s: begrenzt Audiolänge für schnellere Verarbeitung
97
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
 
 
 
 
 
99
  if audio_path is None or not os.path.exists(audio_path):
100
  print(">>> Kein Audio gefunden.")
101
  return ""
102
 
103
- # WAV einlesen (soundfile garantiert PCM korrekt)
104
  data, sr = sf.read(audio_path, always_2d=False)
105
 
106
  if data is None or data.size == 0:
@@ -114,19 +221,18 @@ def transcribe_audio(audio_path: str, language: Optional[str] = None, max_durati
114
  data = np.clip(data, -1.0, 1.0)
115
  try:
116
  data = butter_highpass_filter(data, cutoff=60, fs=sr)
117
- except:
118
  pass
 
119
  m = np.max(np.abs(data))
120
  if m > 0:
121
  data = data / m
122
 
123
- # sehr leise Aufnahmen filtern, damit nicht nur Rauschen erkannt wird
124
  rms = float(np.sqrt(np.mean(data ** 2)))
125
  if rms < 5e-5:
126
  print(">>> Audio zu leise, breche ab.")
127
  return ""
128
 
129
- # bei zu hoher Samplingrate auf 16 kHz runterskalieren (schneller, kleiner)
130
  TARGET_SR = 16000
131
  if sr != TARGET_SR:
132
  target_len = int(len(data) * TARGET_SR / sr)
@@ -135,15 +241,16 @@ def transcribe_audio(audio_path: str, language: Optional[str] = None, max_durati
135
 
136
  idx = np.where(np.abs(data) > 0.02)[0]
137
  if idx.size:
138
- data = data[idx[0]:idx[-1]+1]
139
 
140
- # Stats zum Debuggen
141
  duration_s = len(data) / sr if sr else 0
142
  rms = float(np.sqrt(np.mean(data ** 2)))
143
  peak = float(np.max(np.abs(data))) if data.size else 0.0
144
- print(f">>> Audio stats – sr: {sr}, len: {len(data)}, dur: {duration_s:.2f}s, rms: {rms:.6f}, peak: {peak:.6f}")
 
 
 
145
 
146
- # sehr leise / sehr kurze Aufnahmen filtern, damit nicht nur Rauschen erkannt wird
147
  if duration_s < 0.3:
148
  print(">>> Audio zu kurz, breche ab.")
149
  return ""
@@ -151,22 +258,20 @@ def transcribe_audio(audio_path: str, language: Optional[str] = None, max_durati
151
  print(">>> Audio zu leise, breche ab.")
152
  return ""
153
 
154
- # Whisper > max_duration_s vermeiden
155
  MAX_SAMPLES = sr * max_duration_s
156
  if len(data) > MAX_SAMPLES:
157
  data = data[:MAX_SAMPLES]
158
 
159
  asr = get_asr_pipeline()
160
 
161
- print(">>> Transkribiere Audio...")
162
  lang = language
163
  if not lang and ASR_DEFAULT_LANGUAGE and ASR_DEFAULT_LANGUAGE.lower() != "auto":
164
  lang = ASR_DEFAULT_LANGUAGE
165
  if isinstance(lang, str) and lang.lower() == "auto":
166
  lang = None
167
 
168
- call_kwargs = {}
169
- # Dynamische Dekodier-Settings: kurze Clips -> kleinere Token-Budgets gegen Halluzination
170
  token_budget = 120
171
  if duration_s < 2.0:
172
  token_budget = 60
@@ -184,6 +289,7 @@ def transcribe_audio(audio_path: str, language: Optional[str] = None, max_durati
184
  "logprob_threshold": -1.0,
185
  "no_speech_threshold": 0.6,
186
  "no_repeat_ngram_size": 3,
 
187
  }
188
 
189
  result = asr({"array": data, "sampling_rate": sr}, **call_kwargs)
@@ -193,6 +299,7 @@ def transcribe_audio(audio_path: str, language: Optional[str] = None, max_durati
193
 
194
  text = result.get("text", "") if isinstance(result, dict) else str(result)
195
  text = text.strip()
 
196
  def _fix_domain_terms(s: str) -> str:
197
  pairs = [
198
  (r"\bbriefe\s*um\b", "prüfung"),
@@ -203,8 +310,15 @@ def transcribe_audio(audio_path: str, language: Optional[str] = None, max_durati
203
  for pat, rep in pairs:
204
  s = re.sub(pat, rep, s, flags=re.IGNORECASE)
205
  vocab = [
206
- "prüfung","prüfungsordnung","hochschulgesetz","modul","klausur",
207
- "immatrikulation","exmatrikulation","anmeldung","wiederholung"
 
 
 
 
 
 
 
208
  ]
209
  tokens = s.split()
210
  fixed = []
@@ -212,12 +326,42 @@ def transcribe_audio(audio_path: str, language: Optional[str] = None, max_durati
212
  cand = difflib.get_close_matches(t.lower(), vocab, n=1, cutoff=0.82)
213
  fixed.append(cand[0] if cand else t)
214
  return " ".join(fixed)
 
215
  text = _fix_domain_terms(text)
216
- print("ASR:", text)
217
  return text
218
 
 
219
  # ========================================================
220
- # TEXT-TO-SPEECH (TTS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # ========================================================
222
 
223
  def synthesize_speech(text: str):
@@ -227,36 +371,88 @@ def synthesize_speech(text: str):
227
  tts = get_tts_pipeline()
228
  out = tts(text)
229
 
230
- # rohes Audio from MMS (float32 [-1, 1])
231
  audio = np.array(out["audio"], dtype=np.float32)
232
  sr = out.get("sampling_rate", 16000)
233
 
234
- # ===== FIX sample_rate =====
235
  if sr is None or sr <= 0 or sr > 65535:
236
  sr = 16000
237
 
238
- # ===== Mono erzwingen =====
239
  if audio.ndim > 1:
240
  audio = audio.squeeze()
241
  if audio.ndim > 1:
242
  audio = audio[:, 0]
243
 
244
- # ===== Noise reduction =====
245
  try:
246
  audio = butter_highpass_filter(audio, cutoff=60, fs=sr)
247
- except:
248
  pass
249
 
250
- # ===== Normalize =====
251
  max_val = np.max(np.abs(audio))
252
  if max_val > 0:
253
  audio = audio / max_val
254
 
255
- # ===== Fade gegen pop =====
256
  audio = apply_fade(audio, sr)
257
 
258
- # ===== int16 =====
259
  audio_int16 = np.clip(audio * 32767, -32768, 32767).astype(np.int16)
260
 
261
- # Rückgabe: (sr, np.int16 array)
262
- return (sr, audio_int16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  speech_io.py
3
 
4
  Sprachbasierte Ein-/Ausgabe:
5
+ - Speech-to-Text (STT) mit Whisper:
6
+ * lokal über transformers.pipeline
7
+ * optional über Groq Whisper (whisper-large-v3[-turbo])
8
  - Text-to-Speech (TTS) mit MMS-TTS Deutsch
9
 
10
+ Dieses File ist stabil für HuggingFace Spaces.
11
  """
12
 
13
+ from __future__ import annotations
14
+
15
  import os
16
+ import time
17
+ from typing import Optional, Tuple, List, Dict, Any
18
+
19
  import numpy as np
20
  import soundfile as sf
21
  from scipy.signal import butter, filtfilt, resample
 
23
  import re
24
  import difflib
25
 
26
+ # Groq ist optional: nur genutzt, wenn installiert + API-Key gesetzt
27
+ try:
28
+ from groq import Groq # type: ignore
29
+ except Exception: # Modul evtl. nicht installiert
30
+ Groq = None # type: ignore
31
+
32
+ # ============================
33
+ # Konfiguration über ENV
34
+ # ============================
35
+
36
+ # Lokales Whisper (transformers) – Standard tiny für Geschwindigkeit.
37
  ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny")
38
+
39
+ # Optional: Groq Whisper Backend
40
+ USE_GROQ_WHISPER = os.getenv("USE_GROQ_WHISPER", "0").lower() in ("1", "true", "yes")
41
+ GROQ_WHISPER_MODEL = os.getenv("GROQ_WHISPER_MODEL", "whisper-large-v3-turbo")
42
+
43
  TTS_MODEL_ID = os.getenv("TTS_MODEL_ID", "facebook/mms-tts-deu")
44
+ ASR_DEFAULT_LANGUAGE = os.getenv("ASR_LANGUAGE", "de") # "auto" Auto-Detekt
45
  TTS_ENABLED = os.getenv("TTS_ENABLED", "1").lower() not in ("0", "false", "no")
46
  ASR_PROMPT = os.getenv("ASR_PROMPT", "Dies ist ein Diktat in deutscher Sprache.")
47
  ASR_MAX_DURATION_S = int(os.getenv("ASR_MAX_DURATION_S", "30"))
48
 
49
  _asr = None
50
  _tts = None
51
+ _groq_client = None
52
+
53
 
54
  # ========================================================
55
+ # PIPELINES
56
  # ========================================================
57
 
58
  def get_asr_pipeline():
59
+ """Lokales Whisper-Pipeline (transformers)."""
60
  global _asr
61
  if _asr is None:
62
+ print(f">>> Lade lokales ASR Modell: {ASR_MODEL_ID}")
63
  _asr = pipeline(
64
  task="automatic-speech-recognition",
65
  model=ASR_MODEL_ID,
 
70
  )
71
  return _asr
72
 
 
 
 
73
 
74
  def get_tts_pipeline():
75
  global _tts
 
81
  )
82
  return _tts
83
 
84
+
85
+ def get_groq_client():
86
+ """Lazy Init des Groq-Clients – nur wenn USE_GROQ_WHISPER aktiv ist."""
87
+ global _groq_client
88
+ if _groq_client is None:
89
+ if Groq is None:
90
+ raise RuntimeError(
91
+ "Groq Python-Client nicht installiert. "
92
+ "Bitte `pip install groq` und USE_GROQ_WHISPER=1 setzen."
93
+ )
94
+ api_key = os.getenv("GROQ_API_KEY")
95
+ if not api_key:
96
+ raise RuntimeError("GROQ_API_KEY ist nicht gesetzt.")
97
+ _groq_client = Groq(api_key=api_key) # type: ignore
98
+ print(">>> Groq-Client initialisiert.")
99
+ return _groq_client
100
+
101
+
102
  # ========================================================
103
+ # AUDIO FILTER – Noise Reduction + Highpass
104
  # ========================================================
105
 
106
  def butter_highpass_filter(data, cutoff=60, fs=16000, order=4):
 
109
  b, a = butter(order, norm_cutoff, btype="high")
110
  return filtfilt(b, a, data)
111
 
112
+
113
  def apply_fade(audio, sr, duration_ms=10):
114
  fade_samples = int(sr * duration_ms / 1000)
115
 
 
124
 
125
  return audio
126
 
127
+
128
  # ========================================================
129
+ # GROQ-WHISPER HELFER
130
  # ========================================================
131
 
132
+ def _process_groq_whisper_response(completion: Any) -> str:
133
  """
134
+ Auswertung der Groq Whisper-Antwort (verbose_json) analog zum Gradio-Guide:
135
+ - nutzt no_speech_prob, um reines Rauschen zu filtern
 
136
  """
137
+ # completion kann ein Pydantic-Objekt oder ein dict sein
138
+ segments = None
139
+ text = getattr(completion, "text", None)
140
+ if hasattr(completion, "segments"):
141
+ segments = completion.segments
142
+ elif isinstance(completion, dict):
143
+ segments = completion.get("segments", [])
144
+ text = completion.get("text", "")
145
+
146
+ if not segments:
147
+ return ""
148
+
149
+ first = segments[0]
150
+ if isinstance(first, dict):
151
+ no_speech_prob = first.get("no_speech_prob", 0.0)
152
+ else:
153
+ no_speech_prob = getattr(first, "no_speech_prob", 0.0)
154
+
155
+ print("Groq Whisper no_speech_prob:", no_speech_prob)
156
+ if no_speech_prob > 0.7:
157
+ # wahrscheinlich nur Rauschen
158
+ return ""
159
+
160
+ if text is None:
161
+ return ""
162
+ return str(text).strip()
163
+
164
+
165
+ def transcribe_with_groq(audio_path: str, language: Optional[str]) -> str:
166
+ """
167
+ STT über Groq Whisper (whisper-large-v3(-turbo)).
168
+ Erwartet eine Audiodatei (z.B. WAV) von gr.Audio (type='filepath').
169
+ """
170
+ client = get_groq_client()
171
+
172
+ if language:
173
+ lang_param = None if language.lower() == "auto" else language
174
+ else:
175
+ if ASR_DEFAULT_LANGUAGE and ASR_DEFAULT_LANGUAGE.lower() != "auto":
176
+ lang_param = ASR_DEFAULT_LANGUAGE
177
+ else:
178
+ lang_param = None
179
+
180
+ try:
181
+ with open(audio_path, "rb") as audio_file:
182
+ resp = client.audio.transcriptions.with_raw_response.create(
183
+ model=GROQ_WHISPER_MODEL,
184
+ file=("audio.wav", audio_file),
185
+ response_format="verbose_json",
186
+ language=lang_param,
187
+ )
188
+ completion = resp.parse()
189
+ except Exception as e:
190
+ print(f"Groq Whisper Fehler: {e}")
191
+ return ""
192
+
193
+ text = _process_groq_whisper_response(completion)
194
+ print("Groq ASR:", text)
195
+ return text
196
+
197
+
198
+ # ========================================================
199
+ # SPEECH-TO-TEXT (lokal) – wie bisher
200
+ # ========================================================
201
 
202
+ def _transcribe_local_whisper(
203
+ audio_path: str,
204
+ language: Optional[str] = None,
205
+ max_duration_s: int = ASR_MAX_DURATION_S,
206
+ ) -> str:
207
  if audio_path is None or not os.path.exists(audio_path):
208
  print(">>> Kein Audio gefunden.")
209
  return ""
210
 
 
211
  data, sr = sf.read(audio_path, always_2d=False)
212
 
213
  if data is None or data.size == 0:
 
221
  data = np.clip(data, -1.0, 1.0)
222
  try:
223
  data = butter_highpass_filter(data, cutoff=60, fs=sr)
224
+ except Exception:
225
  pass
226
+
227
  m = np.max(np.abs(data))
228
  if m > 0:
229
  data = data / m
230
 
 
231
  rms = float(np.sqrt(np.mean(data ** 2)))
232
  if rms < 5e-5:
233
  print(">>> Audio zu leise, breche ab.")
234
  return ""
235
 
 
236
  TARGET_SR = 16000
237
  if sr != TARGET_SR:
238
  target_len = int(len(data) * TARGET_SR / sr)
 
241
 
242
  idx = np.where(np.abs(data) > 0.02)[0]
243
  if idx.size:
244
+ data = data[idx[0]: idx[-1] + 1]
245
 
 
246
  duration_s = len(data) / sr if sr else 0
247
  rms = float(np.sqrt(np.mean(data ** 2)))
248
  peak = float(np.max(np.abs(data))) if data.size else 0.0
249
+ print(
250
+ f">>> Audio stats – sr: {sr}, len: {len(data)}, "
251
+ f"dur: {duration_s:.2f}s, rms: {rms:.6f}, peak: {peak:.6f}"
252
+ )
253
 
 
254
  if duration_s < 0.3:
255
  print(">>> Audio zu kurz, breche ab.")
256
  return ""
 
258
  print(">>> Audio zu leise, breche ab.")
259
  return ""
260
 
 
261
  MAX_SAMPLES = sr * max_duration_s
262
  if len(data) > MAX_SAMPLES:
263
  data = data[:MAX_SAMPLES]
264
 
265
  asr = get_asr_pipeline()
266
 
267
+ print(">>> Transkribiere Audio (lokal)...")
268
  lang = language
269
  if not lang and ASR_DEFAULT_LANGUAGE and ASR_DEFAULT_LANGUAGE.lower() != "auto":
270
  lang = ASR_DEFAULT_LANGUAGE
271
  if isinstance(lang, str) and lang.lower() == "auto":
272
  lang = None
273
 
274
+ call_kwargs: Dict[str, Any] = {}
 
275
  token_budget = 120
276
  if duration_s < 2.0:
277
  token_budget = 60
 
289
  "logprob_threshold": -1.0,
290
  "no_speech_threshold": 0.6,
291
  "no_repeat_ngram_size": 3,
292
+ "prompt": ASR_PROMPT,
293
  }
294
 
295
  result = asr({"array": data, "sampling_rate": sr}, **call_kwargs)
 
299
 
300
  text = result.get("text", "") if isinstance(result, dict) else str(result)
301
  text = text.strip()
302
+
303
  def _fix_domain_terms(s: str) -> str:
304
  pairs = [
305
  (r"\bbriefe\s*um\b", "prüfung"),
 
310
  for pat, rep in pairs:
311
  s = re.sub(pat, rep, s, flags=re.IGNORECASE)
312
  vocab = [
313
+ "prüfung",
314
+ "prüfungsordnung",
315
+ "hochschulgesetz",
316
+ "modul",
317
+ "klausur",
318
+ "immatrikulation",
319
+ "exmatrikulation",
320
+ "anmeldung",
321
+ "wiederholung",
322
  ]
323
  tokens = s.split()
324
  fixed = []
 
326
  cand = difflib.get_close_matches(t.lower(), vocab, n=1, cutoff=0.82)
327
  fixed.append(cand[0] if cand else t)
328
  return " ".join(fixed)
329
+
330
  text = _fix_domain_terms(text)
331
+ print("ASR (lokal):", text)
332
  return text
333
 
334
+
335
  # ========================================================
336
+ # Public STT-Funktion – wählt Backend (lokal vs Groq)
337
+ # ========================================================
338
+
339
+ def transcribe_audio(
340
+ audio_path: str,
341
+ language: Optional[str] = None,
342
+ max_duration_s: int = ASR_MAX_DURATION_S,
343
+ ) -> str:
344
+ """
345
+ High-Level STT-API:
346
+ - Wenn USE_GROQ_WHISPER=1 und GROQ_API_KEY gesetzt → Groq Whisper
347
+ - sonst lokales Whisper (transformers)
348
+ """
349
+
350
+ if not audio_path:
351
+ return ""
352
+
353
+ if USE_GROQ_WHISPER:
354
+ try:
355
+ return transcribe_with_groq(audio_path, language)
356
+ except Exception as e:
357
+ # Fallback auf lokales Modell, falls Groq fehlschlägt
358
+ print(f">>> Groq Whisper Fehler, fallback auf lokal: {e}")
359
+
360
+ return _transcribe_local_whisper(audio_path, language, max_duration_s)
361
+
362
+
363
+ # ========================================================
364
+ # TEXT-TO-SPEECH (TTS)
365
  # ========================================================
366
 
367
  def synthesize_speech(text: str):
 
371
  tts = get_tts_pipeline()
372
  out = tts(text)
373
 
 
374
  audio = np.array(out["audio"], dtype=np.float32)
375
  sr = out.get("sampling_rate", 16000)
376
 
 
377
  if sr is None or sr <= 0 or sr > 65535:
378
  sr = 16000
379
 
 
380
  if audio.ndim > 1:
381
  audio = audio.squeeze()
382
  if audio.ndim > 1:
383
  audio = audio[:, 0]
384
 
 
385
  try:
386
  audio = butter_highpass_filter(audio, cutoff=60, fs=sr)
387
+ except Exception:
388
  pass
389
 
 
390
  max_val = np.max(np.abs(audio))
391
  if max_val > 0:
392
  audio = audio / max_val
393
 
 
394
  audio = apply_fade(audio, sr)
395
 
 
396
  audio_int16 = np.clip(audio * 32767, -32768, 32767).astype(np.int16)
397
 
398
+ return sr, audio_int16
399
+
400
+
401
+ # ========================================================
402
+ # SIMPLE BENCHMARK-FUNKTION FÜR WHISPER-MODELLE
403
+ # ========================================================
404
+
405
+ def benchmark_asr_models(
406
+ audio_path: str,
407
+ local_models: Optional[List[str]] = None,
408
+ groq_models: Optional[List[str]] = None,
409
+ ) -> Dict[str, Dict[str, Any]]:
410
+ """
411
+ Einfache Benchmark-Routine:
412
+ - misst Laufzeit und Textlänge für verschiedene Whisper-Modelle
413
+ - wird NICHT automatisch im Space ausgeführt, nur manuell aufrufbar.
414
+
415
+ Beispiel (lokal):
416
+ benchmark_asr_models("sample.wav",
417
+ local_models=["openai/whisper-tiny","openai/whisper-small"])
418
+
419
+ Beispiel (Groq, falls GROQ_API_KEY vorhanden):
420
+ benchmark_asr_models("sample.wav",
421
+ groq_models=["whisper-large-v3-turbo","whisper-large-v3"])
422
+ """
423
+ results: Dict[str, Dict[str, Any]] = {}
424
+
425
+ if local_models:
426
+ for mid in local_models:
427
+ t0 = time.perf_counter()
428
+ global _asr
429
+ _asr = None
430
+ os.environ["ASR_MODEL_ID"] = mid
431
+ text = _transcribe_local_whisper(audio_path, language=None)
432
+ dt = time.perf_counter() - t0
433
+ results[f"local::{mid}"] = {
434
+ "seconds": dt,
435
+ "chars": len(text),
436
+ "text_sample": text[:120],
437
+ }
438
+
439
+ if groq_models:
440
+ if Groq is None or not os.getenv("GROQ_API_KEY"):
441
+ print(">>> Groq Benchmark übersprungen (kein Client/API-Key).")
442
+ else:
443
+ for mid in groq_models:
444
+ t0 = time.perf_counter()
445
+ global GROQ_WHISPER_MODEL
446
+ GROQ_WHISPER_MODEL = mid
447
+ text = transcribe_with_groq(audio_path, language=None)
448
+ dt = time.perf_counter() - t0
449
+ results[f"groq::{mid}"] = {
450
+ "seconds": dt,
451
+ "chars": len(text),
452
+ "text_sample": text[:120],
453
+ }
454
+
455
+ for name, info in results.items():
456
+ print(f"[{name}] {info['seconds']:.2f}s – {info['chars']} chars")
457
+
458
+ return results