DrPie commited on
Commit
2c07152
·
verified ·
1 Parent(s): ce128fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -130
app.py CHANGED
@@ -1,23 +1,21 @@
1
- # =================== #
2
- # Cache + Env setup #
3
- # =================== #
4
- import os, shutil
5
 
6
- # Đặt cache vào /tmp để tránh lỗi permission trên Spaces
7
  os.environ["HF_HOME"] = "/tmp/hf_home"
8
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets"
11
  os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
12
  os.environ["HOME"] = "/tmp"
13
- for p in ["/tmp/hf_home","/tmp/hf_cache","/tmp/hf_datasets","/tmp/.cache"]:
14
- os.makedirs(p, exist_ok=True)
 
 
15
  shutil.rmtree("/.cache", ignore_errors=True)
16
 
17
- # =================== #
18
- # Import thư viện #
19
- # =================== #
20
- import gzip, pickle, json, re
21
  from flask import Flask, request, jsonify
22
  from flask_cors import CORS
23
  import numpy as np
@@ -28,9 +26,7 @@ import google.generativeai as genai
28
  from cachetools import TTLCache
29
  from huggingface_hub import login, hf_hub_download
30
 
31
- # ================ #
32
- # Load ENV & HF #
33
- # ================ #
34
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
35
  if HF_TOKEN:
36
  try:
@@ -39,8 +35,9 @@ if HF_TOKEN:
39
  except Exception as e:
40
  print("Warning: HF login failed:", e)
41
  else:
42
- print("Warning: HF_TOKEN not found")
43
 
 
44
  HF_REPO_ID = os.environ.get("HF_REPO_ID", "DrPie/eGoV_Data")
45
  REPO_TYPE = os.environ.get("REPO_TYPE", "dataset")
46
  EMB_MODEL = os.environ.get("EMB_MODEL", "AITeamVN/Vietnamese_Embedding")
@@ -51,147 +48,209 @@ BM25_PREFILTER = int(os.environ.get("BM25_PREFILTER", "200"))
51
  CACHE_TTL = int(os.environ.get("CACHE_TTL", "3600"))
52
  CACHE_MAX = int(os.environ.get("CACHE_MAX", "2000"))
53
 
54
- print("--- KHỞI ĐỘNG MÁY CHỦ CHATBOT (optimized & id_to_record) ---")
55
 
56
- # ================ #
57
- # Download data #
58
- # ================ #
59
- RAW_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="toan_bo_du_lieu_final.json", repo_type=REPO_TYPE)
60
- FAISS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="index.faiss", repo_type=REPO_TYPE)
61
- BM25_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="bm25.pkl.gz", repo_type=REPO_TYPE)
62
- METAS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="metas.pkl.gz", repo_type=REPO_TYPE)
63
- # Load id_to_record.pkl nếu có
64
  try:
65
- ID2REC_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="id_to_record.pkl", repo_type=REPO_TYPE)
66
- with open(ID2REC_PATH,"rb") as f:
67
- id_to_record = pickle.load(f)
 
 
 
 
68
  except Exception as e:
69
- print("⚠️ Không tải được id_to_record.pkl:", e)
70
- id_to_record = {}
71
-
72
- # ================ #
73
- # Load resources #
74
- # ================ #
75
- faiss_index = faiss.read_index(FAISS_PATH)
76
- with gzip.open(BM25_PATH,"rb") as f: bm25 = pickle.load(f)
77
- with gzip.open(METAS_PATH,"rb") as f: metas = pickle.load(f)
78
- if isinstance(metas,dict) and "corpus" in metas:
79
- corpus = metas["corpus"]
80
  else:
81
- corpus = metas
82
 
83
- # Load embedding model
84
- device = os.environ.get("DEVICE","cpu")
 
 
 
 
 
 
 
 
 
85
  embedding_model = SentenceTransformer(EMB_MODEL, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # Load raw_data làm fallback để build procedure_map
88
  try:
89
- with open(RAW_PATH,"r",encoding="utf-8") as f:
90
  raw_data = json.load(f)
91
- procedure_map = {item.get('nguon') or item.get('parent_id') or str(i): item for i,item in enumerate(raw_data)}
92
  except Exception:
93
- procedure_map = {}
94
 
95
- # GenAI init
96
- API_KEY = os.environ.get("GOOGLE_API_KEY")
97
- generation_model = None
98
  if API_KEY:
99
  try:
100
- genai.configure(api_key=API_KEY)
101
  generation_model = genai.GenerativeModel(GENAI_MODEL)
102
  except Exception as e:
103
- print("Warning: cannot init GenAI:", e)
104
-
105
- # =================== #
106
- # Utility / Retrieve #
107
- # =================== #
108
- def minmax_scale(arr):
109
- arr=np.array(arr,dtype="float32")
110
- return np.zeros_like(arr) if len(arr)==0 or np.max(arr)==np.min(arr) else (arr-np.min(arr))/(np.max(arr)-np.min(arr))
111
-
112
- def classify_followup(text:str)->int:
113
- t=text.lower().strip()
114
- if len(t.split())<=4: return 0
115
- if re.search(r"\b(nó|cái này|thế thì|vậy thì)\b",t): return 0
116
- return 1
117
-
118
- def retrieve(query:str, top_k=TOP_K):
119
- qv = embedding_model.encode([query],normalize_embeddings=True).astype("float32")
120
- D,I = faiss_index.search(qv, max(FAISS_CANDIDATES, top_k*5))
121
- vec_idx = I[0].tolist()
122
- vec_scores = (1-D[0]).tolist()
123
- # BM25 prefilter
124
- try:
125
- bm25_scores_all = bm25.get_scores(query.split())
126
- bm25_top_idx = np.argsort(-bm25_scores_all)[:BM25_PREFILTER].tolist()
127
- except Exception:
128
- bm25_top_idx=[]
129
- union_idx = list(dict.fromkeys(vec_idx+bm25_top_idx))
130
- vec_map = {i:s for i,s in zip(vec_idx,vec_scores)}
131
- vec_list=[vec_map.get(i,0.0) for i in union_idx]
132
- bm25_list=[bm25_scores_all[i] if i<len(bm25_scores_all) else 0.0 for i in union_idx]
133
- fused=0.7*minmax_scale(vec_list)+0.3*minmax_scale(bm25_list)
134
- order=np.argsort(-fused)
135
- return [union_idx[i] for i in order[:top_k]]
136
 
137
- def get_full_procedure_text_by_parent(pid):
138
- rec=None
139
- if id_to_record:
140
- rec=id_to_record.get(pid)
141
- if not rec:
142
- rec=procedure_map.get(pid)
143
- if not rec: return "Không tìm thấy thủ tục."
144
- field_map={"ten_thu_tuc":"Tên thủ tục","cach_thuc_thuc_hien":"Cách thức thực hiện","thanh_phan_ho_so":"Thành phần hồ sơ","trinh_tu_thuc_hien":"Trình tự thực hiện","co_quan_thuc_hien":"Cơ quan thực hiện","yeu_cau_dieu_kien":"Yêu cầu, điều kiện","nguon":"Nguồn"}
145
- return "\n\n".join([f"{field_map[k]}:\n{v}" for k,v in rec.items() if k in field_map and v])
146
-
147
- # =================== #
148
- # Flask endpoints #
149
- # =================== #
150
- app=Flask(__name__)
151
  CORS(app)
152
- chat_histories={}
153
- answer_cache = TTLCache(maxsize=CACHE_MAX, ttl=CACHE_TTL)
 
 
 
 
 
 
 
 
154
 
155
- @app.route("/health")
156
- def health(): return {"status":"ok"}
 
157
 
158
- @app.route("/chat",methods=["POST"])
159
  def chat():
160
- data=request.get_json(force=True)
161
- user_query=data.get("question")
162
- sid=data.get("session_id","default")
163
- if not user_query: return jsonify({"error":"No question provided"}),400
164
- if sid not in chat_histories: chat_histories[sid]=[]
165
- hist=chat_histories[sid]
166
- if classify_followup(user_query)==0 and hist:
167
- context=hist[-1].get("context","")
 
 
 
 
 
168
  else:
169
- idxs=retrieve(user_query,TOP_K)
170
  if idxs:
171
- meta=metas[idxs[0]]
172
- pid=meta.get("parent_id") or meta.get("nguon")
173
- context=get_full_procedure_text_by_parent(pid)
174
- else: context=""
175
- history_str="\n".join([f"{m['role']}: {m['content']}" for m in hist])
176
- prompt=f"""Bạn là trợ lý eGov-Bot dịch vụ công Việt Nam.
177
- Trả lời tiếng Việt, chính xác, dựa dữ liệu nếu có.
178
- Nếu thiếu dữ liệu, nói "Mình chưa có thông tin" và đưa link nguồn trong dữ liệu.
179
-
180
  Lịch sử: {history_str}
181
  DỮ LIỆU:
 
182
  {context}
183
-
184
  CÂU HỎI: {user_query}"""
185
- if not generation_model:
186
- return jsonify({"answer":"LLM model chưa sẵn sàng (kiểm tra GOOGLE_API_KEY)."})
187
  try:
188
- resp=generation_model.generate_content(prompt)
189
- ans=getattr(resp,"text",str(resp))
 
 
190
  except Exception as e:
191
- return jsonify({"error":"LLM call failed","detail":str(e)}),200
192
- hist.append({'role':'user','content':user_query})
193
- hist.append({'role':'model','content':ans,'context':context})
194
- return jsonify({"answer":ans})
195
 
196
- if __name__=="__main__":
197
- app.run(host="0.0.0.0",port=int(os.environ.get("PORT",7860)))
 
1
+ import os
2
+ import shutil
 
 
3
 
4
+ # --- Cache + Env setup --- (phải đặt lên đầu)
5
  os.environ["HF_HOME"] = "/tmp/hf_home"
6
  os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
7
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
8
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets"
9
  os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
10
  os.environ["HOME"] = "/tmp"
11
+ os.makedirs("/tmp/hf_home", exist_ok=True)
12
+ os.makedirs("/tmp/hf_cache", exist_ok=True)
13
+ os.makedirs("/tmp/hf_datasets", exist_ok=True)
14
+ os.makedirs("/tmp/.cache", exist_ok=True)
15
  shutil.rmtree("/.cache", ignore_errors=True)
16
 
17
+ # --- Import các thư viện còn lại ---
18
+ import time, hashlib, gzip, pickle, json, traceback, re
 
 
19
  from flask import Flask, request, jsonify
20
  from flask_cors import CORS
21
  import numpy as np
 
26
  from cachetools import TTLCache
27
  from huggingface_hub import login, hf_hub_download
28
 
29
+ # ---------- HF login ----------
 
 
30
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
31
  if HF_TOKEN:
32
  try:
 
35
  except Exception as e:
36
  print("Warning: HF login failed:", e)
37
  else:
38
+ print("Warning: HF_TOKEN not found - only public repos accessible")
39
 
40
+ # ---------- Config ----------
41
  HF_REPO_ID = os.environ.get("HF_REPO_ID", "DrPie/eGoV_Data")
42
  REPO_TYPE = os.environ.get("REPO_TYPE", "dataset")
43
  EMB_MODEL = os.environ.get("EMB_MODEL", "AITeamVN/Vietnamese_Embedding")
 
48
  CACHE_TTL = int(os.environ.get("CACHE_TTL", "3600"))
49
  CACHE_MAX = int(os.environ.get("CACHE_MAX", "2000"))
50
 
51
+ print("--- KHỞI ĐỘNG MÁY CHỦ CHATBOT (optimized & fixed) ---")
52
 
53
+ # ---------- Download dataset ----------
54
+ FAISS_PATH = None
 
 
 
 
 
 
55
  try:
56
+ RAW_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="toan_bo_du_lieu_final.json", repo_type=REPO_TYPE)
57
+ FAISS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="index.faiss", repo_type=REPO_TYPE)
58
+ METAS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="metas.pkl.gz", repo_type=REPO_TYPE)
59
+ BM25_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="bm25.pkl.gz", repo_type=REPO_TYPE)
60
+ # tải thêm file id_to_record.pkl
61
+ ID_TO_RECORD_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="id_to_record.pkl", repo_type=REPO_TYPE)
62
+ print("✅ Files downloaded or already available.")
63
  except Exception as e:
64
+ print(" LỖI KHI TẢI TÀI NGUỒN:", e)
65
+
66
+ if FAISS_PATH:
67
+ print("Loading FAISS index from:", FAISS_PATH)
68
+ faiss_index = faiss.read_index(FAISS_PATH)
 
 
 
 
 
 
69
  else:
70
+ raise RuntimeError("Không có file FAISS index.")
71
 
72
+ # ---------- External APIs ----------
73
+ API_KEY = os.environ.get("GOOGLE_API_KEY")
74
+ if not API_KEY:
75
+ print("Warning: GOOGLE_API_KEY missing.")
76
+ else:
77
+ genai.configure(api_key=API_KEY)
78
+
79
+ # ---------- Load models ----------
80
+ t0 = time.perf_counter()
81
+ device = os.environ.get("DEVICE", "cpu")
82
+ print("Loading embedding model:", EMB_MODEL)
83
  embedding_model = SentenceTransformer(EMB_MODEL, device=device)
84
+ print("Embedding model loaded.")
85
+ print("FAISS index ntotal =", getattr(faiss_index, "ntotal", "unknown"))
86
+
87
+ with gzip.open(METAS_PATH, "rb") as f:
88
+ metas = pickle.load(f)
89
+ corpus = metas["corpus"] if isinstance(metas, dict) and "corpus" in metas else metas
90
+ with gzip.open(BM25_PATH, "rb") as f:
91
+ bm25 = pickle.load(f)
92
+ metadatas = corpus
93
+ # load id_to_record map
94
+ with open(ID_TO_RECORD_PATH, "rb") as f:
95
+ id_to_record = pickle.load(f)
96
+
97
+ print("Loaded metas, BM25, id_to_record. corpus size:", len(corpus))
98
+ print("Resources load time: %.2fs" % (time.perf_counter() - t0))
99
+
100
+ answer_cache = TTLCache(maxsize=CACHE_MAX, ttl=CACHE_TTL)
101
+
102
+ # ---------- Utility functions ----------
103
+ def _norm_key(s: str) -> str:
104
+ return " ".join(s.lower().strip().split())
105
+
106
+ def cache_key_for_query(q: str) -> str:
107
+ raw = f"{_norm_key(q)}|emb={EMB_MODEL}|k={TOP_K}|p={BM25_PREFILTER}"
108
+ return hashlib.sha256(raw.encode("utf-8")).hexdigest()
109
+
110
+ def minmax_scale(arr):
111
+ arr = np.array(arr, dtype="float32")
112
+ if len(arr) == 0 or np.max(arr) == np.min(arr):
113
+ return np.zeros_like(arr)
114
+ return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
115
+
116
+ def classify_followup(text: str):
117
+ text = text.lower().strip()
118
+ score = 0
119
+ strong_followup_keywords = [
120
+ r"\b(nó|cái (này|đó|ấy)|thủ tục (này|đó|ấy))\b",
121
+ r"\b(vừa (nói|hỏi)|trước đó|ở trên|phía trên)\b",
122
+ r"\b(tiếp theo|tiếp|còn nữa|ngoài ra)\b",
123
+ r"\b(thế (thì|à)|vậy (thì|à)|như vậy)\b"
124
+ ]
125
+ detail_questions = [
126
+ r"\b(mất bao lâu|thời gian|bao nhiêu tiền|chi phí|phí)\b",
127
+ r"\b(ở đâu|tại đâu|chỗ nào|địa chỉ)\b",
128
+ r"\b(cần (gì|những gì)|yêu cầu|điều kiện)\b"
129
+ ]
130
+ specific_services = [
131
+ r"\b(làm|cấp|gia hạn|đổi|đăng ký)\s+(căn cước|cmnd|cccd)\b",
132
+ r"\b(làm|cấp|gia hạn|đổi)\s+hộ chiếu\b",
133
+ r"\b(đăng ký)\s+(kết hôn|sinh|tử|hộ khẩu)\b"
134
+ ]
135
+ if any(re.search(p, text) for p in strong_followup_keywords): score -= 3
136
+ if any(re.search(p, text) for p in detail_questions): score -= 2
137
+ if any(re.search(p, text) for p in specific_services): score += 3
138
+ if len(text.split()) <= 4: score -= 1
139
+ return 0 if score < 0 else 1
140
+
141
+ def retrieve(query: str, top_k=TOP_K):
142
+ qv = embedding_model.encode([query], convert_to_numpy=True, normalize_embeddings=True).astype("float32")
143
+ try:
144
+ tokenized = query.split()
145
+ bm25_scores_all = bm25.get_scores(tokenized)
146
+ bm25_top_idx = np.argsort(-bm25_scores_all)[:BM25_PREFILTER].tolist()
147
+ except Exception:
148
+ bm25_top_idx = []
149
+ k_cand = max(FAISS_CANDIDATES, top_k*5)
150
+ D, I = faiss_index.search(qv, k_cand)
151
+ vec_idx = I[0].tolist()
152
+ vec_scores = (1 - D[0]).tolist()
153
+ union_idx = list(dict.fromkeys(vec_idx + bm25_top_idx))
154
+ vec_map = {i: s for i, s in zip(vec_idx, vec_scores)}
155
+ vec_list = [vec_map.get(i, 0.0) for i in union_idx]
156
+ bm25_scores_all = bm25.get_scores(query.split())
157
+ bm25_list = [bm25_scores_all[i] if i < len(bm25_scores_all) else 0.0 for i in union_idx]
158
+ fused = 0.7 * minmax_scale(vec_list) + 0.3 * minmax_scale(bm25_list)
159
+ order = np.argsort(-fused)
160
+ return [union_idx[i] for i in order[:top_k]]
161
+
162
+ def get_full_procedure_text_by_parent(parent_id):
163
+ # Tra cứu nhanh bằng id_to_record thay vì duyệt metadatas
164
+ procedure = id_to_record.get(parent_id)
165
+ if not procedure:
166
+ return "Không tìm thấy thủ tục."
167
+ field_map = {
168
+ "ten_thu_tuc": "Tên thủ tục",
169
+ "cach_thuc_thuc_hien": "Cách thức thực hiện",
170
+ "thanh_phan_ho_so": "Thành phần hồ sơ",
171
+ "trinh_tu_thuc_hien": "Trình tự thực hiện",
172
+ "co_quan_thuc_hien": "Cơ quan thực hiện",
173
+ "yeu_cau_dieu_kien": "Yêu cầu, điều kiện",
174
+ "thu_tuc_lien_quan": "Thủ tục liên quan",
175
+ "nguon": "Nguồn"
176
+ }
177
+ parts = [f"{field_map[k]}:\n{str(v).strip()}" for k,v in procedure.items() if v and k in field_map]
178
+ return "\n\n".join(parts)
179
 
 
180
  try:
181
+ with open(RAW_PATH, "r", encoding="utf-8") as f:
182
  raw_data = json.load(f)
 
183
  except Exception:
184
+ raw_data = []
185
 
 
 
 
186
  if API_KEY:
187
  try:
 
188
  generation_model = genai.GenerativeModel(GENAI_MODEL)
189
  except Exception as e:
190
+ print("Warning: cannot init generation_model:", e)
191
+ generation_model = None
192
+ else:
193
+ generation_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ # ---------- Flask ----------
196
+ app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
197
  CORS(app)
198
+ chat_histories = {}
199
+
200
+ @app.route("/chat_debug", methods=["POST"])
201
+ def chat_debug():
202
+ try:
203
+ raw = request.get_data(as_text=True)
204
+ headers = dict(request.headers)
205
+ return jsonify({"ok": True, "raw_body": raw, "headers": headers})
206
+ except Exception as e:
207
+ return jsonify({"ok": False, "error": str(e), "trace": traceback.format_exc()}), 200
208
 
209
+ @app.route("/health", methods=["GET"])
210
+ def health():
211
+ return {"status": "ok"}
212
 
213
+ @app.route("/chat", methods=["POST"])
214
  def chat():
215
+ try:
216
+ data = request.get_json(force=True)
217
+ except Exception as e:
218
+ return jsonify({"error": "cannot parse JSON", "detail": str(e)}), 400
219
+ user_query = data.get('question')
220
+ session_id = data.get('session_id', 'default')
221
+ if not user_query:
222
+ return jsonify({"error": "No question provided"}), 400
223
+ if session_id not in chat_histories:
224
+ chat_histories[session_id] = []
225
+ current_history = chat_histories[session_id]
226
+ if classify_followup(user_query) == 0 and current_history:
227
+ context = current_history[-1].get('context', '')
228
  else:
229
+ idxs = retrieve(user_query, top_k=TOP_K)
230
  if idxs:
231
+ parent_id = metadatas[idxs[0]].get("parent_id") or metadatas[idxs[0]].get("nguon")
232
+ context = get_full_procedure_text_by_parent(parent_id)
233
+ else:
234
+ context = ""
235
+ history_str = "\n".join([f"{h['role']}: {h['content']}" for h in current_history])
236
+ prompt = f"""Bạn là trợ lý eGov-Bot dịch vụ công Việt Nam. Trả lời tiếng Việt dựa vào DỮ LIỆU cung cấp.
237
+ Nếu thiếu dữ liệu, hãy nói "Mình chưa thông tin" và đưa link nguồn.
 
 
238
  Lịch sử: {history_str}
239
  DỮ LIỆU:
240
+ ---
241
  {context}
242
+ ---
243
  CÂU HỎI: {user_query}"""
 
 
244
  try:
245
+ if generation_model is None:
246
+ raise RuntimeError("generation_model not available.")
247
+ response = generation_model.generate_content(prompt)
248
+ final_answer = getattr(response, "text", str(response))
249
  except Exception as e:
250
+ return jsonify({"error": "LLM call failed", "detail": str(e), "trace": traceback.format_exc()}), 200
251
+ current_history.append({'role':'user','content':user_query})
252
+ current_history.append({'role':'model','content':final_answer,'context':context})
253
+ return jsonify({"answer": final_answer})
254
 
255
+ if __name__ == "__main__":
256
+ app.run(host="0.0.0.0", port=int(os.environ.get("PORT",7860)))