Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import time
|
|
|
|
| 3 |
from collections import deque
|
| 4 |
from flask import Flask, request, jsonify
|
| 5 |
from waitress import serve
|
|
@@ -7,6 +8,8 @@ from waitress import serve
|
|
| 7 |
from google import genai
|
| 8 |
from google.genai import types
|
| 9 |
|
|
|
|
|
|
|
| 10 |
app = Flask(__name__)
|
| 11 |
|
| 12 |
# -------------------------
|
|
@@ -20,6 +23,12 @@ SYSTEM_PROMPT = (
|
|
| 20 |
"Respond in 1-3 sentences and less than 300 characters."
|
| 21 |
)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Gemini client (expects GEMINI_API_KEY set as a HF Space Secret)
|
| 24 |
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
|
| 25 |
|
|
@@ -27,15 +36,31 @@ client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
|
|
| 27 |
MAX_MESSAGES = 20 # user+assistant messages combined
|
| 28 |
HISTORY = deque(maxlen=MAX_MESSAGES) # holds types.Content objects
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def _client_ip() -> str:
|
| 32 |
-
# HF may proxy requests; this is best-effort
|
| 33 |
return request.headers.get("x-forwarded-for", request.remote_addr or "unknown")
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def _gemini_config() -> types.GenerateContentConfig:
|
| 37 |
-
# NOTE: Setting thresholds to OFF is permissive and may not be honored for all content;
|
| 38 |
-
# some protections are not adjustable.
|
| 39 |
return types.GenerateContentConfig(
|
| 40 |
system_instruction=[types.Part.from_text(text=SYSTEM_PROMPT)],
|
| 41 |
thinking_config=types.ThinkingConfig(thinking_level=THINKING_LEVEL),
|
|
@@ -61,21 +86,11 @@ def _gemini_config() -> types.GenerateContentConfig:
|
|
| 61 |
|
| 62 |
|
| 63 |
def llm_chat(user_text: str) -> str:
|
| 64 |
-
"""
|
| 65 |
-
Updates global HISTORY (user + model), calls Gemini, returns model reply text.
|
| 66 |
-
Rolls back the last user message if Gemini call fails.
|
| 67 |
-
"""
|
| 68 |
user_text = (user_text or "").strip()
|
| 69 |
if not user_text:
|
| 70 |
raise ValueError("Missing 'text'")
|
| 71 |
|
| 72 |
-
|
| 73 |
-
HISTORY.append(
|
| 74 |
-
types.Content(
|
| 75 |
-
role="user",
|
| 76 |
-
parts=[types.Part.from_text(text=user_text)],
|
| 77 |
-
)
|
| 78 |
-
)
|
| 79 |
|
| 80 |
try:
|
| 81 |
resp = client.models.generate_content(
|
|
@@ -85,18 +100,10 @@ def llm_chat(user_text: str) -> str:
|
|
| 85 |
)
|
| 86 |
reply_text = (resp.text or "").strip()
|
| 87 |
|
| 88 |
-
|
| 89 |
-
HISTORY.append(
|
| 90 |
-
types.Content(
|
| 91 |
-
role="model",
|
| 92 |
-
parts=[types.Part.from_text(text=reply_text)],
|
| 93 |
-
)
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
return reply_text
|
| 97 |
|
| 98 |
except Exception:
|
| 99 |
-
# Roll back last user message on failure
|
| 100 |
if len(HISTORY) > 0 and getattr(HISTORY[-1], "role", None) == "user":
|
| 101 |
HISTORY.pop()
|
| 102 |
raise
|
|
@@ -114,6 +121,9 @@ def health():
|
|
| 114 |
"thinking_level": THINKING_LEVEL,
|
| 115 |
"memory_messages": len(HISTORY),
|
| 116 |
"max_messages": MAX_MESSAGES,
|
|
|
|
|
|
|
|
|
|
| 117 |
})
|
| 118 |
|
| 119 |
|
|
@@ -158,14 +168,62 @@ def chat_text():
|
|
| 158 |
|
| 159 |
|
| 160 |
@app.post("/v1/utterance")
|
| 161 |
-
def
|
| 162 |
"""
|
| 163 |
-
|
| 164 |
-
|
| 165 |
"""
|
|
|
|
| 166 |
ip = _client_ip()
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
@app.post("/v1/reset")
|
|
@@ -182,4 +240,5 @@ def reset():
|
|
| 182 |
if __name__ == "__main__":
|
| 183 |
port = int(os.environ.get("PORT", "7860"))
|
| 184 |
print(f"[startup] model={MODEL} thinking_level={THINKING_LEVEL} max_messages={MAX_MESSAGES} port={port}")
|
|
|
|
| 185 |
serve(app, host="0.0.0.0", port=port)
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
+
import tempfile
|
| 4 |
from collections import deque
|
| 5 |
from flask import Flask, request, jsonify
|
| 6 |
from waitress import serve
|
|
|
|
| 8 |
from google import genai
|
| 9 |
from google.genai import types
|
| 10 |
|
| 11 |
+
from faster_whisper import WhisperModel
|
| 12 |
+
|
| 13 |
app = Flask(__name__)
|
| 14 |
|
| 15 |
# -------------------------
|
|
|
|
| 23 |
"Respond in 1-3 sentences and less than 300 characters."
|
| 24 |
)
|
| 25 |
|
| 26 |
+
# STT config (we chose base.en)
|
| 27 |
+
WHISPER_MODEL_NAME = os.environ.get("WHISPER_MODEL", "base.en")
|
| 28 |
+
WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cpu")
|
| 29 |
+
WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE", "int8")
|
| 30 |
+
WHISPER_LANGUAGE = os.environ.get("WHISPER_LANGUAGE", "en")
|
| 31 |
+
|
| 32 |
# Gemini client (expects GEMINI_API_KEY set as a HF Space Secret)
|
| 33 |
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
|
| 34 |
|
|
|
|
| 36 |
MAX_MESSAGES = 20 # user+assistant messages combined
|
| 37 |
HISTORY = deque(maxlen=MAX_MESSAGES) # holds types.Content objects
|
| 38 |
|
| 39 |
+
# ---- Whisper model (lazy init) ----
|
| 40 |
+
_whisper_model = None
|
| 41 |
+
|
| 42 |
|
| 43 |
def _client_ip() -> str:
|
|
|
|
| 44 |
return request.headers.get("x-forwarded-for", request.remote_addr or "unknown")
|
| 45 |
|
| 46 |
|
| 47 |
+
def _get_whisper_model() -> WhisperModel:
|
| 48 |
+
global _whisper_model
|
| 49 |
+
if _whisper_model is None:
|
| 50 |
+
print(
|
| 51 |
+
f"[whisper] loading model={WHISPER_MODEL_NAME} "
|
| 52 |
+
f"device={WHISPER_DEVICE} compute_type={WHISPER_COMPUTE_TYPE}"
|
| 53 |
+
)
|
| 54 |
+
_whisper_model = WhisperModel(
|
| 55 |
+
WHISPER_MODEL_NAME,
|
| 56 |
+
device=WHISPER_DEVICE,
|
| 57 |
+
compute_type=WHISPER_COMPUTE_TYPE,
|
| 58 |
+
)
|
| 59 |
+
print("[whisper] loaded")
|
| 60 |
+
return _whisper_model
|
| 61 |
+
|
| 62 |
+
|
| 63 |
def _gemini_config() -> types.GenerateContentConfig:
|
|
|
|
|
|
|
| 64 |
return types.GenerateContentConfig(
|
| 65 |
system_instruction=[types.Part.from_text(text=SYSTEM_PROMPT)],
|
| 66 |
thinking_config=types.ThinkingConfig(thinking_level=THINKING_LEVEL),
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def llm_chat(user_text: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
user_text = (user_text or "").strip()
|
| 90 |
if not user_text:
|
| 91 |
raise ValueError("Missing 'text'")
|
| 92 |
|
| 93 |
+
HISTORY.append(types.Content(role="user", parts=[types.Part.from_text(text=user_text)]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
try:
|
| 96 |
resp = client.models.generate_content(
|
|
|
|
| 100 |
)
|
| 101 |
reply_text = (resp.text or "").strip()
|
| 102 |
|
| 103 |
+
HISTORY.append(types.Content(role="model", parts=[types.Part.from_text(text=reply_text)]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
return reply_text
|
| 105 |
|
| 106 |
except Exception:
|
|
|
|
| 107 |
if len(HISTORY) > 0 and getattr(HISTORY[-1], "role", None) == "user":
|
| 108 |
HISTORY.pop()
|
| 109 |
raise
|
|
|
|
| 121 |
"thinking_level": THINKING_LEVEL,
|
| 122 |
"memory_messages": len(HISTORY),
|
| 123 |
"max_messages": MAX_MESSAGES,
|
| 124 |
+
"whisper_model": WHISPER_MODEL_NAME,
|
| 125 |
+
"whisper_device": WHISPER_DEVICE,
|
| 126 |
+
"whisper_compute_type": WHISPER_COMPUTE_TYPE,
|
| 127 |
})
|
| 128 |
|
| 129 |
|
|
|
|
| 168 |
|
| 169 |
|
| 170 |
@app.post("/v1/utterance")
|
| 171 |
+
def utterance_to_text():
|
| 172 |
"""
|
| 173 |
+
Accepts: multipart/form-data with field "audio" containing a .wav file
|
| 174 |
+
Returns: JSON { "text": "<transcript>", "total_ms": <int> }
|
| 175 |
"""
|
| 176 |
+
t0 = time.time()
|
| 177 |
ip = _client_ip()
|
| 178 |
+
|
| 179 |
+
print(f"[/v1/utterance] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip}")
|
| 180 |
+
|
| 181 |
+
if "audio" not in request.files:
|
| 182 |
+
print(f"[/v1/utterance] ERROR missing file field 'audio' ip={ip}")
|
| 183 |
+
return jsonify({"error": "Missing file field 'audio'"}), 400
|
| 184 |
+
|
| 185 |
+
f = request.files["audio"]
|
| 186 |
+
filename = (f.filename or "").strip() or "audio.wav"
|
| 187 |
+
|
| 188 |
+
if not filename.lower().endswith(".wav"):
|
| 189 |
+
print(f"[/v1/utterance] ERROR non-wav filename={filename!r} ip={ip}")
|
| 190 |
+
return jsonify({"error": "Please upload a .wav file"}), 400
|
| 191 |
+
|
| 192 |
+
print(f"[/v1/utterance] received filename={filename!r} content_type={f.content_type!r}")
|
| 193 |
+
|
| 194 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
| 195 |
+
tmp_path = tmp.name
|
| 196 |
+
f.save(tmp_path)
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
model = _get_whisper_model()
|
| 200 |
+
|
| 201 |
+
segments, info = model.transcribe(
|
| 202 |
+
tmp_path,
|
| 203 |
+
language=WHISPER_LANGUAGE,
|
| 204 |
+
vad_filter=True,
|
| 205 |
+
beam_size=1, # fast
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
text = "".join(seg.text for seg in segments).strip()
|
| 209 |
+
|
| 210 |
+
dt_ms = int((time.time() - t0) * 1000)
|
| 211 |
+
print(f"[/v1/utterance] transcript_len={len(text)} total_ms={dt_ms}")
|
| 212 |
+
print(f"[/v1/utterance] transcript={text!r}")
|
| 213 |
+
|
| 214 |
+
return jsonify({"text": text, "total_ms": dt_ms})
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
dt_ms = int((time.time() - t0) * 1000)
|
| 218 |
+
print("Whisper error:", repr(e))
|
| 219 |
+
print(f"[/v1/utterance] FAIL ip={ip} total_ms={dt_ms}")
|
| 220 |
+
return jsonify({"error": "STT failed"}), 500
|
| 221 |
+
|
| 222 |
+
finally:
|
| 223 |
+
try:
|
| 224 |
+
os.remove(tmp_path)
|
| 225 |
+
except Exception:
|
| 226 |
+
pass
|
| 227 |
|
| 228 |
|
| 229 |
@app.post("/v1/reset")
|
|
|
|
| 240 |
if __name__ == "__main__":
|
| 241 |
port = int(os.environ.get("PORT", "7860"))
|
| 242 |
print(f"[startup] model={MODEL} thinking_level={THINKING_LEVEL} max_messages={MAX_MESSAGES} port={port}")
|
| 243 |
+
print(f"[startup] whisper_model={WHISPER_MODEL_NAME} device={WHISPER_DEVICE} compute={WHISPER_COMPUTE_TYPE}")
|
| 244 |
serve(app, host="0.0.0.0", port=port)
|