File size: 5,266 Bytes
6c2a9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from flask import Flask, request, jsonify, send_file
import os, uuid, time
from pathlib import Path
import requests

app = Flask(__name__)

# --- Config via environment variables in Space settings ---
HF_TOKEN = os.getenv("HF_TOKEN", "")        # required
HF_STT_MODEL = os.getenv("HF_STT_MODEL", "openai/whisper-small")   # speech-to-text
HF_LLM_MODEL = os.getenv("HF_LLM_MODEL", "google/flan-t5-large")   # text generation
HF_TTS_MODEL = os.getenv("HF_TTS_MODEL", "espnet/kan-bayashi_ljspeech_vits")  # text-to-speech (example)

TMP = Path("tmp")
TMP.mkdir(exist_ok=True)

if not HF_TOKEN:
    print("WARNING: HF_TOKEN not set. Set it in Space Settings -> Secrets/Vars.")

HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}

def hf_inference_raw(model_id: str, input_data, headers_extra=None, as_json=True, timeout=120):
    url = f"https://api-inference.huggingface.co/models/{model_id}"
    headers = HEADERS.copy()
    if headers_extra:
        headers.update(headers_extra)
    if isinstance(input_data, (bytes, bytearray)):
        resp = requests.post(url, headers=headers, data=input_data, timeout=timeout)
    else:
        resp = requests.post(url, headers=headers, json=input_data, timeout=timeout)
    resp.raise_for_status()
    if as_json:
        try:
            return resp.json()
        except Exception:
            return resp.text
    else:
        return resp.content

def stt_from_wav_bytes(wav_bytes: bytes):
    # Many whisper-based HF endpoints accept audio/wav or audio/mpeg
    try:
        out = hf_inference_raw(HF_STT_MODEL, wav_bytes, headers_extra={"Content-Type":"audio/wav"}, as_json=True, timeout=300)
        if isinstance(out, dict) and "text" in out:
            return out["text"]
        if isinstance(out, str):
            return out
        # fallback
        return str(out)
    except Exception as e:
        print("STT error:", e)
        return ""

def llm_query(prompt: str):
    payload = {"inputs": prompt}
    try:
        out = hf_inference_raw(HF_LLM_MODEL, payload, as_json=True, timeout=120)
        # handle common shapes
        if isinstance(out, list) and len(out) > 0:
            first = out[0]
            if isinstance(first, dict) and "generated_text" in first:
                return first["generated_text"]
            return str(first)
        if isinstance(out, dict):
            if "generated_text" in out:
                return out["generated_text"]
            if "text" in out:
                return out["text"]
            return str(out)
        if isinstance(out, str):
            return out
        return str(out)
    except Exception as e:
        print("LLM error:", e)
        return "Xin lỗi, tôi gặp lỗi khi trả lời."

def tts_to_wav_bytes(text: str):
    payload = {"inputs": text}
    try:
        # request raw bytes
        url = f"https://api-inference.huggingface.co/models/{HF_TTS_MODEL}"
        resp = requests.post(url, headers=HEADERS, json=payload, timeout=120)
        resp.raise_for_status()
        return resp.content
    except Exception as e:
        print("TTS error:", e, getattr(e, 'response', None) and e.response.text)
        return b""

@app.route("/health")
def health():
    return jsonify({"ok": True})

@app.route("/upload_audio", methods=["POST"])
def upload_audio():
    try:
        audio_bytes = request.get_data()
        if not audio_bytes:
            return jsonify({"error":"no audio data"}), 400
        fname = f"{int(time.time())}_{uuid.uuid4().hex}.wav"
        path = TMP / fname
        path.write_bytes(audio_bytes)

        # 1) STT
        text = stt_from_wav_bytes(audio_bytes)
        if not text:
            text = "(Không nhận được lời nói)"

        # 2) LLM
        prompt = f"Bạn là trợ lý. Người dùng nói: '{text}'. Trả lời ngắn gọn bằng tiếng Việt."
        reply = llm_query(prompt)

        # 3) TTS
        tts_bytes = tts_to_wav_bytes(reply)
        tts_fname = ""
        if tts_bytes:
            tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav"
            (TMP / tts_fname).write_bytes(tts_bytes)

        return jsonify({"text": text, "reply": reply, "tts_file": tts_fname})
    except Exception as e:
        print("upload_audio error:", e)
        return jsonify({"error": str(e)}), 500

@app.route("/ask", methods=["POST"])
def ask():
    try:
        data = request.get_json(force=True)
        text = data.get("text", "")
        if not text:
            return jsonify({"error":"no text"}), 400
        reply = llm_query(text)
        tts_bytes = tts_to_wav_bytes(reply)
        tts_fname = ""
        if tts_bytes:
            tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav"
            (TMP / tts_fname).write_bytes(tts_bytes)
        return jsonify({"text": reply, "tts_file": tts_fname})
    except Exception as e:
        print("ask error:", e)
        return jsonify({"error": str(e)}), 500

@app.route("/tts/<fname>", methods=["GET"])
def serve_tts(fname):
    p = TMP / fname
    if not p.exists():
        return "Not found", 404
    return send_file(str(p), mimetype="audio/wav", as_attachment=False)

if __name__ == "__main__":
    port = int(os.getenv("PORT", "7860"))
    app.run(host="0.0.0.0", port=port, threaded=True)