| from flask import Flask, request, jsonify, send_file |
| import os, uuid, time |
| from pathlib import Path |
| import requests |
|
|
| app = Flask(__name__) |
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN", "") |
| HF_STT_MODEL = os.getenv("HF_STT_MODEL", "openai/whisper-small") |
| HF_LLM_MODEL = os.getenv("HF_LLM_MODEL", "google/flan-t5-large") |
| HF_TTS_MODEL = os.getenv("HF_TTS_MODEL", "espnet/kan-bayashi_ljspeech_vits") |
|
|
| TMP = Path("tmp") |
| TMP.mkdir(exist_ok=True) |
|
|
| if not HF_TOKEN: |
| print("WARNING: HF_TOKEN not set. Set it in Space Settings -> Secrets/Vars.") |
|
|
| HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} |
|
|
| def hf_inference_raw(model_id: str, input_data, headers_extra=None, as_json=True, timeout=120): |
| url = f"https://api-inference.huggingface.co/models/{model_id}" |
| headers = HEADERS.copy() |
| if headers_extra: |
| headers.update(headers_extra) |
| if isinstance(input_data, (bytes, bytearray)): |
| resp = requests.post(url, headers=headers, data=input_data, timeout=timeout) |
| else: |
| resp = requests.post(url, headers=headers, json=input_data, timeout=timeout) |
| resp.raise_for_status() |
| if as_json: |
| try: |
| return resp.json() |
| except Exception: |
| return resp.text |
| else: |
| return resp.content |
|
|
| def stt_from_wav_bytes(wav_bytes: bytes): |
| |
| try: |
| out = hf_inference_raw(HF_STT_MODEL, wav_bytes, headers_extra={"Content-Type":"audio/wav"}, as_json=True, timeout=300) |
| if isinstance(out, dict) and "text" in out: |
| return out["text"] |
| if isinstance(out, str): |
| return out |
| |
| return str(out) |
| except Exception as e: |
| print("STT error:", e) |
| return "" |
|
|
| def llm_query(prompt: str): |
| payload = {"inputs": prompt} |
| try: |
| out = hf_inference_raw(HF_LLM_MODEL, payload, as_json=True, timeout=120) |
| |
| if isinstance(out, list) and len(out) > 0: |
| first = out[0] |
| if isinstance(first, dict) and "generated_text" in first: |
| return first["generated_text"] |
| return str(first) |
| if isinstance(out, dict): |
| if "generated_text" in out: |
| return out["generated_text"] |
| if "text" in out: |
| return out["text"] |
| return str(out) |
| if isinstance(out, str): |
| return out |
| return str(out) |
| except Exception as e: |
| print("LLM error:", e) |
| return "Xin lỗi, tôi gặp lỗi khi trả lời." |
|
|
| def tts_to_wav_bytes(text: str): |
| payload = {"inputs": text} |
| try: |
| |
| url = f"https://api-inference.huggingface.co/models/{HF_TTS_MODEL}" |
| resp = requests.post(url, headers=HEADERS, json=payload, timeout=120) |
| resp.raise_for_status() |
| return resp.content |
| except Exception as e: |
| print("TTS error:", e, getattr(e, 'response', None) and e.response.text) |
| return b"" |
|
|
| @app.route("/health") |
| def health(): |
| return jsonify({"ok": True}) |
|
|
| @app.route("/upload_audio", methods=["POST"]) |
| def upload_audio(): |
| try: |
| audio_bytes = request.get_data() |
| if not audio_bytes: |
| return jsonify({"error":"no audio data"}), 400 |
| fname = f"{int(time.time())}_{uuid.uuid4().hex}.wav" |
| path = TMP / fname |
| path.write_bytes(audio_bytes) |
|
|
| |
| text = stt_from_wav_bytes(audio_bytes) |
| if not text: |
| text = "(Không nhận được lời nói)" |
|
|
| |
| prompt = f"Bạn là trợ lý. Người dùng nói: '{text}'. Trả lời ngắn gọn bằng tiếng Việt." |
| reply = llm_query(prompt) |
|
|
| |
| tts_bytes = tts_to_wav_bytes(reply) |
| tts_fname = "" |
| if tts_bytes: |
| tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav" |
| (TMP / tts_fname).write_bytes(tts_bytes) |
|
|
| return jsonify({"text": text, "reply": reply, "tts_file": tts_fname}) |
| except Exception as e: |
| print("upload_audio error:", e) |
| return jsonify({"error": str(e)}), 500 |
|
|
| @app.route("/ask", methods=["POST"]) |
| def ask(): |
| try: |
| data = request.get_json(force=True) |
| text = data.get("text", "") |
| if not text: |
| return jsonify({"error":"no text"}), 400 |
| reply = llm_query(text) |
| tts_bytes = tts_to_wav_bytes(reply) |
| tts_fname = "" |
| if tts_bytes: |
| tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav" |
| (TMP / tts_fname).write_bytes(tts_bytes) |
| return jsonify({"text": reply, "tts_file": tts_fname}) |
| except Exception as e: |
| print("ask error:", e) |
| return jsonify({"error": str(e)}), 500 |
|
|
| @app.route("/tts/<fname>", methods=["GET"]) |
| def serve_tts(fname): |
| p = TMP / fname |
| if not p.exists(): |
| return "Not found", 404 |
| return send_file(str(p), mimetype="audio/wav", as_attachment=False) |
|
|
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", "7860")) |
| app.run(host="0.0.0.0", port=port, threaded=True) |
|
|