Spaces:
Paused
Paused
- rag_pipeline.py +21 -16
rag_pipeline.py
CHANGED
|
@@ -82,13 +82,13 @@ def initialize_components(data_path):
|
|
| 82 |
|
| 83 |
def generate_response(query: str, components: dict) -> str:
|
| 84 |
"""
|
| 85 |
-
Tạo câu trả lời (single-turn).
|
| 86 |
-
Phiên bản
|
|
|
|
| 87 |
"""
|
| 88 |
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
# 1. Truy xuất ngữ cảnh
|
| 92 |
retrieved_results = search_relevant_laws(
|
| 93 |
query_text=query,
|
| 94 |
embedding_model=components["embedding_model"],
|
|
@@ -99,25 +99,24 @@ def generate_response(query: str, components: dict) -> str:
|
|
| 99 |
initial_k_multiplier=15
|
| 100 |
)
|
| 101 |
|
| 102 |
-
#
|
| 103 |
-
# 2. Định dạng Context
|
| 104 |
if not retrieved_results:
|
| 105 |
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
|
| 106 |
else:
|
| 107 |
context_parts = []
|
| 108 |
for i, res in enumerate(retrieved_results):
|
| 109 |
metadata = res.get('metadata', {})
|
| 110 |
-
# Tạo header đơn giản, không có gợi ý
|
| 111 |
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
|
| 112 |
text = res.get('text', '*Nội dung không có*')
|
| 113 |
context_parts.append(f"{header}\n{text}")
|
| 114 |
context = "\n\n---\n\n".join(context_parts)
|
| 115 |
|
| 116 |
-
# 3
|
| 117 |
-
print("---
|
| 118 |
llm_model = components["llm_model"]
|
| 119 |
tokenizer = components["tokenizer"]
|
| 120 |
|
|
|
|
| 121 |
messages = [
|
| 122 |
{
|
| 123 |
"role": "system",
|
|
@@ -136,13 +135,17 @@ def generate_response(query: str, components: dict) -> str:
|
|
| 136 |
}
|
| 137 |
]
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
| 143 |
|
| 144 |
-
inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device)
|
| 145 |
-
|
| 146 |
generation_config = dict(
|
| 147 |
max_new_tokens=256,
|
| 148 |
temperature=0.1,
|
|
@@ -151,8 +154,10 @@ def generate_response(query: str, components: dict) -> str:
|
|
| 151 |
pad_token_id=tokenizer.eos_token_id
|
| 152 |
)
|
| 153 |
|
| 154 |
-
output_ids = llm_model.generate(
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
|
| 157 |
print("--- Tạo câu trả lời hoàn tất ---")
|
| 158 |
return response_text
|
|
|
|
| 82 |
|
| 83 |
def generate_response(query: str, components: dict) -> str:
|
| 84 |
"""
|
| 85 |
+
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
|
| 86 |
+
Phiên bản cuối cùng, sửa lỗi ValueError cho mô hình Vision bằng cách
|
| 87 |
+
sử dụng apply_chat_template để tokenization trực tiếp.
|
| 88 |
"""
|
| 89 |
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
| 90 |
|
| 91 |
+
# --- Bước 1: Truy xuất Ngữ cảnh ---
|
|
|
|
| 92 |
retrieved_results = search_relevant_laws(
|
| 93 |
query_text=query,
|
| 94 |
embedding_model=components["embedding_model"],
|
|
|
|
| 99 |
initial_k_multiplier=15
|
| 100 |
)
|
| 101 |
|
| 102 |
+
# --- Bước 2: Định dạng Ngữ cảnh ---
|
|
|
|
| 103 |
if not retrieved_results:
|
| 104 |
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
|
| 105 |
else:
|
| 106 |
context_parts = []
|
| 107 |
for i, res in enumerate(retrieved_results):
|
| 108 |
metadata = res.get('metadata', {})
|
|
|
|
| 109 |
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
|
| 110 |
text = res.get('text', '*Nội dung không có*')
|
| 111 |
context_parts.append(f"{header}\n{text}")
|
| 112 |
context = "\n\n---\n\n".join(context_parts)
|
| 113 |
|
| 114 |
+
# --- Bước 3: Chuẩn bị Dữ liệu và Tokenize bằng Chat Template (Phần sửa lỗi cốt lõi) ---
|
| 115 |
+
print("--- Chuẩn bị và tokenize prompt bằng chat template ---")
|
| 116 |
llm_model = components["llm_model"]
|
| 117 |
tokenizer = components["tokenizer"]
|
| 118 |
|
| 119 |
+
# Tạo cấu trúc tin nhắn theo chuẩn
|
| 120 |
messages = [
|
| 121 |
{
|
| 122 |
"role": "system",
|
|
|
|
| 135 |
}
|
| 136 |
]
|
| 137 |
|
| 138 |
+
# SỬA LỖI: Dùng apply_chat_template để tokenize trực tiếp
|
| 139 |
+
# Nó sẽ tự động định dạng và chuyển thành tensor, tương thích với mô hình Vision
|
| 140 |
+
inputs = tokenizer.apply_chat_template(
|
| 141 |
+
messages,
|
| 142 |
+
return_tensors="pt",
|
| 143 |
+
add_generation_prompt=True,
|
| 144 |
+
).to(llm_model.device)
|
| 145 |
+
|
| 146 |
+
# --- Bước 4: Tạo câu trả lời từ LLM ---
|
| 147 |
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
| 148 |
|
|
|
|
|
|
|
| 149 |
generation_config = dict(
|
| 150 |
max_new_tokens=256,
|
| 151 |
temperature=0.1,
|
|
|
|
| 154 |
pad_token_id=tokenizer.eos_token_id
|
| 155 |
)
|
| 156 |
|
| 157 |
+
output_ids = llm_model.generate(inputs, **generation_config)
|
| 158 |
+
|
| 159 |
+
# Decode như cũ, nhưng đầu vào là `inputs` thay vì `inputs.input_ids`
|
| 160 |
+
response_text = tokenizer.decode(output_ids[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 161 |
|
| 162 |
print("--- Tạo câu trả lời hoàn tất ---")
|
| 163 |
return response_text
|