kerodat2004 commited on
Commit
21243cd
·
verified ·
1 Parent(s): 2ee6337

Update chatbot_rag.py

Browse files
Files changed (1) hide show
  1. chatbot_rag.py +189 -189
chatbot_rag.py CHANGED
@@ -1,190 +1,190 @@
1
- import torch
2
- import re
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from src.search import search_places, is_meaningless_query
5
-
6
- # ========================
7
- # MODEL
8
- # ========================
9
- last_search_results = []
10
- waiting_for_suggestion = False
11
- last_query_no_result = ""
12
-
13
- model_name = "Qwen/Qwen2.5-0.5B-Instruct"
14
-
15
- model = AutoModelForCausalLM.from_pretrained(
16
- model_name,
17
- torch_dtype=torch.float16,
18
- device_map="auto"
19
- )
20
-
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
-
23
- # ========================
24
- # MEMORY (tạm thời vẫn dùng global)
25
- # ========================
26
- last_search_results = []
27
-
28
- # ========================
29
- # PARAPHRASE
30
- # ========================
31
- def paraphrase_description(text):
32
- messages = [
33
- {
34
- "role": "system",
35
- "content": (
36
- "Bạn chỉ có nhiệm vụ viết lại câu văn cho tự nhiên hơn.\n"
37
- "KHÔNG thêm thông tin mới.\n"
38
- "KHÔNG suy diễn.\n"
39
- "Giữ nguyên ý nghĩa 100%.\n"
40
- "Viết ngắn gọn, dễ hiểu."
41
- )
42
- },
43
- {
44
- "role": "user",
45
- "content": f"Viết lại câu sau:\n{text}"
46
- }
47
- ]
48
-
49
- text_input = tokenizer.apply_chat_template(
50
- messages, tokenize=False, add_generation_prompt=True
51
- )
52
-
53
- inputs = tokenizer([text_input], return_tensors="pt").to(model.device)
54
-
55
- outputs = model.generate(
56
- **inputs,
57
- max_new_tokens=80,
58
- top_p=0.8
59
- )
60
-
61
- return tokenizer.decode(
62
- outputs[0][inputs.input_ids.shape[-1]:],
63
- skip_special_tokens=True
64
- ).strip()
65
-
66
-
67
- # ========================
68
- # BUILD ANSWER
69
- # ========================
70
- def build_natural_answer(results):
71
- answer = ""
72
- for idx, p in enumerate(results, 1):
73
- desc = paraphrase_description(p["description"])
74
- answer += f"{idx}. {p['name']} ({p['city']}): {desc}\n"
75
- return answer
76
-
77
- # ========================
78
- # INTENT
79
- # ========================
80
- def detect_intent(query):
81
- q = query.lower()
82
-
83
- if any(k in q for k in ["biển","bien","beach","đảo","dao"]): return "beach"
84
- if any(k in q for k in ["ăn","food","quán","nhà hàng"]): return "food"
85
- if any(k in q for k in ["checkin","sống ảo","đẹp"]): return "checkin"
86
- if any(k in q for k in ["chùa","di tích","lịch sử"]): return "culture"
87
- if any(k in q for k in ["núi","thác","rừng"]): return "nature"
88
- if any(k in q for k in ["chơi","giải trí","bar"]): return "entertainment"
89
-
90
- return "general"
91
-
92
- # ========================
93
- # EMBEDDING CLASSIFIER
94
- # ========================
95
- def classify_query(results, threshold=1.2):
96
- if not results:
97
- return "no_data"
98
-
99
- if isinstance(results, list) and "error" in results[0]:
100
- return "no_data"
101
-
102
- top_score = results[0].get("score", 999)
103
-
104
- if top_score > threshold:
105
- return "out_domain"
106
-
107
- return "in_domain"
108
-
109
- # ========================
110
- # MAIN RAG
111
- # ========================
112
- def rag_answer(query):
113
- global last_search_results, waiting_for_suggestion, last_query_no_result
114
-
115
- query_lower = query.lower().strip()
116
-
117
- # ========================
118
- # 1. HANDLE GỢI Ý
119
- # ========================
120
- if waiting_for_suggestion:
121
- if any(x in query_lower for x in ["có","ok","yes","ừ"]):
122
- waiting_for_suggestion = False
123
- results = search_places("du lịch nổi bật việt nam")
124
-
125
- if results:
126
- last_search_results = results
127
- return "Gợi ý cho bạn:\n\n" + build_natural_answer(results)
128
-
129
- return "Chưa có gợi ý phù hợp."
130
-
131
- elif any(x in query_lower for x in ["không","no","ko"]):
132
- waiting_for_suggestion = False
133
- return "Ok, bạn cần gì cứ hỏi mình nhé!"
134
-
135
- else:
136
- return "Bạn có muốn mình gợi ý địa điểm khác không? (có / không)"
137
-
138
- # ========================
139
- # 2. HỎI LINK
140
- # ========================
141
- if "link" in query_lower and last_search_results:
142
- nums = re.findall(r'\d+', query)
143
- if nums:
144
- idx = int(nums[0]) - 1
145
- if 0 <= idx < len(last_search_results):
146
- p = last_search_results[idx]
147
- return f"Link: {p['maps_link']}"
148
-
149
- # ========================
150
- # 3. SEARCH
151
- # ========================
152
- intent = detect_intent(query)
153
-
154
- intent_map = {
155
- "beach": "du lịch biển",
156
- "food": "ăn uống",
157
- "checkin": "checkin đẹp",
158
- "culture": "văn hóa",
159
- "nature": "thiên nhiên",
160
- "entertainment": "giải trí",
161
- "general": ""
162
- }
163
-
164
- augmented_query = query + " " + intent_map.get(intent, "")
165
-
166
- results = search_places(augmented_query)
167
-
168
- # ========================
169
- # 4. CLASSIFY
170
- # ========================
171
- query_type = classify_query(results)
172
-
173
- if query_type == "out_domain":
174
- return "Mình chỉ hỗ trợ tư vấn du lịch (địa điểm, ăn uống, vui chơi)."
175
-
176
- if query_type == "no_data":
177
- # nếu query vô nghĩa
178
- if is_meaningless_query(query):
179
- return "Mình chưa hiểu ý bạn. Mình chỉ hỗ trợ gợi ý địa điểm du lịch nhé!"
180
-
181
- # nếu có nghĩa nhưng không có dữ liệu
182
- waiting_for_suggestion = True
183
- last_query_no_result = query
184
- return "Mình chưa có dữ liệu. Bạn có muốn mình gợi ý địa điểm khác không?"
185
-
186
- # ========================
187
- # 5. BUILD
188
- # ========================
189
- last_search_results = results
190
  return build_natural_answer(results)
 
1
+ import torch
2
+ import re
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from search import search_places, is_meaningless_query
5
+
6
+ # ========================
7
+ # MODEL
8
+ # ========================
9
+ last_search_results = []
10
+ waiting_for_suggestion = False
11
+ last_query_no_result = ""
12
+
13
+ model_name = "Qwen/Qwen2.5-0.5B-Instruct"
14
+
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_name,
17
+ torch_dtype=torch.float16,
18
+ device_map="auto"
19
+ )
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+
23
+ # ========================
24
+ # MEMORY (tạm thời vẫn dùng global)
25
+ # ========================
26
+ last_search_results = []
27
+
28
+ # ========================
29
+ # PARAPHRASE
30
+ # ========================
31
+ def paraphrase_description(text):
32
+ messages = [
33
+ {
34
+ "role": "system",
35
+ "content": (
36
+ "Bạn chỉ có nhiệm vụ viết lại câu văn cho tự nhiên hơn.\n"
37
+ "KHÔNG thêm thông tin mới.\n"
38
+ "KHÔNG suy diễn.\n"
39
+ "Giữ nguyên ý nghĩa 100%.\n"
40
+ "Viết ngắn gọn, dễ hiểu."
41
+ )
42
+ },
43
+ {
44
+ "role": "user",
45
+ "content": f"Viết lại câu sau:\n{text}"
46
+ }
47
+ ]
48
+
49
+ text_input = tokenizer.apply_chat_template(
50
+ messages, tokenize=False, add_generation_prompt=True
51
+ )
52
+
53
+ inputs = tokenizer([text_input], return_tensors="pt").to(model.device)
54
+
55
+ outputs = model.generate(
56
+ **inputs,
57
+ max_new_tokens=80,
58
+ top_p=0.8
59
+ )
60
+
61
+ return tokenizer.decode(
62
+ outputs[0][inputs.input_ids.shape[-1]:],
63
+ skip_special_tokens=True
64
+ ).strip()
65
+
66
+
67
+ # ========================
68
+ # BUILD ANSWER
69
+ # ========================
70
+ def build_natural_answer(results):
71
+ answer = ""
72
+ for idx, p in enumerate(results, 1):
73
+ desc = paraphrase_description(p["description"])
74
+ answer += f"{idx}. {p['name']} ({p['city']}): {desc}\n"
75
+ return answer
76
+
77
+ # ========================
78
+ # INTENT
79
+ # ========================
80
+ def detect_intent(query):
81
+ q = query.lower()
82
+
83
+ if any(k in q for k in ["biển","bien","beach","đảo","dao"]): return "beach"
84
+ if any(k in q for k in ["ăn","food","quán","nhà hàng"]): return "food"
85
+ if any(k in q for k in ["checkin","sống ảo","đẹp"]): return "checkin"
86
+ if any(k in q for k in ["chùa","di tích","lịch sử"]): return "culture"
87
+ if any(k in q for k in ["núi","thác","rừng"]): return "nature"
88
+ if any(k in q for k in ["chơi","giải trí","bar"]): return "entertainment"
89
+
90
+ return "general"
91
+
92
+ # ========================
93
+ # EMBEDDING CLASSIFIER
94
+ # ========================
95
+ def classify_query(results, threshold=1.2):
96
+ if not results:
97
+ return "no_data"
98
+
99
+ if isinstance(results, list) and "error" in results[0]:
100
+ return "no_data"
101
+
102
+ top_score = results[0].get("score", 999)
103
+
104
+ if top_score > threshold:
105
+ return "out_domain"
106
+
107
+ return "in_domain"
108
+
109
+ # ========================
110
+ # MAIN RAG
111
+ # ========================
112
+ def rag_answer(query):
113
+ global last_search_results, waiting_for_suggestion, last_query_no_result
114
+
115
+ query_lower = query.lower().strip()
116
+
117
+ # ========================
118
+ # 1. HANDLE GỢI Ý
119
+ # ========================
120
+ if waiting_for_suggestion:
121
+ if any(x in query_lower for x in ["có","ok","yes","ừ"]):
122
+ waiting_for_suggestion = False
123
+ results = search_places("du lịch nổi bật việt nam")
124
+
125
+ if results:
126
+ last_search_results = results
127
+ return "Gợi ý cho bạn:\n\n" + build_natural_answer(results)
128
+
129
+ return "Chưa có gợi ý phù hợp."
130
+
131
+ elif any(x in query_lower for x in ["không","no","ko"]):
132
+ waiting_for_suggestion = False
133
+ return "Ok, bạn cần gì cứ hỏi mình nhé!"
134
+
135
+ else:
136
+ return "Bạn có muốn mình gợi ý địa điểm khác không? (có / không)"
137
+
138
+ # ========================
139
+ # 2. HỎI LINK
140
+ # ========================
141
+ if "link" in query_lower and last_search_results:
142
+ nums = re.findall(r'\d+', query)
143
+ if nums:
144
+ idx = int(nums[0]) - 1
145
+ if 0 <= idx < len(last_search_results):
146
+ p = last_search_results[idx]
147
+ return f"Link: {p['maps_link']}"
148
+
149
+ # ========================
150
+ # 3. SEARCH
151
+ # ========================
152
+ intent = detect_intent(query)
153
+
154
+ intent_map = {
155
+ "beach": "du lịch biển",
156
+ "food": "ăn uống",
157
+ "checkin": "checkin đẹp",
158
+ "culture": "văn hóa",
159
+ "nature": "thiên nhiên",
160
+ "entertainment": "giải trí",
161
+ "general": ""
162
+ }
163
+
164
+ augmented_query = query + " " + intent_map.get(intent, "")
165
+
166
+ results = search_places(augmented_query)
167
+
168
+ # ========================
169
+ # 4. CLASSIFY
170
+ # ========================
171
+ query_type = classify_query(results)
172
+
173
+ if query_type == "out_domain":
174
+ return "Mình chỉ hỗ trợ tư vấn du lịch (địa điểm, ăn uống, vui chơi)."
175
+
176
+ if query_type == "no_data":
177
+ # nếu query vô nghĩa
178
+ if is_meaningless_query(query):
179
+ return "Mình chưa hiểu ý bạn. Mình chỉ hỗ trợ gợi ý địa điểm du lịch nhé!"
180
+
181
+ # nếu có nghĩa nhưng không có dữ liệu
182
+ waiting_for_suggestion = True
183
+ last_query_no_result = query
184
+ return "Mình chưa có dữ liệu. Bạn có muốn mình gợi ý địa điểm khác không?"
185
+
186
+ # ========================
187
+ # 5. BUILD
188
+ # ========================
189
+ last_search_results = results
190
  return build_natural_answer(results)