Spaces:
Runtime error
Runtime error
修改bug并且删除vector_store的重复初始化
Browse files
app.py
CHANGED
|
@@ -94,7 +94,7 @@ try:
|
|
| 94 |
|
| 95 |
except Exception as e:
|
| 96 |
raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}")
|
| 97 |
-
|
| 98 |
|
| 99 |
|
| 100 |
background_prompt = '''
|
|
@@ -155,19 +155,14 @@ Now, please guide me step by step to describe the legal issues I am facing, acco
|
|
| 155 |
|
| 156 |
def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8):
|
| 157 |
"""
|
| 158 |
-
|
| 159 |
-
参数:
|
| 160 |
-
vector_store (FAISS): 向量存储实例
|
| 161 |
-
query (str): 查询内容
|
| 162 |
-
k (int): 返回文档数量
|
| 163 |
-
relevance_threshold (float): 相关性阈值
|
| 164 |
-
返回:
|
| 165 |
-
context (list): 查询到的上下文内容
|
| 166 |
"""
|
| 167 |
-
retriever = vector_store.as_retriever(search_type="similarity_score_threshold",
|
|
|
|
| 168 |
similar_docs = retriever.invoke(query)
|
| 169 |
context = [doc.page_content for doc in similar_docs]
|
| 170 |
-
|
|
|
|
| 171 |
|
| 172 |
@spaces.GPU(duration=120)
|
| 173 |
def chat_llama3_8b(message: str,
|
|
@@ -177,40 +172,39 @@ def chat_llama3_8b(message: str,
|
|
| 177 |
) -> str:
|
| 178 |
"""
|
| 179 |
Generate a streaming response using the llama3-8b model.
|
| 180 |
-
Args:
|
| 181 |
-
message (str): The input message.
|
| 182 |
-
history (list): The conversation history used by ChatInterface.
|
| 183 |
-
temperature (float): The temperature for generating the response.
|
| 184 |
-
max_new_tokens (int): The maximum number of new tokens to generate.
|
| 185 |
-
Returns:
|
| 186 |
-
str: The generated response.
|
| 187 |
"""
|
|
|
|
| 188 |
citation = query_vector_store(vector_store, message, 4, 0.7)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
conversation = []
|
| 192 |
for user, assistant in history:
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
else:
|
| 198 |
-
message = background_prompt
|
|
|
|
| 199 |
conversation.append({"role": "user", "content": message})
|
| 200 |
|
|
|
|
| 201 |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
| 202 |
-
|
| 203 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
| 204 |
|
| 205 |
generate_kwargs = dict(
|
| 206 |
-
input_ids=
|
| 207 |
streamer=streamer,
|
| 208 |
max_new_tokens=max_new_tokens,
|
| 209 |
do_sample=True,
|
| 210 |
temperature=temperature,
|
| 211 |
eos_token_id=terminators,
|
| 212 |
)
|
| 213 |
-
|
| 214 |
if temperature == 0:
|
| 215 |
generate_kwargs['do_sample'] = False
|
| 216 |
|
|
@@ -220,7 +214,6 @@ def chat_llama3_8b(message: str,
|
|
| 220 |
outputs = []
|
| 221 |
for text in streamer:
|
| 222 |
outputs.append(text)
|
| 223 |
-
#print(outputs)
|
| 224 |
yield "".join(outputs)
|
| 225 |
|
| 226 |
|
|
|
|
| 94 |
|
| 95 |
except Exception as e:
|
| 96 |
raise RuntimeError(f"Failed to load vector store from HuggingFace Hub: {str(e)}")
|
| 97 |
+
vector_store = FAISS.load_local(repo_path, embedding_model, allow_dangerous_deserialization=True)
|
| 98 |
|
| 99 |
|
| 100 |
background_prompt = '''
|
|
|
|
| 155 |
|
| 156 |
def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8):
|
| 157 |
"""
|
| 158 |
+
Query similar documents from vector store.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
"""
|
| 160 |
+
retriever = vector_store.as_retriever(search_type="similarity_score_threshold",
|
| 161 |
+
search_kwargs={"score_threshold": relevance_threshold, "k": k})
|
| 162 |
similar_docs = retriever.invoke(query)
|
| 163 |
context = [doc.page_content for doc in similar_docs]
|
| 164 |
+
# Join the context list into a single string
|
| 165 |
+
return " ".join(context) if context else ""
|
| 166 |
|
| 167 |
@spaces.GPU(duration=120)
|
| 168 |
def chat_llama3_8b(message: str,
|
|
|
|
| 172 |
) -> str:
|
| 173 |
"""
|
| 174 |
Generate a streaming response using the llama3-8b model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
"""
|
| 176 |
+
# Get citations from vector store
|
| 177 |
citation = query_vector_store(vector_store, message, 4, 0.7)
|
| 178 |
+
|
| 179 |
+
# Build conversation history
|
| 180 |
conversation = []
|
| 181 |
for user, assistant in history:
|
| 182 |
+
conversation.extend([
|
| 183 |
+
{"role": "user", "content": user},
|
| 184 |
+
{"role": "assistant", "content": assistant}
|
| 185 |
+
])
|
| 186 |
+
|
| 187 |
+
# Construct the final message with background prompt and citations
|
| 188 |
+
if citation:
|
| 189 |
+
message = f"{background_prompt}Based on these citations: {citation}\nPlease answer question: {message}"
|
| 190 |
else:
|
| 191 |
+
message = f"{background_prompt}{message}"
|
| 192 |
+
|
| 193 |
conversation.append({"role": "user", "content": message})
|
| 194 |
|
| 195 |
+
# Generate response
|
| 196 |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
|
|
|
| 197 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
| 198 |
|
| 199 |
generate_kwargs = dict(
|
| 200 |
+
input_ids=input_ids,
|
| 201 |
streamer=streamer,
|
| 202 |
max_new_tokens=max_new_tokens,
|
| 203 |
do_sample=True,
|
| 204 |
temperature=temperature,
|
| 205 |
eos_token_id=terminators,
|
| 206 |
)
|
| 207 |
+
|
| 208 |
if temperature == 0:
|
| 209 |
generate_kwargs['do_sample'] = False
|
| 210 |
|
|
|
|
| 214 |
outputs = []
|
| 215 |
for text in streamer:
|
| 216 |
outputs.append(text)
|
|
|
|
| 217 |
yield "".join(outputs)
|
| 218 |
|
| 219 |
|