import os import time import json import tempfile from collections import deque from flask import Flask, request, jsonify, send_file from waitress import serve from google import genai from google.genai import types from faster_whisper import WhisperModel from elevenlabs.client import ElevenLabs from elevenlabs import save, VoiceSettings app = Flask(__name__) # ------------------------- # Config # ------------------------- MODEL = os.environ.get("GEMINI_MODEL", "gemini-3-flash-preview") THINKING_LEVEL = os.environ.get("GEMINI_THINKING_LEVEL", "HIGH") SYSTEM_PROMPT = ( "You should respond like Andy Warhol.\n" "Respond in 1-3 sentences and less than 200 characters.\n" "You should say uh 0-2 times per response, it can be in different parts of the response.\n" "Don't repeat yourself too much.\n" ) # ------------------------- # Auth # ------------------------- API_PASSWORD = os.environ.get("API_PASSWORD", "").strip() def _require_auth(): """ Require a shared secret from the client. Client must send header: X-API-PASSWORD: """ if not API_PASSWORD: # If you forget to set the secret, fail closed. return jsonify({"error": "Server missing API_PASSWORD secret"}), 500 provided = (request.headers.get("X-API-PASSWORD") or "").strip() if not provided or provided != API_PASSWORD: return jsonify({"error": "Unauthorized"}), 401 return None # STT (we chose base.en) WHISPER_MODEL_NAME = os.environ.get("WHISPER_MODEL", "base.en") WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE", "cpu") WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE", "int8") WHISPER_LANGUAGE = os.environ.get("WHISPER_LANGUAGE", "en") # ElevenLabs ELEVEN_API_KEY = os.environ.get("ELEVEN_API_KEY") ELEVEN_VOICE_ID = os.environ.get("ELEVEN_VOICE_ID", "kYLLcRUC2uzrEp0Jr2HT") ELEVEN_MODEL_ID = os.environ.get("ELEVEN_MODEL_ID", "eleven_multilingual_v2") ELEVEN_OUTPUT_FORMAT = os.environ.get("ELEVEN_OUTPUT_FORMAT", "mp3_44100_128") # Gemini client (expects GEMINI_API_KEY set as a HF Space Secret) client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) # Eleven client eleven = ElevenLabs(api_key=ELEVEN_API_KEY) if ELEVEN_API_KEY else None # ---- Memory (global, RAM-only, survives refresh, resets on Space restart) ---- MAX_MESSAGES = 20 HISTORY = deque(maxlen=MAX_MESSAGES) # ---- Whisper model (lazy init) ---- _whisper_model = None # ------------------------- # Helpers # ------------------------- def _client_ip() -> str: return request.headers.get("x-forwarded-for", request.remote_addr or "unknown") def _err_details(e: Exception) -> dict: d = {"type": type(e).__name__, "repr": repr(e)} for k in ["status_code", "body", "message", "response", "details"]: if hasattr(e, k): try: d[k] = getattr(e, k) except Exception: pass return d def _get_whisper_model() -> WhisperModel: global _whisper_model if _whisper_model is None: print( f"[whisper] loading model={WHISPER_MODEL_NAME} " f"device={WHISPER_DEVICE} compute_type={WHISPER_COMPUTE_TYPE}" ) _whisper_model = WhisperModel( WHISPER_MODEL_NAME, device=WHISPER_DEVICE, compute_type=WHISPER_COMPUTE_TYPE, ) print("[whisper] loaded") return _whisper_model def _clean_reply(text: str) -> str: t = (text or "").strip() if not t: return t if t.endswith(("...", "…", ",")): t = t.rstrip(".,… ,").strip() if t and t[-1] not in ".?!": t += "." return t def _gemini_config() -> types.GenerateContentConfig: return types.GenerateContentConfig( system_instruction=[types.Part.from_text(text=SYSTEM_PROMPT)], thinking_config=types.ThinkingConfig(thinking_level=THINKING_LEVEL), #max_output_tokens=256, #temperature=0.7, safety_settings=[ types.SafetySetting( category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=types.HarmBlockThreshold.OFF, ), types.SafetySetting( category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=types.HarmBlockThreshold.OFF, ), types.SafetySetting( category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=types.HarmBlockThreshold.OFF, ), types.SafetySetting( category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.OFF, ), ], ) def llm_chat(user_text: str) -> str: user_text = (user_text or "").strip() if not user_text: raise ValueError("Missing 'text'") HISTORY.append(types.Content(role="user", parts=[types.Part.from_text(text=user_text)])) try: resp = client.models.generate_content( model=MODEL, contents=list(HISTORY), config=_gemini_config(), ) reply_text = _clean_reply(resp.text) HISTORY.append(types.Content(role="model", parts=[types.Part.from_text(text=reply_text)])) return reply_text except Exception: if len(HISTORY) > 0 and getattr(HISTORY[-1], "role", None) == "user": HISTORY.pop() raise def _tts_to_mp3_file(text: str) -> tuple[str, int]: """ Returns: (mp3_path, tts_ms) Raises exception on failure. """ if eleven is None: raise RuntimeError("Server missing ELEVEN_API_KEY") t0 = time.time() audio_stream = eleven.text_to_speech.convert( text=text, voice_id=ELEVEN_VOICE_ID, model_id=ELEVEN_MODEL_ID, output_format=ELEVEN_OUTPUT_FORMAT, voice_settings=VoiceSettings( speed=0.8, ), ) with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_out: mp3_path = tmp_out.name save(audio_stream, mp3_path) tts_ms = int((time.time() - t0) * 1000) return mp3_path, tts_ms # ------------------------- # Endpoints # ------------------------- @app.get("/health") def health(): print(f"[/health] {time.strftime('%Y-%m-%d %H:%M:%S')} ip={_client_ip()} mem={len(HISTORY)}/{MAX_MESSAGES}") return jsonify({ "ok": True, "model": MODEL, "thinking_level": THINKING_LEVEL, "memory_messages": len(HISTORY), "max_messages": MAX_MESSAGES, "whisper_model": WHISPER_MODEL_NAME, "whisper_device": WHISPER_DEVICE, "whisper_compute_type": WHISPER_COMPUTE_TYPE, "eleven_ok": bool(ELEVEN_API_KEY), "eleven_voice_id": ELEVEN_VOICE_ID, "eleven_model_id": ELEVEN_MODEL_ID, "eleven_output_format": ELEVEN_OUTPUT_FORMAT, }) @app.post("/v1/chat") def chat_text(): auth_resp = _require_auth() if auth_resp: return auth_resp t0 = time.time() ip = _client_ip() data = request.get_json(silent=True) or {} user_text = (data.get("text") or "").strip() print(f"[/v1/chat] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} mem_before={len(HISTORY)}/{MAX_MESSAGES}") if not user_text: print(f"[/v1/chat] ERROR missing text ip={ip}") return jsonify({"error": "Missing 'text'"}), 400 print(f"[/v1/chat] user_text_len={len(user_text)} user_text={user_text!r}") try: reply_text = llm_chat(user_text) dt_ms = int((time.time() - t0) * 1000) print(f"[/v1/chat] bot_reply={reply_text!r}") print(f"[/v1/chat] END ip={ip} total_ms={dt_ms} mem_now={len(HISTORY)}/{MAX_MESSAGES}") return jsonify({ "input": user_text, "reply_text": reply_text, "model": MODEL, "memory_messages": len(HISTORY), "total_ms": dt_ms, }) except Exception as e: dt_ms = int((time.time() - t0) * 1000) print("Gemini error:", repr(e)) print(f"[/v1/chat] FAIL ip={ip} total_ms={dt_ms} mem_now={len(HISTORY)}/{MAX_MESSAGES}") return jsonify({"error": "Gemini call failed"}), 500 @app.post("/v1/tts") def tts_only(): """ JSON body: { "text": "hello" } Returns: audio/mpeg (mp3) Timing headers: X-TTS-MS, X-TOTAL-MS """ auth_resp = _require_auth() if auth_resp: return auth_resp ip = _client_ip() t0 = time.time() data = request.get_json(silent=True) or {} text = (data.get("text") or "").strip() print(f"[/v1/tts] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} text_len={len(text)}") if not text: return jsonify({"error": "Missing 'text'"}), 400 mp3_path = None try: mp3_path, tts_ms = _tts_to_mp3_file(text) total_ms = int((time.time() - t0) * 1000) print(f"[/v1/tts] OK tts_ms={tts_ms} total_ms={total_ms}") resp = send_file( mp3_path, mimetype="audio/mpeg", as_attachment=False, download_name="andy.mp3", conditional=False, ) resp.headers["X-TTS-MS"] = str(tts_ms) resp.headers["X-TOTAL-MS"] = str(total_ms) return resp except Exception as e: details = _err_details(e) total_ms = int((time.time() - t0) * 1000) print(f"[/v1/tts] FAIL total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}") return jsonify({"error": "ElevenLabs TTS failed", "details": details, "total_ms": total_ms}), 502 finally: if mp3_path: try: os.remove(mp3_path) except Exception: pass @app.post("/v1/utterance") def utterance_audio_to_audio(): """ Accepts: multipart/form-data with field "audio" containing a .wav file Returns: audio/mpeg (mp3) Timing headers: X-STT-MS, X-LLM-MS, X-TTS-MS, X-TOTAL-MS """ auth_resp = _require_auth() if auth_resp: return auth_resp t0 = time.time() ip = _client_ip() print(f"[/v1/utterance] START {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip}") if eleven is None: print("[/v1/utterance] ERROR missing ELEVEN_API_KEY") return jsonify({"error": "Server missing ELEVEN_API_KEY"}), 500 if "audio" not in request.files: print(f"[/v1/utterance] ERROR missing file field 'audio' ip={ip}") return jsonify({"error": "Missing file field 'audio'"}), 400 f = request.files["audio"] filename = (f.filename or "").strip() or "audio.wav" if not filename.lower().endswith(".wav"): print(f"[/v1/utterance] ERROR non-wav filename={filename!r} ip={ip}") return jsonify({"error": "Please upload a .wav file"}), 400 with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_in: wav_path = tmp_in.name f.save(wav_path) mp3_path = None stt_ms = llm_ms = tts_ms = 0 transcript = "" reply_text = "" try: # ---- STT ---- t_stt = time.time() model = _get_whisper_model() segments, _info = model.transcribe( wav_path, language=WHISPER_LANGUAGE, vad_filter=True, beam_size=1, ) transcript = "".join(seg.text for seg in segments).strip() stt_ms = int((time.time() - t_stt) * 1000) print(f"[/v1/utterance] transcript_len={len(transcript)} stt_ms={stt_ms}") print(f"[/v1/utterance] transcript={transcript!r}") if not transcript: total_ms = int((time.time() - t0) * 1000) print(f"[/v1/utterance] EMPTY transcript total_ms={total_ms}") return jsonify({"error": "Empty transcript", "stt_ms": stt_ms, "total_ms": total_ms}), 200 # ---- LLM ---- t_llm = time.time() reply_text = llm_chat(transcript) llm_ms = int((time.time() - t_llm) * 1000) print(f"[/v1/utterance] reply_len={len(reply_text)} llm_ms={llm_ms}") print(f"[/v1/utterance] bot_reply={reply_text!r}") # ---- TTS ---- try: mp3_path, tts_ms = _tts_to_mp3_file(reply_text) except Exception as e: details = _err_details(e) total_ms = int((time.time() - t0) * 1000) print(f"[/v1/utterance] TTS FAIL total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}") return jsonify({ "error": "ElevenLabs TTS failed", "details": details, "transcript": transcript, "reply_text": reply_text, "stt_ms": stt_ms, "llm_ms": llm_ms, "total_ms": total_ms, }), 502 total_ms = int((time.time() - t0) * 1000) print(f"[/v1/utterance] tts_ms={tts_ms} total_ms={total_ms}") print(f"[/v1/utterance] END ip={ip}") resp = send_file( mp3_path, mimetype="audio/mpeg", as_attachment=False, download_name="andy.mp3", conditional=False, ) resp.headers["X-STT-MS"] = str(stt_ms) resp.headers["X-LLM-MS"] = str(llm_ms) resp.headers["X-TTS-MS"] = str(tts_ms) resp.headers["X-TOTAL-MS"] = str(total_ms) return resp except Exception as e: total_ms = int((time.time() - t0) * 1000) details = _err_details(e) print(f"[/v1/utterance] FAIL ip={ip} total_ms={total_ms} details={json.dumps(details, default=str)[:2000]}") return jsonify({"error": "Utterance pipeline failed", "details": details, "total_ms": total_ms}), 500 finally: try: os.remove(wav_path) except Exception: pass if mp3_path: try: os.remove(mp3_path) except Exception: pass @app.post("/v1/reset") def reset(): auth_resp = _require_auth() if auth_resp: return auth_resp ip = _client_ip() print(f"[/v1/reset] {time.strftime('%Y-%m-%d %H:%M:%S')} ip={ip} clearing mem (was {len(HISTORY)}/{MAX_MESSAGES})") HISTORY.clear() return jsonify({"ok": True, "memory_messages": 0}) # ------------------------- # Startup # ------------------------- if __name__ == "__main__": port = int(os.environ.get("PORT", "7860")) print(f"[startup] model={MODEL} thinking_level={THINKING_LEVEL} max_messages={MAX_MESSAGES} port={port}") print(f"[startup] whisper_model={WHISPER_MODEL_NAME} device={WHISPER_DEVICE} compute={WHISPER_COMPUTE_TYPE}") print(f"[startup] eleven_ok={bool(ELEVEN_API_KEY)} voice={ELEVEN_VOICE_ID} model={ELEVEN_MODEL_ID} out={ELEVEN_OUTPUT_FORMAT}") serve(app, host="0.0.0.0", port=port)