andy / app.py
tonyassi's picture
Update app.py
e7f0c14 verified
import os
import time
import json
import tempfile
from collections import deque
from flask import Flask, request, jsonify, send_file
from waitress import serve
from google import genai
from google.genai import types
from faster_whisper import WhisperModel
from elevenlabs.client import ElevenLabs
from elevenlabs import save, VoiceSettings
app = Flask(__name__)
# -------------------------
# Config
# -------------------------
MODEL = os.environ.get("GEMINI_MODEL", "gemini-3-flash-preview")
THINKING_LEVEL = os.environ.get("GEMINI_THINKING_LEVEL", "HIGH")
SYSTEM_PROMPT = (
"You should respond like Andy Warhol.\n"
"Respond in 1-3 sentences and less than 200 characters.\n"
"You should say uh 0-2 times per response, it can be in different parts of the response.\n"
"Don't repeat yourself too much.\n"
)
# -------------------------
# Auth
# -------------------------
API_PASSWORD = os.environ.get("API_PASSWORD", "").strip()
def _require_auth():
"""
Require a shared secret from the client.
Client must send header: X-API-PASSWORD: <secret>
"""
if not API_PASSWORD:
# If you forget to set the secret, fail closed.
return jsonify({"error": "Server missing API_PASSWORD secret"}), 500
provided = (request.headers.get("X-API-PASSWORD") or "").strip()
if not provided or provided != API_PASSWORD:
return jsonify({"error": "Unauthorized"}), 401
return None
# STT (we chose base.en)
WHISPER_MODEL_NAME = os.environ.get("WHISPER_MODEL", "base.en")
WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cpu")
WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE", "int8")
WHISPER_LANGUAGE = os.environ.get("WHISPER_LANGUAGE", "en")
# ElevenLabs
ELEVEN_API_KEY = os.environ.get("ELEVEN_API_KEY")
ELEVEN_VOICE_ID = os.environ.get("ELEVEN_VOICE_ID", "kYLLcRUC2uzrEp0Jr2HT")
ELEVEN_MODEL_ID = os.environ.get("ELEVEN_MODEL_ID", "eleven_multilingual_v2")
ELEVEN_OUTPUT_FORMAT = os.environ.get("ELEVEN_OUTPUT_FORMAT", "mp3_44100_128")
# Gemini client (expects GEMINI_API_KEY set as a HF Space Secret)
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
# Eleven client
eleven = ElevenLabs(api_key=ELEVEN_API_KEY) if ELEVEN_API_KEY else None
# ---- Memory (global, RAM-only, survives refresh, resets on Space restart) ----
MAX_MESSAGES = 20
HISTORY = deque(maxlen=MAX_MESSAGES)
# ---- Whisper model (lazy init) ----
_whisper_model = None
# -------------------------
# Helpers
# -------------------------
def _client_ip() -> str:
return request.headers.get("x-forwarded-for", request.remote_addr or "unknown")
def _err_details(e: Exception) -> dict:
d = {"type": type(e).__name__, "repr": repr(e)}
for k in ["status_code", "body", "message", "response", "details"]:
if hasattr(e, k):
try:
d[k] = getattr(e, k)
except Exception:
pass
return d
def _get_whisper_model() -> WhisperModel:
global _whisper_model
if _whisper_model is None:
print(
f"[whisper] loading model={WHISPER_MODEL_NAME} "
f"device={WHISPER_DEVICE} compute_type={WHISPER_COMPUTE_TYPE}"
)
_whisper_model = WhisperModel(
WHISPER_MODEL_NAME,
device=WHISPER_DEVICE,
compute_type=WHISPER_COMPUTE_TYPE,
)
print("[whisper] loaded")
return _whisper_model
def _clean_reply(text: str) -> str:
t = (text or "").strip()
if not t:
return t
if t.endswith(("...", "…", ",")):
t = t.rstrip(".,… ,").strip()
if t and t[-1] not in ".?!":
t += "."
return t
def _gemini_config() -> types.GenerateContentConfig:
return types.GenerateContentConfig(
system_instruction=[types.Part.from_text(text=SYSTEM_PROMPT)],
thinking_config=types.ThinkingConfig(thinking_level=THINKING_LEVEL),
#max_output_tokens=256,
#temperature=0.7,
safety_settings=[
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=types.HarmBlockThreshold.OFF,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=types.HarmBlockThreshold.OFF,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=types.HarmBlockThreshold.OFF,
),
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
],
)
def llm_chat(user_text: str) -> str:
user_text = (user_text or "").strip()
if not user_text:
raise ValueError("Missing 'text'")
HISTORY.append(types.Content(role="user", parts=[types.Part.from_text(text=user_text)]))
try:
resp = client.models.generate_content(
model=MODEL,
contents=list(HISTORY),
config=_gemini_config(),
)
reply_text = _clean_reply(resp.text)
HISTORY.append(types.Content(role="model", parts=[types.Part.from_text(text=reply_text)]))
return reply_text
except Exception:
if len(HISTORY) > 0 and getattr(HISTORY[-1], "role", None) == "user":
HISTORY.pop()
raise
def _tts_to_mp3_file(text: str) -> tuple[str, int]:
"""
Returns: (mp3_path, tts_ms)
Raises exception on failure.
"""
if eleven is None:
raise RuntimeError("Server missing ELEVEN_API_KEY")
t0 = time.time()
audio_stream = eleven.text_to_speech.convert(
text=text,
voice_id=ELEVEN_VOICE_ID,
model_id=ELEVEN_MODEL_ID,
output_format=ELEVEN_OUTPUT_FORMAT,
voice_settings=VoiceSettings(
speed=0.8,
),
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_out:
mp3_path = tmp_out.name
save(audio_stream, mp3_path)
tts_ms = int((time.time() - t0) * 1000)
return mp3_path, tts_ms
# -------------------------
# Endpoints
# -------------------------
@app.get("/health")
def health():
print(f"[/health] {time.strftime('%Y-%m-%d %H:%M:%S')} ip={_client_ip()} mem={len(HISTORY)}/{MAX_MESSAGES}")
return jsonify({
"ok": True,
"model": MODEL,
"thinking_level": THINKING_LEVEL,
"memory_messages": len(HISTORY),
"max_messages": MAX_MESSAGES,
"whisper_model": WHISPER_MODEL_NAME,
"whisper_device": WHISPER_DEVICE,
"whisper_compute_type": WHISPER_COMPUTE_TYPE,
"eleven_ok": bool(ELEVEN_API_KEY),
"eleven_voice_id": ELEVEN_VOICE_ID,
"eleven_model_id": ELEVEN_MODEL_ID,
"eleven_output_format": ELEVEN_OUTPUT_FORMAT,
})
@app.post("/v1/chat")
def chat_text():
auth_resp = _require_auth()
if auth_resp:
return auth_resp
t0 = time.time()
ip = _client_ip()
data = request.get_json(silent=True) or {}
user_text = (data.get("text") or "").strip()
print(f"[/v1/chat] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} mem_before={len(HISTORY)}/{MAX_MESSAGES}")
if not user_text:
print(f"[/v1/chat] ERROR missing text ip={ip}")
return jsonify({"error": "Missing 'text'"}), 400
print(f"[/v1/chat] user_text_len={len(user_text)} user_text={user_text!r}")
try:
reply_text = llm_chat(user_text)
dt_ms = int((time.time() - t0) * 1000)
print(f"[/v1/chat] bot_reply={reply_text!r}")
print(f"[/v1/chat] END ip={ip} total_ms={dt_ms} mem_now={len(HISTORY)}/{MAX_MESSAGES}")
return jsonify({
"input": user_text,
"reply_text": reply_text,
"model": MODEL,
"memory_messages": len(HISTORY),
"total_ms": dt_ms,
})
except Exception as e:
dt_ms = int((time.time() - t0) * 1000)
print("Gemini error:", repr(e))
print(f"[/v1/chat] FAIL ip={ip} total_ms={dt_ms} mem_now={len(HISTORY)}/{MAX_MESSAGES}")
return jsonify({"error": "Gemini call failed"}), 500
@app.post("/v1/tts")
def tts_only():
"""
JSON body: { "text": "hello" }
Returns: audio/mpeg (mp3)
Timing headers:
X-TTS-MS, X-TOTAL-MS
"""
auth_resp = _require_auth()
if auth_resp:
return auth_resp
ip = _client_ip()
t0 = time.time()
data = request.get_json(silent=True) or {}
text = (data.get("text") or "").strip()
print(f"[/v1/tts] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} text_len={len(text)}")
if not text:
return jsonify({"error": "Missing 'text'"}), 400
mp3_path = None
try:
mp3_path, tts_ms = _tts_to_mp3_file(text)
total_ms = int((time.time() - t0) * 1000)
print(f"[/v1/tts] OK tts_ms={tts_ms} total_ms={total_ms}")
resp = send_file(
mp3_path,
mimetype="audio/mpeg",
as_attachment=False,
download_name="andy.mp3",
conditional=False,
)
resp.headers["X-TTS-MS"] = str(tts_ms)
resp.headers["X-TOTAL-MS"] = str(total_ms)
return resp
except Exception as e:
details = _err_details(e)
total_ms = int((time.time() - t0) * 1000)
print(f"[/v1/tts] FAIL total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}")
return jsonify({"error": "ElevenLabs TTS failed", "details": details, "total_ms": total_ms}), 502
finally:
if mp3_path:
try:
os.remove(mp3_path)
except Exception:
pass
@app.post("/v1/utterance")
def utterance_audio_to_audio():
"""
Accepts: multipart/form-data with field "audio" containing a .wav file
Returns: audio/mpeg (mp3)
Timing headers:
X-STT-MS, X-LLM-MS, X-TTS-MS, X-TOTAL-MS
"""
auth_resp = _require_auth()
if auth_resp:
return auth_resp
t0 = time.time()
ip = _client_ip()
print(f"[/v1/utterance] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip}")
if eleven is None:
print("[/v1/utterance] ERROR missing ELEVEN_API_KEY")
return jsonify({"error": "Server missing ELEVEN_API_KEY"}), 500
if "audio" not in request.files:
print(f"[/v1/utterance] ERROR missing file field 'audio' ip={ip}")
return jsonify({"error": "Missing file field 'audio'"}), 400
f = request.files["audio"]
filename = (f.filename or "").strip() or "audio.wav"
if not filename.lower().endswith(".wav"):
print(f"[/v1/utterance] ERROR non-wav filename={filename!r} ip={ip}")
return jsonify({"error": "Please upload a .wav file"}), 400
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_in:
wav_path = tmp_in.name
f.save(wav_path)
mp3_path = None
stt_ms = llm_ms = tts_ms = 0
transcript = ""
reply_text = ""
try:
# ---- STT ----
t_stt = time.time()
model = _get_whisper_model()
segments, _info = model.transcribe(
wav_path,
language=WHISPER_LANGUAGE,
vad_filter=True,
beam_size=1,
)
transcript = "".join(seg.text for seg in segments).strip()
stt_ms = int((time.time() - t_stt) * 1000)
print(f"[/v1/utterance] transcript_len={len(transcript)} stt_ms={stt_ms}")
print(f"[/v1/utterance] transcript={transcript!r}")
if not transcript:
total_ms = int((time.time() - t0) * 1000)
print(f"[/v1/utterance] EMPTY transcript total_ms={total_ms}")
return jsonify({"error": "Empty transcript", "stt_ms": stt_ms, "total_ms": total_ms}), 200
# ---- LLM ----
t_llm = time.time()
reply_text = llm_chat(transcript)
llm_ms = int((time.time() - t_llm) * 1000)
print(f"[/v1/utterance] reply_len={len(reply_text)} llm_ms={llm_ms}")
print(f"[/v1/utterance] bot_reply={reply_text!r}")
# ---- TTS ----
try:
mp3_path, tts_ms = _tts_to_mp3_file(reply_text)
except Exception as e:
details = _err_details(e)
total_ms = int((time.time() - t0) * 1000)
print(f"[/v1/utterance] TTS FAIL total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}")
return jsonify({
"error": "ElevenLabs TTS failed",
"details": details,
"transcript": transcript,
"reply_text": reply_text,
"stt_ms": stt_ms,
"llm_ms": llm_ms,
"total_ms": total_ms,
}), 502
total_ms = int((time.time() - t0) * 1000)
print(f"[/v1/utterance] tts_ms={tts_ms} total_ms={total_ms}")
print(f"[/v1/utterance] END ip={ip}")
resp = send_file(
mp3_path,
mimetype="audio/mpeg",
as_attachment=False,
download_name="andy.mp3",
conditional=False,
)
resp.headers["X-STT-MS"] = str(stt_ms)
resp.headers["X-LLM-MS"] = str(llm_ms)
resp.headers["X-TTS-MS"] = str(tts_ms)
resp.headers["X-TOTAL-MS"] = str(total_ms)
return resp
except Exception as e:
total_ms = int((time.time() - t0) * 1000)
details = _err_details(e)
print(f"[/v1/utterance] FAIL ip={ip} total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}")
return jsonify({"error": "Utterance pipeline failed", "details": details, "total_ms": total_ms}), 500
finally:
try:
os.remove(wav_path)
except Exception:
pass
if mp3_path:
try:
os.remove(mp3_path)
except Exception:
pass
@app.post("/v1/reset")
def reset():
auth_resp = _require_auth()
if auth_resp:
return auth_resp
ip = _client_ip()
print(f"[/v1/reset] {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} clearing mem (was {len(HISTORY)}/{MAX_MESSAGES})")
HISTORY.clear()
return jsonify({"ok": True, "memory_messages": 0})
# -------------------------
# Startup
# -------------------------
if __name__ == "__main__":
port = int(os.environ.get("PORT", "7860"))
print(f"[startup] model={MODEL} thinking_level={THINKING_LEVEL} max_messages={MAX_MESSAGES} port={port}")
print(f"[startup] whisper_model={WHISPER_MODEL_NAME} device={WHISPER_DEVICE} compute={WHISPER_COMPUTE_TYPE}")
print(f"[startup] eleven_ok={bool(ELEVEN_API_KEY)} voice={ELEVEN_VOICE_ID} model={ELEVEN_MODEL_ID} out={ELEVEN_OUTPUT_FORMAT}")
serve(app, host="0.0.0.0", port=port)