kc / app.py
kcrobot40's picture
kcfile
6c2a9cf verified
from flask import Flask, request, jsonify, send_file
import os, uuid, time
from pathlib import Path
import requests
app = Flask(__name__)
# --- Config via environment variables in Space settings ---
HF_TOKEN = os.getenv("HF_TOKEN", "") # required
HF_STT_MODEL = os.getenv("HF_STT_MODEL", "openai/whisper-small") # speech-to-text
HF_LLM_MODEL = os.getenv("HF_LLM_MODEL", "google/flan-t5-large") # text generation
HF_TTS_MODEL = os.getenv("HF_TTS_MODEL", "espnet/kan-bayashi_ljspeech_vits") # text-to-speech (example)
TMP = Path("tmp")
TMP.mkdir(exist_ok=True)
if not HF_TOKEN:
print("WARNING: HF_TOKEN not set. Set it in Space Settings -> Secrets/Vars.")
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
def hf_inference_raw(model_id: str, input_data, headers_extra=None, as_json=True, timeout=120):
url = f"https://api-inference.huggingface.co/models/{model_id}"
headers = HEADERS.copy()
if headers_extra:
headers.update(headers_extra)
if isinstance(input_data, (bytes, bytearray)):
resp = requests.post(url, headers=headers, data=input_data, timeout=timeout)
else:
resp = requests.post(url, headers=headers, json=input_data, timeout=timeout)
resp.raise_for_status()
if as_json:
try:
return resp.json()
except Exception:
return resp.text
else:
return resp.content
def stt_from_wav_bytes(wav_bytes: bytes):
# Many whisper-based HF endpoints accept audio/wav or audio/mpeg
try:
out = hf_inference_raw(HF_STT_MODEL, wav_bytes, headers_extra={"Content-Type":"audio/wav"}, as_json=True, timeout=300)
if isinstance(out, dict) and "text" in out:
return out["text"]
if isinstance(out, str):
return out
# fallback
return str(out)
except Exception as e:
print("STT error:", e)
return ""
def llm_query(prompt: str):
payload = {"inputs": prompt}
try:
out = hf_inference_raw(HF_LLM_MODEL, payload, as_json=True, timeout=120)
# handle common shapes
if isinstance(out, list) and len(out) > 0:
first = out[0]
if isinstance(first, dict) and "generated_text" in first:
return first["generated_text"]
return str(first)
if isinstance(out, dict):
if "generated_text" in out:
return out["generated_text"]
if "text" in out:
return out["text"]
return str(out)
if isinstance(out, str):
return out
return str(out)
except Exception as e:
print("LLM error:", e)
return "Xin lỗi, tôi gặp lỗi khi trả lời."
def tts_to_wav_bytes(text: str):
payload = {"inputs": text}
try:
# request raw bytes
url = f"https://api-inference.huggingface.co/models/{HF_TTS_MODEL}"
resp = requests.post(url, headers=HEADERS, json=payload, timeout=120)
resp.raise_for_status()
return resp.content
except Exception as e:
print("TTS error:", e, getattr(e, 'response', None) and e.response.text)
return b""
@app.route("/health")
def health():
return jsonify({"ok": True})
@app.route("/upload_audio", methods=["POST"])
def upload_audio():
try:
audio_bytes = request.get_data()
if not audio_bytes:
return jsonify({"error":"no audio data"}), 400
fname = f"{int(time.time())}_{uuid.uuid4().hex}.wav"
path = TMP / fname
path.write_bytes(audio_bytes)
# 1) STT
text = stt_from_wav_bytes(audio_bytes)
if not text:
text = "(Không nhận được lời nói)"
# 2) LLM
prompt = f"Bạn là trợ lý. Người dùng nói: '{text}'. Trả lời ngắn gọn bằng tiếng Việt."
reply = llm_query(prompt)
# 3) TTS
tts_bytes = tts_to_wav_bytes(reply)
tts_fname = ""
if tts_bytes:
tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav"
(TMP / tts_fname).write_bytes(tts_bytes)
return jsonify({"text": text, "reply": reply, "tts_file": tts_fname})
except Exception as e:
print("upload_audio error:", e)
return jsonify({"error": str(e)}), 500
@app.route("/ask", methods=["POST"])
def ask():
try:
data = request.get_json(force=True)
text = data.get("text", "")
if not text:
return jsonify({"error":"no text"}), 400
reply = llm_query(text)
tts_bytes = tts_to_wav_bytes(reply)
tts_fname = ""
if tts_bytes:
tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav"
(TMP / tts_fname).write_bytes(tts_bytes)
return jsonify({"text": reply, "tts_file": tts_fname})
except Exception as e:
print("ask error:", e)
return jsonify({"error": str(e)}), 500
@app.route("/tts/<fname>", methods=["GET"])
def serve_tts(fname):
p = TMP / fname
if not p.exists():
return "Not found", 404
return send_file(str(p), mimetype="audio/wav", as_attachment=False)
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
app.run(host="0.0.0.0", port=port, threaded=True)