File size: 5,266 Bytes
6c2a9cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | from flask import Flask, request, jsonify, send_file
import os, uuid, time
from pathlib import Path
import requests
app = Flask(__name__)
# --- Config via environment variables in Space settings ---
HF_TOKEN = os.getenv("HF_TOKEN", "") # required
HF_STT_MODEL = os.getenv("HF_STT_MODEL", "openai/whisper-small") # speech-to-text
HF_LLM_MODEL = os.getenv("HF_LLM_MODEL", "google/flan-t5-large") # text generation
HF_TTS_MODEL = os.getenv("HF_TTS_MODEL", "espnet/kan-bayashi_ljspeech_vits") # text-to-speech (example)
TMP = Path("tmp")
TMP.mkdir(exist_ok=True)
if not HF_TOKEN:
print("WARNING: HF_TOKEN not set. Set it in Space Settings -> Secrets/Vars.")
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
def hf_inference_raw(model_id: str, input_data, headers_extra=None, as_json=True, timeout=120):
url = f"https://api-inference.huggingface.co/models/{model_id}"
headers = HEADERS.copy()
if headers_extra:
headers.update(headers_extra)
if isinstance(input_data, (bytes, bytearray)):
resp = requests.post(url, headers=headers, data=input_data, timeout=timeout)
else:
resp = requests.post(url, headers=headers, json=input_data, timeout=timeout)
resp.raise_for_status()
if as_json:
try:
return resp.json()
except Exception:
return resp.text
else:
return resp.content
def stt_from_wav_bytes(wav_bytes: bytes):
# Many whisper-based HF endpoints accept audio/wav or audio/mpeg
try:
out = hf_inference_raw(HF_STT_MODEL, wav_bytes, headers_extra={"Content-Type":"audio/wav"}, as_json=True, timeout=300)
if isinstance(out, dict) and "text" in out:
return out["text"]
if isinstance(out, str):
return out
# fallback
return str(out)
except Exception as e:
print("STT error:", e)
return ""
def llm_query(prompt: str):
payload = {"inputs": prompt}
try:
out = hf_inference_raw(HF_LLM_MODEL, payload, as_json=True, timeout=120)
# handle common shapes
if isinstance(out, list) and len(out) > 0:
first = out[0]
if isinstance(first, dict) and "generated_text" in first:
return first["generated_text"]
return str(first)
if isinstance(out, dict):
if "generated_text" in out:
return out["generated_text"]
if "text" in out:
return out["text"]
return str(out)
if isinstance(out, str):
return out
return str(out)
except Exception as e:
print("LLM error:", e)
return "Xin lỗi, tôi gặp lỗi khi trả lời."
def tts_to_wav_bytes(text: str):
payload = {"inputs": text}
try:
# request raw bytes
url = f"https://api-inference.huggingface.co/models/{HF_TTS_MODEL}"
resp = requests.post(url, headers=HEADERS, json=payload, timeout=120)
resp.raise_for_status()
return resp.content
except Exception as e:
print("TTS error:", e, getattr(e, 'response', None) and e.response.text)
return b""
@app.route("/health")
def health():
return jsonify({"ok": True})
@app.route("/upload_audio", methods=["POST"])
def upload_audio():
try:
audio_bytes = request.get_data()
if not audio_bytes:
return jsonify({"error":"no audio data"}), 400
fname = f"{int(time.time())}_{uuid.uuid4().hex}.wav"
path = TMP / fname
path.write_bytes(audio_bytes)
# 1) STT
text = stt_from_wav_bytes(audio_bytes)
if not text:
text = "(Không nhận được lời nói)"
# 2) LLM
prompt = f"Bạn là trợ lý. Người dùng nói: '{text}'. Trả lời ngắn gọn bằng tiếng Việt."
reply = llm_query(prompt)
# 3) TTS
tts_bytes = tts_to_wav_bytes(reply)
tts_fname = ""
if tts_bytes:
tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav"
(TMP / tts_fname).write_bytes(tts_bytes)
return jsonify({"text": text, "reply": reply, "tts_file": tts_fname})
except Exception as e:
print("upload_audio error:", e)
return jsonify({"error": str(e)}), 500
@app.route("/ask", methods=["POST"])
def ask():
try:
data = request.get_json(force=True)
text = data.get("text", "")
if not text:
return jsonify({"error":"no text"}), 400
reply = llm_query(text)
tts_bytes = tts_to_wav_bytes(reply)
tts_fname = ""
if tts_bytes:
tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav"
(TMP / tts_fname).write_bytes(tts_bytes)
return jsonify({"text": reply, "tts_file": tts_fname})
except Exception as e:
print("ask error:", e)
return jsonify({"error": str(e)}), 500
@app.route("/tts/<fname>", methods=["GET"])
def serve_tts(fname):
p = TMP / fname
if not p.exists():
return "Not found", 404
return send_file(str(p), mimetype="audio/wav", as_attachment=False)
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
app.run(host="0.0.0.0", port=port, threaded=True)
|