tonyassi commited on
Commit
0d48af6
·
verified ·
1 Parent(s): 8a1e13a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -28
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import time
 
3
  from collections import deque
4
  from flask import Flask, request, jsonify
5
  from waitress import serve
@@ -7,6 +8,8 @@ from waitress import serve
7
  from google import genai
8
  from google.genai import types
9
 
 
 
10
  app = Flask(__name__)
11
 
12
  # -------------------------
@@ -20,6 +23,12 @@ SYSTEM_PROMPT = (
20
  "Respond in 1-3 sentences and less than 300 characters."
21
  )
22
 
 
 
 
 
 
 
23
  # Gemini client (expects GEMINI_API_KEY set as a HF Space Secret)
24
  client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
25
 
@@ -27,15 +36,31 @@ client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
27
  MAX_MESSAGES = 20 # user+assistant messages combined
28
  HISTORY = deque(maxlen=MAX_MESSAGES) # holds types.Content objects
29
 
 
 
 
30
 
31
  def _client_ip() -> str:
32
- # HF may proxy requests; this is best-effort
33
  return request.headers.get("x-forwarded-for", request.remote_addr or "unknown")
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def _gemini_config() -> types.GenerateContentConfig:
37
- # NOTE: Setting thresholds to OFF is permissive and may not be honored for all content;
38
- # some protections are not adjustable.
39
  return types.GenerateContentConfig(
40
  system_instruction=[types.Part.from_text(text=SYSTEM_PROMPT)],
41
  thinking_config=types.ThinkingConfig(thinking_level=THINKING_LEVEL),
@@ -61,21 +86,11 @@ def _gemini_config() -> types.GenerateContentConfig:
61
 
62
 
63
  def llm_chat(user_text: str) -> str:
64
- """
65
- Updates global HISTORY (user + model), calls Gemini, returns model reply text.
66
- Rolls back the last user message if Gemini call fails.
67
- """
68
  user_text = (user_text or "").strip()
69
  if not user_text:
70
  raise ValueError("Missing 'text'")
71
 
72
- # Add user message to memory
73
- HISTORY.append(
74
- types.Content(
75
- role="user",
76
- parts=[types.Part.from_text(text=user_text)],
77
- )
78
- )
79
 
80
  try:
81
  resp = client.models.generate_content(
@@ -85,18 +100,10 @@ def llm_chat(user_text: str) -> str:
85
  )
86
  reply_text = (resp.text or "").strip()
87
 
88
- # Add assistant message to memory
89
- HISTORY.append(
90
- types.Content(
91
- role="model",
92
- parts=[types.Part.from_text(text=reply_text)],
93
- )
94
- )
95
-
96
  return reply_text
97
 
98
  except Exception:
99
- # Roll back last user message on failure
100
  if len(HISTORY) > 0 and getattr(HISTORY[-1], "role", None) == "user":
101
  HISTORY.pop()
102
  raise
@@ -114,6 +121,9 @@ def health():
114
  "thinking_level": THINKING_LEVEL,
115
  "memory_messages": len(HISTORY),
116
  "max_messages": MAX_MESSAGES,
 
 
 
117
  })
118
 
119
 
@@ -158,14 +168,62 @@ def chat_text():
158
 
159
 
160
  @app.post("/v1/utterance")
161
- def chat_audio():
162
  """
163
- Audio endpoint (placeholder for now).
164
- Later: accept audio (multipart/form-data), run STT -> llm_chat -> TTS -> return audio.
165
  """
 
166
  ip = _client_ip()
167
- print(f"[/v1/utterance] HIT {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} (not implemented)")
168
- return jsonify({"error": "Not implemented yet"}), 501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  @app.post("/v1/reset")
@@ -182,4 +240,5 @@ def reset():
182
  if __name__ == "__main__":
183
  port = int(os.environ.get("PORT", "7860"))
184
  print(f"[startup] model={MODEL} thinking_level={THINKING_LEVEL} max_messages={MAX_MESSAGES} port={port}")
 
185
  serve(app, host="0.0.0.0", port=port)
 
1
  import os
2
  import time
3
+ import tempfile
4
  from collections import deque
5
  from flask import Flask, request, jsonify
6
  from waitress import serve
 
8
  from google import genai
9
  from google.genai import types
10
 
11
+ from faster_whisper import WhisperModel
12
+
13
  app = Flask(__name__)
14
 
15
  # -------------------------
 
23
  "Respond in 1-3 sentences and less than 300 characters."
24
  )
25
 
26
+ # STT config (we chose base.en)
27
+ WHISPER_MODEL_NAME = os.environ.get("WHISPER_MODEL", "base.en")
28
+ WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cpu")
29
+ WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE", "int8")
30
+ WHISPER_LANGUAGE = os.environ.get("WHISPER_LANGUAGE", "en")
31
+
32
  # Gemini client (expects GEMINI_API_KEY set as a HF Space Secret)
33
  client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
34
 
 
36
  MAX_MESSAGES = 20 # user+assistant messages combined
37
  HISTORY = deque(maxlen=MAX_MESSAGES) # holds types.Content objects
38
 
39
+ # ---- Whisper model (lazy init) ----
40
+ _whisper_model = None
41
+
42
 
43
  def _client_ip() -> str:
 
44
  return request.headers.get("x-forwarded-for", request.remote_addr or "unknown")
45
 
46
 
47
+ def _get_whisper_model() -> WhisperModel:
48
+ global _whisper_model
49
+ if _whisper_model is None:
50
+ print(
51
+ f"[whisper] loading model={WHISPER_MODEL_NAME} "
52
+ f"device={WHISPER_DEVICE} compute_type={WHISPER_COMPUTE_TYPE}"
53
+ )
54
+ _whisper_model = WhisperModel(
55
+ WHISPER_MODEL_NAME,
56
+ device=WHISPER_DEVICE,
57
+ compute_type=WHISPER_COMPUTE_TYPE,
58
+ )
59
+ print("[whisper] loaded")
60
+ return _whisper_model
61
+
62
+
63
  def _gemini_config() -> types.GenerateContentConfig:
 
 
64
  return types.GenerateContentConfig(
65
  system_instruction=[types.Part.from_text(text=SYSTEM_PROMPT)],
66
  thinking_config=types.ThinkingConfig(thinking_level=THINKING_LEVEL),
 
86
 
87
 
88
  def llm_chat(user_text: str) -> str:
 
 
 
 
89
  user_text = (user_text or "").strip()
90
  if not user_text:
91
  raise ValueError("Missing 'text'")
92
 
93
+ HISTORY.append(types.Content(role="user", parts=[types.Part.from_text(text=user_text)]))
 
 
 
 
 
 
94
 
95
  try:
96
  resp = client.models.generate_content(
 
100
  )
101
  reply_text = (resp.text or "").strip()
102
 
103
+ HISTORY.append(types.Content(role="model", parts=[types.Part.from_text(text=reply_text)]))
 
 
 
 
 
 
 
104
  return reply_text
105
 
106
  except Exception:
 
107
  if len(HISTORY) > 0 and getattr(HISTORY[-1], "role", None) == "user":
108
  HISTORY.pop()
109
  raise
 
121
  "thinking_level": THINKING_LEVEL,
122
  "memory_messages": len(HISTORY),
123
  "max_messages": MAX_MESSAGES,
124
+ "whisper_model": WHISPER_MODEL_NAME,
125
+ "whisper_device": WHISPER_DEVICE,
126
+ "whisper_compute_type": WHISPER_COMPUTE_TYPE,
127
  })
128
 
129
 
 
168
 
169
 
170
  @app.post("/v1/utterance")
171
+ def utterance_to_text():
172
  """
173
+ Accepts: multipart/form-data with field "audio" containing a .wav file
174
+ Returns: JSON { "text": "<transcript>", "total_ms": <int> }
175
  """
176
+ t0 = time.time()
177
  ip = _client_ip()
178
+
179
+ print(f"[/v1/utterance] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip}")
180
+
181
+ if "audio" not in request.files:
182
+ print(f"[/v1/utterance] ERROR missing file field 'audio' ip={ip}")
183
+ return jsonify({"error": "Missing file field 'audio'"}), 400
184
+
185
+ f = request.files["audio"]
186
+ filename = (f.filename or "").strip() or "audio.wav"
187
+
188
+ if not filename.lower().endswith(".wav"):
189
+ print(f"[/v1/utterance] ERROR non-wav filename={filename!r} ip={ip}")
190
+ return jsonify({"error": "Please upload a .wav file"}), 400
191
+
192
+ print(f"[/v1/utterance] received filename={filename!r} content_type={f.content_type!r}")
193
+
194
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
195
+ tmp_path = tmp.name
196
+ f.save(tmp_path)
197
+
198
+ try:
199
+ model = _get_whisper_model()
200
+
201
+ segments, info = model.transcribe(
202
+ tmp_path,
203
+ language=WHISPER_LANGUAGE,
204
+ vad_filter=True,
205
+ beam_size=1, # fast
206
+ )
207
+
208
+ text = "".join(seg.text for seg in segments).strip()
209
+
210
+ dt_ms = int((time.time() - t0) * 1000)
211
+ print(f"[/v1/utterance] transcript_len={len(text)} total_ms={dt_ms}")
212
+ print(f"[/v1/utterance] transcript={text!r}")
213
+
214
+ return jsonify({"text": text, "total_ms": dt_ms})
215
+
216
+ except Exception as e:
217
+ dt_ms = int((time.time() - t0) * 1000)
218
+ print("Whisper error:", repr(e))
219
+ print(f"[/v1/utterance] FAIL ip={ip} total_ms={dt_ms}")
220
+ return jsonify({"error": "STT failed"}), 500
221
+
222
+ finally:
223
+ try:
224
+ os.remove(tmp_path)
225
+ except Exception:
226
+ pass
227
 
228
 
229
  @app.post("/v1/reset")
 
240
  if __name__ == "__main__":
241
  port = int(os.environ.get("PORT", "7860"))
242
  print(f"[startup] model={MODEL} thinking_level={THINKING_LEVEL} max_messages={MAX_MESSAGES} port={port}")
243
+ print(f"[startup] whisper_model={WHISPER_MODEL_NAME} device={WHISPER_DEVICE} compute={WHISPER_COMPUTE_TYPE}")
244
  serve(app, host="0.0.0.0", port=port)