from flask import Flask, request, jsonify, send_file import os, uuid, time from pathlib import Path import requests app = Flask(__name__) # --- Config via environment variables in Space settings --- HF_TOKEN = os.getenv("HF_TOKEN", "") # required HF_STT_MODEL = os.getenv("HF_STT_MODEL", "openai/whisper-small") # speech-to-text HF_LLM_MODEL = os.getenv("HF_LLM_MODEL", "google/flan-t5-large") # text generation HF_TTS_MODEL = os.getenv("HF_TTS_MODEL", "espnet/kan-bayashi_ljspeech_vits") # text-to-speech (example) TMP = Path("tmp") TMP.mkdir(exist_ok=True) if not HF_TOKEN: print("WARNING: HF_TOKEN not set. Set it in Space Settings -> Secrets/Vars.") HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} def hf_inference_raw(model_id: str, input_data, headers_extra=None, as_json=True, timeout=120): url = f"https://api-inference.huggingface.co/models/{model_id}" headers = HEADERS.copy() if headers_extra: headers.update(headers_extra) if isinstance(input_data, (bytes, bytearray)): resp = requests.post(url, headers=headers, data=input_data, timeout=timeout) else: resp = requests.post(url, headers=headers, json=input_data, timeout=timeout) resp.raise_for_status() if as_json: try: return resp.json() except Exception: return resp.text else: return resp.content def stt_from_wav_bytes(wav_bytes: bytes): # Many whisper-based HF endpoints accept audio/wav or audio/mpeg try: out = hf_inference_raw(HF_STT_MODEL, wav_bytes, headers_extra={"Content-Type":"audio/wav"}, as_json=True, timeout=300) if isinstance(out, dict) and "text" in out: return out["text"] if isinstance(out, str): return out # fallback return str(out) except Exception as e: print("STT error:", e) return "" def llm_query(prompt: str): payload = {"inputs": prompt} try: out = hf_inference_raw(HF_LLM_MODEL, payload, as_json=True, timeout=120) # handle common shapes if isinstance(out, list) and len(out) > 0: first = out[0] if isinstance(first, dict) and "generated_text" in first: return first["generated_text"] return str(first) if isinstance(out, dict): if "generated_text" in out: return out["generated_text"] if "text" in out: return out["text"] return str(out) if isinstance(out, str): return out return str(out) except Exception as e: print("LLM error:", e) return "Xin lỗi, tôi gặp lỗi khi trả lời." def tts_to_wav_bytes(text: str): payload = {"inputs": text} try: # request raw bytes url = f"https://api-inference.huggingface.co/models/{HF_TTS_MODEL}" resp = requests.post(url, headers=HEADERS, json=payload, timeout=120) resp.raise_for_status() return resp.content except Exception as e: print("TTS error:", e, getattr(e, 'response', None) and e.response.text) return b"" @app.route("/health") def health(): return jsonify({"ok": True}) @app.route("/upload_audio", methods=["POST"]) def upload_audio(): try: audio_bytes = request.get_data() if not audio_bytes: return jsonify({"error":"no audio data"}), 400 fname = f"{int(time.time())}_{uuid.uuid4().hex}.wav" path = TMP / fname path.write_bytes(audio_bytes) # 1) STT text = stt_from_wav_bytes(audio_bytes) if not text: text = "(Không nhận được lời nói)" # 2) LLM prompt = f"Bạn là trợ lý. Người dùng nói: '{text}'. Trả lời ngắn gọn bằng tiếng Việt." reply = llm_query(prompt) # 3) TTS tts_bytes = tts_to_wav_bytes(reply) tts_fname = "" if tts_bytes: tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav" (TMP / tts_fname).write_bytes(tts_bytes) return jsonify({"text": text, "reply": reply, "tts_file": tts_fname}) except Exception as e: print("upload_audio error:", e) return jsonify({"error": str(e)}), 500 @app.route("/ask", methods=["POST"]) def ask(): try: data = request.get_json(force=True) text = data.get("text", "") if not text: return jsonify({"error":"no text"}), 400 reply = llm_query(text) tts_bytes = tts_to_wav_bytes(reply) tts_fname = "" if tts_bytes: tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav" (TMP / tts_fname).write_bytes(tts_bytes) return jsonify({"text": reply, "tts_file": tts_fname}) except Exception as e: print("ask error:", e) return jsonify({"error": str(e)}), 500 @app.route("/tts/", methods=["GET"]) def serve_tts(fname): p = TMP / fname if not p.exists(): return "Not found", 404 return send_file(str(p), mimetype="audio/wav", as_attachment=False) if __name__ == "__main__": port = int(os.getenv("PORT", "7860")) app.run(host="0.0.0.0", port=port, threaded=True)