Files changed (2) hide show
  1. app.py +151 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ import os, uuid, time
3
+ from pathlib import Path
4
+ import requests
5
+
6
+ app = Flask(__name__)
7
+
8
+ # --- Config via environment variables in Space settings ---
9
+ HF_TOKEN = os.getenv("HF_TOKEN", "") # required
10
+ HF_STT_MODEL = os.getenv("HF_STT_MODEL", "openai/whisper-small") # speech-to-text
11
+ HF_LLM_MODEL = os.getenv("HF_LLM_MODEL", "google/flan-t5-large") # text generation
12
+ HF_TTS_MODEL = os.getenv("HF_TTS_MODEL", "espnet/kan-bayashi_ljspeech_vits") # text-to-speech (example)
13
+
14
+ TMP = Path("tmp")
15
+ TMP.mkdir(exist_ok=True)
16
+
17
+ if not HF_TOKEN:
18
+ print("WARNING: HF_TOKEN not set. Set it in Space Settings -> Secrets/Vars.")
19
+
20
+ HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
21
+
22
+ def hf_inference_raw(model_id: str, input_data, headers_extra=None, as_json=True, timeout=120):
23
+ url = f"https://api-inference.huggingface.co/models/{model_id}"
24
+ headers = HEADERS.copy()
25
+ if headers_extra:
26
+ headers.update(headers_extra)
27
+ if isinstance(input_data, (bytes, bytearray)):
28
+ resp = requests.post(url, headers=headers, data=input_data, timeout=timeout)
29
+ else:
30
+ resp = requests.post(url, headers=headers, json=input_data, timeout=timeout)
31
+ resp.raise_for_status()
32
+ if as_json:
33
+ try:
34
+ return resp.json()
35
+ except Exception:
36
+ return resp.text
37
+ else:
38
+ return resp.content
39
+
40
+ def stt_from_wav_bytes(wav_bytes: bytes):
41
+ # Many whisper-based HF endpoints accept audio/wav or audio/mpeg
42
+ try:
43
+ out = hf_inference_raw(HF_STT_MODEL, wav_bytes, headers_extra={"Content-Type":"audio/wav"}, as_json=True, timeout=300)
44
+ if isinstance(out, dict) and "text" in out:
45
+ return out["text"]
46
+ if isinstance(out, str):
47
+ return out
48
+ # fallback
49
+ return str(out)
50
+ except Exception as e:
51
+ print("STT error:", e)
52
+ return ""
53
+
54
+ def llm_query(prompt: str):
55
+ payload = {"inputs": prompt}
56
+ try:
57
+ out = hf_inference_raw(HF_LLM_MODEL, payload, as_json=True, timeout=120)
58
+ # handle common shapes
59
+ if isinstance(out, list) and len(out) > 0:
60
+ first = out[0]
61
+ if isinstance(first, dict) and "generated_text" in first:
62
+ return first["generated_text"]
63
+ return str(first)
64
+ if isinstance(out, dict):
65
+ if "generated_text" in out:
66
+ return out["generated_text"]
67
+ if "text" in out:
68
+ return out["text"]
69
+ return str(out)
70
+ if isinstance(out, str):
71
+ return out
72
+ return str(out)
73
+ except Exception as e:
74
+ print("LLM error:", e)
75
+ return "Xin lỗi, tôi gặp lỗi khi trả lời."
76
+
77
+ def tts_to_wav_bytes(text: str):
78
+ payload = {"inputs": text}
79
+ try:
80
+ # request raw bytes
81
+ url = f"https://api-inference.huggingface.co/models/{HF_TTS_MODEL}"
82
+ resp = requests.post(url, headers=HEADERS, json=payload, timeout=120)
83
+ resp.raise_for_status()
84
+ return resp.content
85
+ except Exception as e:
86
+ print("TTS error:", e, getattr(e, 'response', None) and e.response.text)
87
+ return b""
88
+
89
+ @app.route("/health")
90
+ def health():
91
+ return jsonify({"ok": True})
92
+
93
+ @app.route("/upload_audio", methods=["POST"])
94
+ def upload_audio():
95
+ try:
96
+ audio_bytes = request.get_data()
97
+ if not audio_bytes:
98
+ return jsonify({"error":"no audio data"}), 400
99
+ fname = f"{int(time.time())}_{uuid.uuid4().hex}.wav"
100
+ path = TMP / fname
101
+ path.write_bytes(audio_bytes)
102
+
103
+ # 1) STT
104
+ text = stt_from_wav_bytes(audio_bytes)
105
+ if not text:
106
+ text = "(Không nhận được lời nói)"
107
+
108
+ # 2) LLM
109
+ prompt = f"Bạn là trợ lý. Người dùng nói: '{text}'. Trả lời ngắn gọn bằng tiếng Việt."
110
+ reply = llm_query(prompt)
111
+
112
+ # 3) TTS
113
+ tts_bytes = tts_to_wav_bytes(reply)
114
+ tts_fname = ""
115
+ if tts_bytes:
116
+ tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav"
117
+ (TMP / tts_fname).write_bytes(tts_bytes)
118
+
119
+ return jsonify({"text": text, "reply": reply, "tts_file": tts_fname})
120
+ except Exception as e:
121
+ print("upload_audio error:", e)
122
+ return jsonify({"error": str(e)}), 500
123
+
124
+ @app.route("/ask", methods=["POST"])
125
+ def ask():
126
+ try:
127
+ data = request.get_json(force=True)
128
+ text = data.get("text", "")
129
+ if not text:
130
+ return jsonify({"error":"no text"}), 400
131
+ reply = llm_query(text)
132
+ tts_bytes = tts_to_wav_bytes(reply)
133
+ tts_fname = ""
134
+ if tts_bytes:
135
+ tts_fname = f"tts_{int(time.time())}_{uuid.uuid4().hex}.wav"
136
+ (TMP / tts_fname).write_bytes(tts_bytes)
137
+ return jsonify({"text": reply, "tts_file": tts_fname})
138
+ except Exception as e:
139
+ print("ask error:", e)
140
+ return jsonify({"error": str(e)}), 500
141
+
142
+ @app.route("/tts/<fname>", methods=["GET"])
143
+ def serve_tts(fname):
144
+ p = TMP / fname
145
+ if not p.exists():
146
+ return "Not found", 404
147
+ return send_file(str(p), mimetype="audio/wav", as_attachment=False)
148
+
149
+ if __name__ == "__main__":
150
+ port = int(os.getenv("PORT", "7860"))
151
+ app.run(host="0.0.0.0", port=port, threaded=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ flask
2
+ requests