|
|
import os |
|
|
import time |
|
|
import json |
|
|
import tempfile |
|
|
from collections import deque |
|
|
|
|
|
from flask import Flask, request, jsonify, send_file |
|
|
from waitress import serve |
|
|
|
|
|
from google import genai |
|
|
from google.genai import types |
|
|
|
|
|
from faster_whisper import WhisperModel |
|
|
|
|
|
from elevenlabs.client import ElevenLabs |
|
|
from elevenlabs import save, VoiceSettings |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL = os.environ.get("GEMINI_MODEL", "gemini-3-flash-preview") |
|
|
THINKING_LEVEL = os.environ.get("GEMINI_THINKING_LEVEL", "HIGH") |
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"You should respond like Andy Warhol.\n" |
|
|
"Respond in 1-3 sentences and less than 200 characters.\n" |
|
|
"You should say uh 0-2 times per response, it can be in different parts of the response.\n" |
|
|
"Don't repeat yourself too much.\n" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
API_PASSWORD = os.environ.get("API_PASSWORD", "").strip() |
|
|
|
|
|
def _require_auth(): |
|
|
""" |
|
|
Require a shared secret from the client. |
|
|
Client must send header: X-API-PASSWORD: <secret> |
|
|
""" |
|
|
if not API_PASSWORD: |
|
|
|
|
|
return jsonify({"error": "Server missing API_PASSWORD secret"}), 500 |
|
|
|
|
|
provided = (request.headers.get("X-API-PASSWORD") or "").strip() |
|
|
if not provided or provided != API_PASSWORD: |
|
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
WHISPER_MODEL_NAME = os.environ.get("WHISPER_MODEL", "base.en") |
|
|
WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cpu") |
|
|
WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE", "int8") |
|
|
WHISPER_LANGUAGE = os.environ.get("WHISPER_LANGUAGE", "en") |
|
|
|
|
|
|
|
|
ELEVEN_API_KEY = os.environ.get("ELEVEN_API_KEY") |
|
|
ELEVEN_VOICE_ID = os.environ.get("ELEVEN_VOICE_ID", "kYLLcRUC2uzrEp0Jr2HT") |
|
|
ELEVEN_MODEL_ID = os.environ.get("ELEVEN_MODEL_ID", "eleven_multilingual_v2") |
|
|
ELEVEN_OUTPUT_FORMAT = os.environ.get("ELEVEN_OUTPUT_FORMAT", "mp3_44100_128") |
|
|
|
|
|
|
|
|
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) |
|
|
|
|
|
|
|
|
eleven = ElevenLabs(api_key=ELEVEN_API_KEY) if ELEVEN_API_KEY else None |
|
|
|
|
|
|
|
|
MAX_MESSAGES = 20 |
|
|
HISTORY = deque(maxlen=MAX_MESSAGES) |
|
|
|
|
|
|
|
|
_whisper_model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _client_ip() -> str: |
|
|
return request.headers.get("x-forwarded-for", request.remote_addr or "unknown") |
|
|
|
|
|
|
|
|
def _err_details(e: Exception) -> dict: |
|
|
d = {"type": type(e).__name__, "repr": repr(e)} |
|
|
for k in ["status_code", "body", "message", "response", "details"]: |
|
|
if hasattr(e, k): |
|
|
try: |
|
|
d[k] = getattr(e, k) |
|
|
except Exception: |
|
|
pass |
|
|
return d |
|
|
|
|
|
|
|
|
def _get_whisper_model() -> WhisperModel: |
|
|
global _whisper_model |
|
|
if _whisper_model is None: |
|
|
print( |
|
|
f"[whisper] loading model={WHISPER_MODEL_NAME} " |
|
|
f"device={WHISPER_DEVICE} compute_type={WHISPER_COMPUTE_TYPE}" |
|
|
) |
|
|
_whisper_model = WhisperModel( |
|
|
WHISPER_MODEL_NAME, |
|
|
device=WHISPER_DEVICE, |
|
|
compute_type=WHISPER_COMPUTE_TYPE, |
|
|
) |
|
|
print("[whisper] loaded") |
|
|
return _whisper_model |
|
|
|
|
|
|
|
|
def _clean_reply(text: str) -> str: |
|
|
t = (text or "").strip() |
|
|
if not t: |
|
|
return t |
|
|
if t.endswith(("...", "…", ",")): |
|
|
t = t.rstrip(".,… ,").strip() |
|
|
if t and t[-1] not in ".?!": |
|
|
t += "." |
|
|
return t |
|
|
|
|
|
|
|
|
def _gemini_config() -> types.GenerateContentConfig: |
|
|
return types.GenerateContentConfig( |
|
|
system_instruction=[types.Part.from_text(text=SYSTEM_PROMPT)], |
|
|
thinking_config=types.ThinkingConfig(thinking_level=THINKING_LEVEL), |
|
|
|
|
|
|
|
|
safety_settings=[ |
|
|
types.SafetySetting( |
|
|
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, |
|
|
threshold=types.HarmBlockThreshold.OFF, |
|
|
), |
|
|
types.SafetySetting( |
|
|
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, |
|
|
threshold=types.HarmBlockThreshold.OFF, |
|
|
), |
|
|
types.SafetySetting( |
|
|
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, |
|
|
threshold=types.HarmBlockThreshold.OFF, |
|
|
), |
|
|
types.SafetySetting( |
|
|
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, |
|
|
threshold=types.HarmBlockThreshold.OFF, |
|
|
), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
def llm_chat(user_text: str) -> str: |
|
|
user_text = (user_text or "").strip() |
|
|
if not user_text: |
|
|
raise ValueError("Missing 'text'") |
|
|
|
|
|
HISTORY.append(types.Content(role="user", parts=[types.Part.from_text(text=user_text)])) |
|
|
|
|
|
try: |
|
|
resp = client.models.generate_content( |
|
|
model=MODEL, |
|
|
contents=list(HISTORY), |
|
|
config=_gemini_config(), |
|
|
) |
|
|
reply_text = _clean_reply(resp.text) |
|
|
|
|
|
HISTORY.append(types.Content(role="model", parts=[types.Part.from_text(text=reply_text)])) |
|
|
return reply_text |
|
|
|
|
|
except Exception: |
|
|
if len(HISTORY) > 0 and getattr(HISTORY[-1], "role", None) == "user": |
|
|
HISTORY.pop() |
|
|
raise |
|
|
|
|
|
|
|
|
def _tts_to_mp3_file(text: str) -> tuple[str, int]: |
|
|
""" |
|
|
Returns: (mp3_path, tts_ms) |
|
|
Raises exception on failure. |
|
|
""" |
|
|
if eleven is None: |
|
|
raise RuntimeError("Server missing ELEVEN_API_KEY") |
|
|
|
|
|
t0 = time.time() |
|
|
|
|
|
audio_stream = eleven.text_to_speech.convert( |
|
|
text=text, |
|
|
voice_id=ELEVEN_VOICE_ID, |
|
|
model_id=ELEVEN_MODEL_ID, |
|
|
output_format=ELEVEN_OUTPUT_FORMAT, |
|
|
voice_settings=VoiceSettings( |
|
|
speed=0.8, |
|
|
), |
|
|
) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_out: |
|
|
mp3_path = tmp_out.name |
|
|
|
|
|
save(audio_stream, mp3_path) |
|
|
tts_ms = int((time.time() - t0) * 1000) |
|
|
return mp3_path, tts_ms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
print(f"[/health] {time.strftime('%Y-%m-%d %H:%M:%S')} ip={_client_ip()} mem={len(HISTORY)}/{MAX_MESSAGES}") |
|
|
return jsonify({ |
|
|
"ok": True, |
|
|
"model": MODEL, |
|
|
"thinking_level": THINKING_LEVEL, |
|
|
"memory_messages": len(HISTORY), |
|
|
"max_messages": MAX_MESSAGES, |
|
|
"whisper_model": WHISPER_MODEL_NAME, |
|
|
"whisper_device": WHISPER_DEVICE, |
|
|
"whisper_compute_type": WHISPER_COMPUTE_TYPE, |
|
|
"eleven_ok": bool(ELEVEN_API_KEY), |
|
|
"eleven_voice_id": ELEVEN_VOICE_ID, |
|
|
"eleven_model_id": ELEVEN_MODEL_ID, |
|
|
"eleven_output_format": ELEVEN_OUTPUT_FORMAT, |
|
|
}) |
|
|
|
|
|
|
|
|
@app.post("/v1/chat") |
|
|
def chat_text(): |
|
|
auth_resp = _require_auth() |
|
|
if auth_resp: |
|
|
return auth_resp |
|
|
|
|
|
t0 = time.time() |
|
|
ip = _client_ip() |
|
|
|
|
|
data = request.get_json(silent=True) or {} |
|
|
user_text = (data.get("text") or "").strip() |
|
|
|
|
|
print(f"[/v1/chat] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} mem_before={len(HISTORY)}/{MAX_MESSAGES}") |
|
|
|
|
|
if not user_text: |
|
|
print(f"[/v1/chat] ERROR missing text ip={ip}") |
|
|
return jsonify({"error": "Missing 'text'"}), 400 |
|
|
|
|
|
print(f"[/v1/chat] user_text_len={len(user_text)} user_text={user_text!r}") |
|
|
|
|
|
try: |
|
|
reply_text = llm_chat(user_text) |
|
|
dt_ms = int((time.time() - t0) * 1000) |
|
|
|
|
|
print(f"[/v1/chat] bot_reply={reply_text!r}") |
|
|
print(f"[/v1/chat] END ip={ip} total_ms={dt_ms} mem_now={len(HISTORY)}/{MAX_MESSAGES}") |
|
|
|
|
|
return jsonify({ |
|
|
"input": user_text, |
|
|
"reply_text": reply_text, |
|
|
"model": MODEL, |
|
|
"memory_messages": len(HISTORY), |
|
|
"total_ms": dt_ms, |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
dt_ms = int((time.time() - t0) * 1000) |
|
|
print("Gemini error:", repr(e)) |
|
|
print(f"[/v1/chat] FAIL ip={ip} total_ms={dt_ms} mem_now={len(HISTORY)}/{MAX_MESSAGES}") |
|
|
return jsonify({"error": "Gemini call failed"}), 500 |
|
|
|
|
|
|
|
|
@app.post("/v1/tts") |
|
|
def tts_only(): |
|
|
""" |
|
|
JSON body: { "text": "hello" } |
|
|
Returns: audio/mpeg (mp3) |
|
|
Timing headers: |
|
|
X-TTS-MS, X-TOTAL-MS |
|
|
""" |
|
|
auth_resp = _require_auth() |
|
|
if auth_resp: |
|
|
return auth_resp |
|
|
|
|
|
ip = _client_ip() |
|
|
t0 = time.time() |
|
|
|
|
|
data = request.get_json(silent=True) or {} |
|
|
text = (data.get("text") or "").strip() |
|
|
|
|
|
print(f"[/v1/tts] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} text_len={len(text)}") |
|
|
|
|
|
if not text: |
|
|
return jsonify({"error": "Missing 'text'"}), 400 |
|
|
|
|
|
mp3_path = None |
|
|
try: |
|
|
mp3_path, tts_ms = _tts_to_mp3_file(text) |
|
|
total_ms = int((time.time() - t0) * 1000) |
|
|
|
|
|
print(f"[/v1/tts] OK tts_ms={tts_ms} total_ms={total_ms}") |
|
|
|
|
|
resp = send_file( |
|
|
mp3_path, |
|
|
mimetype="audio/mpeg", |
|
|
as_attachment=False, |
|
|
download_name="andy.mp3", |
|
|
conditional=False, |
|
|
) |
|
|
resp.headers["X-TTS-MS"] = str(tts_ms) |
|
|
resp.headers["X-TOTAL-MS"] = str(total_ms) |
|
|
return resp |
|
|
|
|
|
except Exception as e: |
|
|
details = _err_details(e) |
|
|
total_ms = int((time.time() - t0) * 1000) |
|
|
print(f"[/v1/tts] FAIL total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}") |
|
|
return jsonify({"error": "ElevenLabs TTS failed", "details": details, "total_ms": total_ms}), 502 |
|
|
|
|
|
finally: |
|
|
if mp3_path: |
|
|
try: |
|
|
os.remove(mp3_path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
@app.post("/v1/utterance") |
|
|
def utterance_audio_to_audio(): |
|
|
""" |
|
|
Accepts: multipart/form-data with field "audio" containing a .wav file |
|
|
Returns: audio/mpeg (mp3) |
|
|
|
|
|
Timing headers: |
|
|
X-STT-MS, X-LLM-MS, X-TTS-MS, X-TOTAL-MS |
|
|
""" |
|
|
auth_resp = _require_auth() |
|
|
if auth_resp: |
|
|
return auth_resp |
|
|
|
|
|
t0 = time.time() |
|
|
ip = _client_ip() |
|
|
|
|
|
print(f"[/v1/utterance] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip}") |
|
|
|
|
|
if eleven is None: |
|
|
print("[/v1/utterance] ERROR missing ELEVEN_API_KEY") |
|
|
return jsonify({"error": "Server missing ELEVEN_API_KEY"}), 500 |
|
|
|
|
|
if "audio" not in request.files: |
|
|
print(f"[/v1/utterance] ERROR missing file field 'audio' ip={ip}") |
|
|
return jsonify({"error": "Missing file field 'audio'"}), 400 |
|
|
|
|
|
f = request.files["audio"] |
|
|
filename = (f.filename or "").strip() or "audio.wav" |
|
|
if not filename.lower().endswith(".wav"): |
|
|
print(f"[/v1/utterance] ERROR non-wav filename={filename!r} ip={ip}") |
|
|
return jsonify({"error": "Please upload a .wav file"}), 400 |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_in: |
|
|
wav_path = tmp_in.name |
|
|
f.save(wav_path) |
|
|
|
|
|
mp3_path = None |
|
|
stt_ms = llm_ms = tts_ms = 0 |
|
|
transcript = "" |
|
|
reply_text = "" |
|
|
|
|
|
try: |
|
|
|
|
|
t_stt = time.time() |
|
|
model = _get_whisper_model() |
|
|
|
|
|
segments, _info = model.transcribe( |
|
|
wav_path, |
|
|
language=WHISPER_LANGUAGE, |
|
|
vad_filter=True, |
|
|
beam_size=1, |
|
|
) |
|
|
transcript = "".join(seg.text for seg in segments).strip() |
|
|
stt_ms = int((time.time() - t_stt) * 1000) |
|
|
|
|
|
print(f"[/v1/utterance] transcript_len={len(transcript)} stt_ms={stt_ms}") |
|
|
print(f"[/v1/utterance] transcript={transcript!r}") |
|
|
|
|
|
if not transcript: |
|
|
total_ms = int((time.time() - t0) * 1000) |
|
|
print(f"[/v1/utterance] EMPTY transcript total_ms={total_ms}") |
|
|
return jsonify({"error": "Empty transcript", "stt_ms": stt_ms, "total_ms": total_ms}), 200 |
|
|
|
|
|
|
|
|
t_llm = time.time() |
|
|
reply_text = llm_chat(transcript) |
|
|
llm_ms = int((time.time() - t_llm) * 1000) |
|
|
|
|
|
print(f"[/v1/utterance] reply_len={len(reply_text)} llm_ms={llm_ms}") |
|
|
print(f"[/v1/utterance] bot_reply={reply_text!r}") |
|
|
|
|
|
|
|
|
try: |
|
|
mp3_path, tts_ms = _tts_to_mp3_file(reply_text) |
|
|
except Exception as e: |
|
|
details = _err_details(e) |
|
|
total_ms = int((time.time() - t0) * 1000) |
|
|
print(f"[/v1/utterance] TTS FAIL total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}") |
|
|
return jsonify({ |
|
|
"error": "ElevenLabs TTS failed", |
|
|
"details": details, |
|
|
"transcript": transcript, |
|
|
"reply_text": reply_text, |
|
|
"stt_ms": stt_ms, |
|
|
"llm_ms": llm_ms, |
|
|
"total_ms": total_ms, |
|
|
}), 502 |
|
|
|
|
|
total_ms = int((time.time() - t0) * 1000) |
|
|
print(f"[/v1/utterance] tts_ms={tts_ms} total_ms={total_ms}") |
|
|
print(f"[/v1/utterance] END ip={ip}") |
|
|
|
|
|
resp = send_file( |
|
|
mp3_path, |
|
|
mimetype="audio/mpeg", |
|
|
as_attachment=False, |
|
|
download_name="andy.mp3", |
|
|
conditional=False, |
|
|
) |
|
|
resp.headers["X-STT-MS"] = str(stt_ms) |
|
|
resp.headers["X-LLM-MS"] = str(llm_ms) |
|
|
resp.headers["X-TTS-MS"] = str(tts_ms) |
|
|
resp.headers["X-TOTAL-MS"] = str(total_ms) |
|
|
return resp |
|
|
|
|
|
except Exception as e: |
|
|
total_ms = int((time.time() - t0) * 1000) |
|
|
details = _err_details(e) |
|
|
print(f"[/v1/utterance] FAIL ip={ip} total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}") |
|
|
return jsonify({"error": "Utterance pipeline failed", "details": details, "total_ms": total_ms}), 500 |
|
|
|
|
|
finally: |
|
|
try: |
|
|
os.remove(wav_path) |
|
|
except Exception: |
|
|
pass |
|
|
if mp3_path: |
|
|
try: |
|
|
os.remove(mp3_path) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
@app.post("/v1/reset") |
|
|
def reset(): |
|
|
auth_resp = _require_auth() |
|
|
if auth_resp: |
|
|
return auth_resp |
|
|
|
|
|
ip = _client_ip() |
|
|
print(f"[/v1/reset] {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} clearing mem (was {len(HISTORY)}/{MAX_MESSAGES})") |
|
|
HISTORY.clear() |
|
|
return jsonify({"ok": True, "memory_messages": 0}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.environ.get("PORT", "7860")) |
|
|
print(f"[startup] model={MODEL} thinking_level={THINKING_LEVEL} max_messages={MAX_MESSAGES} port={port}") |
|
|
print(f"[startup] whisper_model={WHISPER_MODEL_NAME} device={WHISPER_DEVICE} compute={WHISPER_COMPUTE_TYPE}") |
|
|
print(f"[startup] eleven_ok={bool(ELEVEN_API_KEY)} voice={ELEVEN_VOICE_ID} model={ELEVEN_MODEL_ID} out={ELEVEN_OUTPUT_FORMAT}") |
|
|
serve(app, host="0.0.0.0", port=port) |
|
|
|