DrPie commited on
Commit
d80e6d2
·
verified ·
1 Parent(s): 76c06ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -95
app.py CHANGED
@@ -1,103 +1,204 @@
1
- import os
2
- import re
3
- import unicodedata
4
- import pickle
5
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from flask import Flask, request, jsonify
 
 
 
 
7
  from rank_bm25 import BM25Okapi
8
- from huggingface_hub import InferenceClient
9
-
10
- # ===================== #
11
- # TIỀN XỬ LÝ VĂN BẢN #
12
- # ===================== #
13
-
14
- def normalize_text(text: str) -> str:
15
- text = text.lower()
16
- text = ''.join(c for c in unicodedata.normalize('NFD', text) if unicodedata.category(c) != 'Mn') # bỏ dấu tiếng Việt
17
- text = re.sub(r'[^a-z0-9\s]', ' ', text) # bỏ ký tự đặc biệt
18
- return text
19
-
20
- def tokenize(text: str):
21
- return normalize_text(text).split()
22
-
23
- # ===================== #
24
- # LOAD DỮ LIỆU #
25
- # ===================== #
26
-
27
- # File id_to_record.pkl chứa dict: id -> {ten_thu_tuc, mo_ta, yeu_cau, co_quan, link ...}
28
- with open("id_to_record.pkl", "rb") as f:
29
- id_to_record = pickle.load(f)
30
-
31
- # Tạo corpus cho BM25: mỗi record nối các trường thành 1 text
32
- corpus = []
33
- for rid, rec in id_to_record.items():
34
- fields = [str(rec.get(k, "")) for k in ["ten_thu_tuc", "mo_ta", "yeu_cau", "co_quan", "linh_vuc"]]
35
- text = " ".join(fields)
36
- corpus.append(tokenize(text))
37
-
38
- bm25 = BM25Okapi(corpus)
39
-
40
- # ===================== #
41
- # KHỞI TẠO FLASK APP #
42
- # ===================== #
43
-
44
- app = Flask(__name__)
45
- HF_TOKEN = os.getenv("HF_TOKEN")
46
- HF_MODEL = os.getenv("HF_MODEL", "gemini-pro") # đổi sang model bạn dùng
47
- client = InferenceClient(token=HF_TOKEN)
48
-
49
- # ===================== #
50
- # HÀM LẤY CONTEXT #
51
- # ===================== #
52
-
53
- def retrieve_context(query: str, top_k: int = 5):
54
- tokens = tokenize(query)
55
- scores = bm25.get_scores(tokens)
56
- top_idx = np.argsort(-scores)[:top_k]
57
- context_parts = []
58
- for idx in top_idx:
59
- if scores[idx] > 0: # chỉ lấy nếu score > 0
60
- rid = list(id_to_record.keys())[idx]
61
- rec = id_to_record[rid]
62
- # context gồm tên, mô tả, yêu cầu và link nếu có
63
- ctx = f"Tên: {rec.get('ten_thu_tuc','')}\nMô tả: {rec.get('mo_ta','')}\nYêu cầu: {rec.get('yeu_cau','')}\nCơ quan: {rec.get('co_quan','')}\nLink: {rec.get('link','')}"
64
- context_parts.append(ctx)
65
- return "\n\n".join(context_parts)
66
-
67
- # ===================== #
68
- # ROUTE /chat #
69
- # ===================== #
70
-
71
- @app.route("/chat", methods=["POST"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def chat():
73
- user_query = request.json.get("query", "")
74
- if not user_query.strip():
75
- return jsonify({"answer": "Bạn chưa nhập câu hỏi."})
76
-
77
- context = retrieve_context(user_query)
78
-
79
- prompt = f"""
80
- Bạn là trợ lý eGov-Bot, trả lời bằng tiếng Việt.
81
- Ưu tiên dùng thông tin từ DỮ LIỆU dưới đây để trả lời.
82
- Nếu dữ liệu không đủ, có thể suy luận hợp lý hoặc trả lời rằng chưa có đủ thông tin.
83
- Nếu có link nguồn trong dữ liệu, hãy cung cấp.
84
-
 
 
 
 
 
 
 
 
 
85
  DỮ LIỆU:
86
- {context if context.strip() else "Không tìm thấy thông tin nào khớp trực tiếp."}
87
-
88
- CÂU HỎI: {user_query}
89
- """
90
 
 
 
 
91
  try:
92
- response = client.text_generation(model=HF_MODEL, prompt=prompt, max_new_tokens=512)
93
- return jsonify({"answer": response.strip()})
94
  except Exception as e:
95
- return jsonify({"answer": f"Lỗi khi gọi model: {e}"})
96
-
97
- # ===================== #
98
- # MAIN APP #
99
- # ===================== #
100
 
101
- if __name__ == "__main__":
102
- # Debug mode cho dev, production có thể bỏ
103
- app.run(host="0.0.0.0", port=7860)
 
1
+ # =================== #
2
+ # Cache + Env setup #
3
+ # =================== #
4
+ import os, shutil
5
+
6
+ # Đặt cache vào /tmp để tránh lỗi permission trên Spaces
7
+ os.environ["HF_HOME"] = "/tmp/hf_home"
8
+ os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets"
11
+ os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
12
+ os.environ["HOME"] = "/tmp"
13
+ for p in ["/tmp/hf_home","/tmp/hf_cache","/tmp/hf_datasets","/tmp/.cache"]:
14
+ os.makedirs(p, exist_ok=True)
15
+ # Xóa cache cũ nếu có
16
+ shutil.rmtree("/.cache", ignore_errors=True)
17
+
18
+ # =================== #
19
+ # Import thư viện #
20
+ # =================== #
21
+ import time, hashlib, gzip, pickle, json, traceback, re
22
  from flask import Flask, request, jsonify
23
+ from flask_cors import CORS
24
+ import numpy as np
25
+ import faiss
26
+ from sentence_transformers import SentenceTransformer
27
  from rank_bm25 import BM25Okapi
28
+ import google.generativeai as genai
29
+ from cachetools import TTLCache
30
+ from huggingface_hub import login, hf_hub_download
31
+
32
+ # ================ #
33
+ # Load ENV & HF #
34
+ # ================ #
35
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
36
+ if HF_TOKEN:
37
+ try:
38
+ login(HF_TOKEN)
39
+ print("HF login successful")
40
+ except Exception as e:
41
+ print("Warning: HF login failed:", e)
42
+ else:
43
+ print("Warning: HF_TOKEN not found")
44
+
45
+ HF_REPO_ID = os.environ.get("HF_REPO_ID", "DrPie/eGoV_Data")
46
+ REPO_TYPE = os.environ.get("REPO_TYPE", "dataset")
47
+ EMB_MODEL = os.environ.get("EMB_MODEL", "AITeamVN/Vietnamese_Embedding")
48
+ GENAI_MODEL = os.environ.get("GENAI_MODEL", "gemini-2.5-flash")
49
+ TOP_K = int(os.environ.get("TOP_K", "3"))
50
+ FAISS_CANDIDATES = int(os.environ.get("FAISS_CANDIDATES", str(max(10, TOP_K*5))))
51
+ BM25_PREFILTER = int(os.environ.get("BM25_PREFILTER", "200"))
52
+ CACHE_TTL = int(os.environ.get("CACHE_TTL", "3600"))
53
+ CACHE_MAX = int(os.environ.get("CACHE_MAX", "2000"))
54
+
55
+ print("--- KHỞI ĐỘNG MÁY CHỦ CHATBOT (optimized & id_to_record) ---")
56
+
57
+ # ================ #
58
+ # Download data #
59
+ # ================ #
60
+ RAW_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="toan_bo_du_lieu_final.json", repo_type=REPO_TYPE)
61
+ FAISS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="index.faiss", repo_type=REPO_TYPE)
62
+ BM25_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="bm25.pkl.gz", repo_type=REPO_TYPE)
63
+ METAS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="metas.pkl.gz", repo_type=REPO_TYPE)
64
+ # Load id_to_record.pkl nếu có
65
+ try:
66
+ ID2REC_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="id_to_record.pkl", repo_type=REPO_TYPE)
67
+ with open(ID2REC_PATH,"rb") as f:
68
+ id_to_record = pickle.load(f)
69
+ except Exception as e:
70
+ print("⚠️ Không tải được id_to_record.pkl:", e)
71
+ id_to_record = {}
72
+
73
+ # ================ #
74
+ # Load resources #
75
+ # ================ #
76
+ faiss_index = faiss.read_index(FAISS_PATH)
77
+ with gzip.open(BM25_PATH,"rb") as f: bm25 = pickle.load(f)
78
+ with gzip.open(METAS_PATH,"rb") as f:
79
+ metas = pickle.load(f)
80
+ if isinstance(metas,dict) and "corpus" in metas:
81
+ corpus = metas["corpus"]
82
+ else:
83
+ corpus = metas
84
+
85
+ # Lưu list key để tránh tạo lại nhiều lần
86
+ meta_keys = list(range(len(corpus)))
87
+
88
+ # Load embedding model
89
+ device = os.environ.get("DEVICE","cpu")
90
+ embedding_model = SentenceTransformer(EMB_MODEL, device=device)
91
+
92
+ # Load raw_data làm fallback để build procedure_map
93
+ try:
94
+ with open(RAW_PATH,"r",encoding="utf-8") as f:
95
+ raw_data = json.load(f)
96
+ procedure_map = {item.get('nguon') or item.get('parent_id') or str(i): item for i,item in enumerate(raw_data)}
97
+ except Exception:
98
+ procedure_map = {}
99
+
100
+ # GenAI init
101
+ API_KEY = os.environ.get("GOOGLE_API_KEY")
102
+ generation_model = None
103
+ if API_KEY:
104
+ try:
105
+ genai.configure(api_key=API_KEY)
106
+ generation_model = genai.GenerativeModel(GENAI_MODEL)
107
+ except Exception as e:
108
+ print("Warning: cannot init GenAI:", e)
109
+
110
+ answer_cache = TTLCache(maxsize=CACHE_MAX, ttl=CACHE_TTL)
111
+
112
+ # =================== #
113
+ # Utility / Retrieve #
114
+ # =================== #
115
+ def minmax_scale(arr):
116
+ arr=np.array(arr,dtype="float32")
117
+ return np.zeros_like(arr) if len(arr)==0 or np.max(arr)==np.min(arr) else (arr-np.min(arr))/(np.max(arr)-np.min(arr))
118
+
119
+ def classify_followup(text:str)->int:
120
+ # như code gốc, bỏ bớt regex nặng để nhanh hơn
121
+ t=text.lower().strip()
122
+ if len(t.split())<=4: return 0
123
+ if re.search(r"\b(nó|cái này|thế thì|vậy thì)\b",t): return 0
124
+ return 1
125
+
126
+ def retrieve(query:str, top_k=TOP_K):
127
+ qv = embedding_model.encode([query],normalize_embeddings=True).astype("float32")
128
+ D,I = faiss_index.search(qv, max(FAISS_CANDIDATES, top_k*5))
129
+ vec_idx = I[0].tolist()
130
+ vec_scores = (1-D[0]).tolist()
131
+ # BM25 prefilter
132
+ try:
133
+ bm25_scores_all = bm25.get_scores(query.split())
134
+ bm25_top_idx = np.argsort(-bm25_scores_all)[:BM25_PREFILTER].tolist()
135
+ except Exception:
136
+ bm25_top_idx=[]
137
+ union_idx = list(dict.fromkeys(vec_idx+bm25_top_idx))
138
+ vec_map = {i:s for i,s in zip(vec_idx,vec_scores)}
139
+ vec_list=[vec_map.get(i,0.0) for i in union_idx]
140
+ bm25_list=[bm25_scores_all[i] if i<len(bm25_scores_all) else 0.0 for i in union_idx]
141
+ fused=0.7*minmax_scale(vec_list)+0.3*minmax_scale(bm25_list)
142
+ order=np.argsort(-fused)
143
+ return [union_idx[i] for i in order[:top_k]]
144
+
145
+ def get_full_procedure_text_by_parent(pid):
146
+ rec=None
147
+ if id_to_record:
148
+ rec=id_to_record.get(pid)
149
+ if not rec:
150
+ rec=procedure_map.get(pid)
151
+ if not rec: return "Không tìm thấy thủ tục."
152
+ field_map={"ten_thu_tuc":"Tên thủ tục","cach_thuc_thuc_hien":"Cách thức thực hiện","thanh_phan_ho_so":"Thành phần hồ sơ","trinh_tu_thuc_hien":"Trình tự thực hiện","co_quan_thuc_hien":"Cơ quan thực hiện","yeu_cau_dieu_kien":"Yêu cầu, điều kiện","nguon":"Nguồn"}
153
+ return "\n\n".join([f"{field_map[k]}:\n{v}" for k,v in rec.items() if k in field_map and v])
154
+
155
+ # ================ #
156
+ # Flask endpoints #
157
+ # ================ #
158
+ app=Flask(__name__)
159
+ CORS(app)
160
+ chat_histories={}
161
+
162
+ @app.route("/health")
163
+ def health(): return {"status":"ok"}
164
+
165
+ @app.route("/chat",methods=["POST"])
166
  def chat():
167
+ data=request.get_json(force=True)
168
+ user_query=data.get("question")
169
+ sid=data.get("session_id","default")
170
+ if not user_query: return jsonify({"error":"No question provided"}),400
171
+ if sid not in chat_histories: chat_histories[sid]=[]
172
+ hist=chat_histories[sid]
173
+ if classify_followup(user_query)==0 and hist:
174
+ context=hist[-1].get("context","")
175
+ else:
176
+ idxs=retrieve(user_query,TOP_K)
177
+ if idxs:
178
+ meta=metas[idxs[0]]
179
+ pid=meta.get("parent_id") or meta.get("nguon")
180
+ context=get_full_procedure_text_by_parent(pid)
181
+ else: context=""
182
+ history_str="\n".join([f"{m['role']}: {m['content']}" for m in hist])
183
+ prompt=f"""Bạn là trợ lý eGov-Bot dịch vụ công Việt Nam.
184
+ Trả lời tiếng Việt, chính xác, dựa dữ liệu nếu có.
185
+ Nếu thiếu dữ liệu, nói "Mình chưa có thông tin" và đưa link nguồn trong dữ liệu.
186
+
187
+ Lịch sử: {history_str}
188
  DỮ LIỆU:
189
+ {context}
 
 
 
190
 
191
+ CÂU HỎI: {user_query}"""
192
+ if not generation_model:
193
+ return jsonify({"answer":"LLM model chưa sẵn sàng (kiểm tra GOOGLE_API_KEY)."})
194
  try:
195
+ resp=generation_model.generate_content(prompt)
196
+ ans=getattr(resp,"text",str(resp))
197
  except Exception as e:
198
+ return jsonify({"error":"LLM call failed","detail":str(e)}),200
199
+ hist.append({'role':'user','content':user_query})
200
+ hist.append({'role':'model','content':ans,'context':context})
201
+ return jsonify({"answer":ans})
 
202
 
203
+ if __name__=="__main__":
204
+ app.run(host="0.0.0.0",port=int(os.environ.get("PORT",7860)))