Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,10 +9,8 @@ from transformers.models.llama.configuration_llama import LlamaConfig
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
import gradio as gr
|
| 11 |
|
| 12 |
-
# --- 解決 Matplotlib 權限問題 ---
|
| 13 |
-
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
|
| 14 |
-
|
| 15 |
# === Enterprise Environment Settings ===
|
|
|
|
| 16 |
os.environ["HOME"] = "/tmp"
|
| 17 |
os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
|
| 18 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
|
@@ -20,9 +18,10 @@ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
|
| 20 |
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
|
| 21 |
os.environ["HF_METRICS_CACHE"] = "/tmp/huggingface/metrics"
|
| 22 |
os.environ["GRADIO_FLAGGING_DIR"] = "/tmp/flagged"
|
|
|
|
|
|
|
| 23 |
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers"
|
| 24 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hf_cache"
|
| 25 |
-
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
|
| 26 |
|
| 27 |
# === Load Required Modules ===
|
| 28 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
@@ -35,15 +34,20 @@ from langchain.chains import ConversationalRetrievalChain
|
|
| 35 |
from langchain.memory import ConversationBufferMemory
|
| 36 |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader
|
| 37 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 38 |
-
from langchain.chains.summarize import load_summarize_chain
|
| 39 |
from tempfile import mkdtemp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
import pandas as pd
|
| 41 |
|
| 42 |
# === Multi-Agent Imports ===
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
from serpapi import GoogleSearch
|
| 44 |
-
# CrewAI 部分:完全使用 CrewAI 的 Agent、Task、Crew 與 @tool 裝飾器
|
| 45 |
-
from crewai import Crew, Agent, Task, Process
|
| 46 |
-
from crewai.tools import tool
|
| 47 |
|
| 48 |
# === Model and Device Setup ===
|
| 49 |
if torch.backends.mps.is_available():
|
|
@@ -54,10 +58,14 @@ else:
|
|
| 54 |
device = "cpu"
|
| 55 |
print(f"Using device => {device}")
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
model_id = "ChienChung/my-llama-1b"
|
| 60 |
|
|
|
|
| 61 |
config_path = hf_hub_download(
|
| 62 |
repo_id=model_id,
|
| 63 |
filename="config.json",
|
|
@@ -66,8 +74,13 @@ config_path = hf_hub_download(
|
|
| 66 |
)
|
| 67 |
with open(config_path, "r", encoding="utf-8") as f:
|
| 68 |
config_dict = json.load(f)
|
|
|
|
| 69 |
if "rope_scaling" in config_dict:
|
| 70 |
-
config_dict["rope_scaling"] = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
model_config = LlamaConfig.from_dict(config_dict)
|
| 72 |
model_config.trust_remote_code = True
|
| 73 |
|
|
@@ -93,6 +106,7 @@ if tokenizer.pad_token is None:
|
|
| 93 |
tokenizer.pad_token = tokenizer.eos_token
|
| 94 |
print("Tokenizer loaded!")
|
| 95 |
|
|
|
|
| 96 |
query_pipeline = transformers.pipeline(
|
| 97 |
"text-generation",
|
| 98 |
model=model,
|
|
@@ -109,6 +123,7 @@ query_pipeline = transformers.pipeline(
|
|
| 109 |
print("Loading Chroma DB for Biden Speech...")
|
| 110 |
if not os.path.exists("/tmp/chroma_db"):
|
| 111 |
shutil.copytree("./chroma_db", "/tmp/chroma_db")
|
|
|
|
| 112 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
| 113 |
vectordb = Chroma(persist_directory="/tmp/chroma_db", embedding_function=embeddings)
|
| 114 |
retriever = vectordb.as_retriever()
|
|
@@ -143,22 +158,9 @@ qa_gpt = ConversationalRetrievalChain.from_llm(
|
|
| 143 |
combine_docs_chain_kwargs={"prompt": custom_prompt}
|
| 144 |
)
|
| 145 |
|
| 146 |
-
# ===
|
| 147 |
-
def get_file_path(file):
|
| 148 |
-
if isinstance(file, str):
|
| 149 |
-
return file
|
| 150 |
-
elif isinstance(file, dict):
|
| 151 |
-
# 優先使用 "data" 鍵,其次是 "name"
|
| 152 |
-
return file.get("data", file.get("name", None))
|
| 153 |
-
elif hasattr(file, "save"):
|
| 154 |
-
temp_dir = mkdtemp()
|
| 155 |
-
file_path = os.path.join(temp_dir, file.name)
|
| 156 |
-
file.save(file_path)
|
| 157 |
-
return file_path
|
| 158 |
-
else:
|
| 159 |
-
return None
|
| 160 |
|
| 161 |
-
#
|
| 162 |
def rag_llama_qa(query):
|
| 163 |
output = RetrievalQA.from_chain_type(
|
| 164 |
llm=llm_local,
|
|
@@ -171,19 +173,32 @@ def rag_llama_qa(query):
|
|
| 171 |
idx = lower_text.find("answer:")
|
| 172 |
return output[idx + len("answer:"):].strip() if idx != -1 else output
|
| 173 |
|
|
|
|
| 174 |
def rag_gpt4_qa(query):
|
| 175 |
return qa_gpt.run(query)
|
| 176 |
|
|
|
|
| 177 |
def upload_and_chat(file, query):
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
if file_path.lower().endswith(".pdf"):
|
| 182 |
loader = PyPDFLoader(file_path)
|
| 183 |
elif file_path.lower().endswith(".docx"):
|
| 184 |
loader = UnstructuredWordDocumentLoader(file_path)
|
| 185 |
else:
|
| 186 |
loader = TextLoader(file_path)
|
|
|
|
| 187 |
docs = loader.load()
|
| 188 |
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
|
| 189 |
db = FAISS.from_documents(chunks, embeddings)
|
|
@@ -197,63 +212,96 @@ def upload_and_chat(file, query):
|
|
| 197 |
)
|
| 198 |
return qa_temp.run(query)
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
initial_prompt = PromptTemplate(
|
| 201 |
input_variables=["text"],
|
| 202 |
-
template="""
|
|
|
|
| 203 |
|
| 204 |
-
{text}
|
| 205 |
|
| 206 |
-
|
| 207 |
"""
|
| 208 |
)
|
|
|
|
| 209 |
refine_prompt = PromptTemplate(
|
| 210 |
input_variables=["existing_answer", "text"],
|
| 211 |
-
template="""
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
Refine the summary based on the new content below. Add or update information only if it's relevant. Keep it concise:
|
| 215 |
|
| 216 |
-
|
|
|
|
| 217 |
|
| 218 |
-
|
| 219 |
"""
|
| 220 |
)
|
| 221 |
-
|
| 222 |
def document_summarize(file):
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
if file_path.lower().endswith(".pdf"):
|
| 227 |
loader = PyPDFLoader(file_path)
|
| 228 |
elif file_path.lower().endswith(".docx"):
|
| 229 |
loader = UnstructuredWordDocumentLoader(file_path)
|
| 230 |
else:
|
| 231 |
loader = TextLoader(file_path)
|
|
|
|
| 232 |
docs = loader.load()
|
| 233 |
summarize_chain = load_summarize_chain(llm_gpt4, chain_type="refine", question_prompt=initial_prompt, refine_prompt=refine_prompt)
|
| 234 |
summary = summarize_chain.invoke(docs)
|
| 235 |
return summary['output_text']
|
| 236 |
|
|
|
|
| 237 |
def csv_agent(file, query):
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
try:
|
| 242 |
df = pd.read_csv(file_path)
|
| 243 |
except Exception as e:
|
| 244 |
-
return f"Error reading CSV: {e}"
|
| 245 |
safe_dict = {"df": df, "pd": pd}
|
| 246 |
try:
|
| 247 |
result = eval(query, {"__builtins__": {}}, safe_dict)
|
| 248 |
return str(result)
|
| 249 |
except Exception as e:
|
| 250 |
-
return f"Query error: {e}"
|
| 251 |
|
|
|
|
| 252 |
def search_agent(query):
|
| 253 |
api_key = os.environ.get("SERPAPI_API_KEY")
|
| 254 |
if not api_key:
|
| 255 |
return "SERPAPI_API_KEY not set. Please set the environment variable."
|
| 256 |
-
params = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
search = GoogleSearch(params)
|
| 258 |
results = search.get_dict()
|
| 259 |
if "organic_results" in results:
|
|
@@ -267,10 +315,41 @@ def search_agent(query):
|
|
| 267 |
else:
|
| 268 |
return "No results found."
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
def uploaded_qa(file, query):
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
if file_path.lower().endswith(".pdf"):
|
| 275 |
loader = PyPDFLoader(file_path)
|
| 276 |
elif file_path.lower().endswith(".docx"):
|
|
@@ -290,115 +369,46 @@ def uploaded_qa(file, query):
|
|
| 290 |
)
|
| 291 |
return qa_temp.run(query)
|
| 292 |
|
| 293 |
-
|
| 294 |
-
# 完全捨棄 langchain.agents.Tool,使用 CrewAI 的 @tool 裝飾器來定義工具
|
| 295 |
-
from pydantic import BaseModel
|
| 296 |
-
class SimpleQuery(BaseModel):
|
| 297 |
-
query: str
|
| 298 |
-
|
| 299 |
-
@tool("summarise")
|
| 300 |
-
def summarise_tool(query: str) -> str:
|
| 301 |
-
"""Summarise: 使用文件摘要功能。"""
|
| 302 |
-
global session_retriever, session_qa_chain
|
| 303 |
-
if session_retriever is None:
|
| 304 |
-
return "尚未上傳文件。"
|
| 305 |
-
try:
|
| 306 |
-
docs = session_retriever.get_relevant_documents(query if query.strip() else "summary")
|
| 307 |
-
if not docs:
|
| 308 |
-
return "無法從文件中找到相關內容。"
|
| 309 |
-
summarize_chain = load_summarize_chain(llm_gpt4, chain_type="refine", question_prompt=initial_prompt, refine_prompt=refine_prompt)
|
| 310 |
-
summary = summarize_chain.invoke(docs)
|
| 311 |
-
return summary['output_text']
|
| 312 |
-
except Exception as e:
|
| 313 |
-
return f"摘要錯誤: {e}"
|
| 314 |
-
|
| 315 |
-
@tool("python_calc")
|
| 316 |
-
def python_calc_tool(query: str) -> str:
|
| 317 |
-
"""Python Calculation: 執行簡單計算。"""
|
| 318 |
-
try:
|
| 319 |
-
return str(eval(query))
|
| 320 |
-
except Exception as e:
|
| 321 |
-
return f"計算錯誤: {e}"
|
| 322 |
-
|
| 323 |
-
@tool("search_agent")
|
| 324 |
-
def search_tool_func(query: str) -> str:
|
| 325 |
-
"""Search: 執行網路搜尋。"""
|
| 326 |
-
return search_agent(query)
|
| 327 |
-
|
| 328 |
-
@tool("uploaded_qa")
|
| 329 |
-
def uploaded_qa_tool_func(query: str) -> str:
|
| 330 |
-
"""Document QA: 根據上傳文件回答問題。"""
|
| 331 |
-
global session_qa_chain
|
| 332 |
-
if session_qa_chain is not None:
|
| 333 |
-
try:
|
| 334 |
-
return session_qa_chain.run(query)
|
| 335 |
-
except Exception as e:
|
| 336 |
-
return f"文檔問答錯誤: {e}"
|
| 337 |
-
else:
|
| 338 |
-
return "尚未上傳文件。"
|
| 339 |
-
|
| 340 |
-
# 建立 CrewAI 代理(僅針對 Tab 5)
|
| 341 |
-
summarizer_agent = Agent(
|
| 342 |
-
role="文件摘要助手",
|
| 343 |
-
goal="對上傳文件內容進行摘要",
|
| 344 |
-
backstory="你是一位專業的摘要專家,能抓住長文的重點。",
|
| 345 |
-
tools=[summarise_tool],
|
| 346 |
-
verbose=True
|
| 347 |
-
)
|
| 348 |
-
document_qa_agent = Agent(
|
| 349 |
-
role="文件問答專家",
|
| 350 |
-
goal="根據上傳文件回答問題",
|
| 351 |
-
backstory="你精通文檔內容,能從中找出問題答案。",
|
| 352 |
-
tools=[uploaded_qa_tool_func],
|
| 353 |
-
verbose=True
|
| 354 |
-
)
|
| 355 |
-
general_agent = Agent(
|
| 356 |
-
role="綜合助手",
|
| 357 |
-
goal="回答一般問題,執行計算與網路搜尋",
|
| 358 |
-
backstory="你是一位多才多藝的AI助理,能根據需要使用工具。",
|
| 359 |
-
tools=[python_calc_tool, search_tool_func],
|
| 360 |
-
verbose=True
|
| 361 |
-
)
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
)
|
|
|
|
|
|
|
| 368 |
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
tasks=[router_task],
|
| 372 |
-
process=Process.sequential,
|
| 373 |
-
verbose=True
|
| 374 |
-
)
|
| 375 |
|
| 376 |
-
def multi_agent_chat(query
|
| 377 |
try:
|
| 378 |
-
return
|
| 379 |
except Exception as e:
|
| 380 |
-
return f"Error: {e}"
|
| 381 |
|
| 382 |
-
|
|
|
|
| 383 |
global session_retriever, session_qa_chain
|
| 384 |
-
# 定義一些明顯與文件無關的關鍵字
|
| 385 |
-
non_doc_keywords = ["calculate", "sum", "date", "time", "how many", "how much", "weather", "temperature"]
|
| 386 |
-
use_file_chain = True
|
| 387 |
-
for kw in non_doc_keywords:
|
| 388 |
-
if kw in query.lower():
|
| 389 |
-
use_file_chain = False
|
| 390 |
-
break
|
| 391 |
|
| 392 |
if file is not None:
|
| 393 |
-
|
|
|
|
| 394 |
if file_path is None:
|
| 395 |
-
return "
|
|
|
|
| 396 |
if file_path.lower().endswith(".csv"):
|
| 397 |
return csv_agent(file, query)
|
|
|
|
| 398 |
elif file_path.lower().endswith((".pdf", ".txt", ".docx")):
|
| 399 |
-
loader = (
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
docs = loader.load()
|
| 403 |
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
|
| 404 |
db = FAISS.from_documents(chunks, embeddings)
|
|
@@ -409,30 +419,23 @@ def multi_agent_chat_advanced(query: str, file=None) -> str:
|
|
| 409 |
memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True),
|
| 410 |
combine_docs_chain_kwargs={"prompt": custom_prompt}
|
| 411 |
)
|
| 412 |
-
|
| 413 |
-
|
|
|
|
|
|
|
| 414 |
else:
|
| 415 |
-
|
| 416 |
-
return crew.run(query)
|
| 417 |
-
except Exception as e:
|
| 418 |
-
return f"Error: {e}"
|
| 419 |
else:
|
| 420 |
-
return "
|
|
|
|
|
|
|
| 421 |
elif session_qa_chain is not None:
|
| 422 |
-
|
| 423 |
-
return session_qa_chain.run(query)
|
| 424 |
-
else:
|
| 425 |
-
try:
|
| 426 |
-
return crew.run(query)
|
| 427 |
-
except Exception as e:
|
| 428 |
-
return f"Error: {e}"
|
| 429 |
else:
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
return f"Error: {e}"
|
| 434 |
|
| 435 |
-
# === Gradio Interface Settings ===
|
| 436 |
demo_description = """
|
| 437 |
**Context**:
|
| 438 |
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
|
@@ -451,6 +454,7 @@ If no relevant information is found in the document, the system will say "No rel
|
|
| 451 |
|
| 452 |
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
|
| 453 |
"""
|
|
|
|
| 454 |
demo_description2 = """
|
| 455 |
**Context**:
|
| 456 |
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
|
@@ -469,13 +473,19 @@ If no relevant information is found in the document, the system will say "No rel
|
|
| 469 |
|
| 470 |
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
|
| 471 |
"""
|
|
|
|
| 472 |
demo_description3 = """
|
| 473 |
**Context**:
|
| 474 |
-
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
Feel free to ask any question related to your document.
|
| 478 |
"""
|
|
|
|
| 479 |
demo_description4 = """
|
| 480 |
**Context**:
|
| 481 |
This assistant performs multi-agent tasks using tools such as:
|
|
@@ -484,9 +494,10 @@ This assistant performs multi-agent tasks using tools such as:
|
|
| 484 |
- Financial or CSV-style logic queries
|
| 485 |
- Multi-step reasoning via agent orchestration
|
| 486 |
|
| 487 |
-
|
| 488 |
-
|
| 489 |
"""
|
|
|
|
| 490 |
demo_description5 = """
|
| 491 |
**Context**:
|
| 492 |
This demo uses Document Summarization via a Map-Reduce chain.
|
|
|
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
import gradio as gr
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
# === Enterprise Environment Settings ===
|
| 13 |
+
# Redirect cache directories to writable locations (e.g., /tmp)
|
| 14 |
os.environ["HOME"] = "/tmp"
|
| 15 |
os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
|
| 16 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
|
|
|
| 18 |
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
|
| 19 |
os.environ["HF_METRICS_CACHE"] = "/tmp/huggingface/metrics"
|
| 20 |
os.environ["GRADIO_FLAGGING_DIR"] = "/tmp/flagged"
|
| 21 |
+
|
| 22 |
+
# Set SentenceTransformers and HF_HUB cache directories to avoid writing to system root
|
| 23 |
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers"
|
| 24 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hf_cache"
|
|
|
|
| 25 |
|
| 26 |
# === Load Required Modules ===
|
| 27 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
|
| 34 |
from langchain.memory import ConversationBufferMemory
|
| 35 |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader
|
| 36 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
| 37 |
from tempfile import mkdtemp
|
| 38 |
+
|
| 39 |
+
# Import Summarization Chain (Map-Reduce)
|
| 40 |
+
from langchain.chains.summarize import load_summarize_chain
|
| 41 |
+
|
| 42 |
+
# Import pandas for CSV handling
|
| 43 |
import pandas as pd
|
| 44 |
|
| 45 |
# === Multi-Agent Imports ===
|
| 46 |
+
from langchain.agents import initialize_agent, Tool
|
| 47 |
+
from langchain.agents.agent_types import AgentType
|
| 48 |
+
|
| 49 |
+
# Import SerpAPI (real external search integration)
|
| 50 |
from serpapi import GoogleSearch
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# === Model and Device Setup ===
|
| 53 |
if torch.backends.mps.is_available():
|
|
|
|
| 58 |
device = "cpu"
|
| 59 |
print(f"Using device => {device}")
|
| 60 |
|
| 61 |
+
# Ensure environment variables are set for API keys:
|
| 62 |
+
hf_token = os.environ.get("HF_TOKEN") # Hugging Face access token
|
| 63 |
+
openai_api_key = os.environ.get("OPENAI_API_KEY") # OpenAI API key
|
| 64 |
+
# SERPAPI_API_KEY must be set for external search
|
| 65 |
+
|
| 66 |
model_id = "ChienChung/my-llama-1b"
|
| 67 |
|
| 68 |
+
# Download and load model config
|
| 69 |
config_path = hf_hub_download(
|
| 70 |
repo_id=model_id,
|
| 71 |
filename="config.json",
|
|
|
|
| 74 |
)
|
| 75 |
with open(config_path, "r", encoding="utf-8") as f:
|
| 76 |
config_dict = json.load(f)
|
| 77 |
+
|
| 78 |
if "rope_scaling" in config_dict:
|
| 79 |
+
config_dict["rope_scaling"] = {
|
| 80 |
+
"type": "dynamic",
|
| 81 |
+
"factor": config_dict["rope_scaling"].get("factor", 32.0)
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
model_config = LlamaConfig.from_dict(config_dict)
|
| 85 |
model_config.trust_remote_code = True
|
| 86 |
|
|
|
|
| 106 |
tokenizer.pad_token = tokenizer.eos_token
|
| 107 |
print("Tokenizer loaded!")
|
| 108 |
|
| 109 |
+
# Build query pipeline
|
| 110 |
query_pipeline = transformers.pipeline(
|
| 111 |
"text-generation",
|
| 112 |
model=model,
|
|
|
|
| 123 |
print("Loading Chroma DB for Biden Speech...")
|
| 124 |
if not os.path.exists("/tmp/chroma_db"):
|
| 125 |
shutil.copytree("./chroma_db", "/tmp/chroma_db")
|
| 126 |
+
|
| 127 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
| 128 |
vectordb = Chroma(persist_directory="/tmp/chroma_db", embedding_function=embeddings)
|
| 129 |
retriever = vectordb.as_retriever()
|
|
|
|
| 158 |
combine_docs_chain_kwargs={"prompt": custom_prompt}
|
| 159 |
)
|
| 160 |
|
| 161 |
+
# === Function Definitions ===
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
# Biden Q&A (LLaMA)
|
| 164 |
def rag_llama_qa(query):
|
| 165 |
output = RetrievalQA.from_chain_type(
|
| 166 |
llm=llm_local,
|
|
|
|
| 173 |
idx = lower_text.find("answer:")
|
| 174 |
return output[idx + len("answer:"):].strip() if idx != -1 else output
|
| 175 |
|
| 176 |
+
# Biden Q&A (GPT-4)
|
| 177 |
def rag_gpt4_qa(query):
|
| 178 |
return qa_gpt.run(query)
|
| 179 |
|
| 180 |
+
# Document Q&A (Upload + GPT-4)
|
| 181 |
def upload_and_chat(file, query):
|
| 182 |
+
if isinstance(file, str):
|
| 183 |
+
file_path = file
|
| 184 |
+
elif isinstance(file, dict):
|
| 185 |
+
file_path = file.get("name", None)
|
| 186 |
+
if file_path is None:
|
| 187 |
+
return "Unable to obtain the uploaded file path."
|
| 188 |
+
elif hasattr(file, "save"):
|
| 189 |
+
temp_dir = mkdtemp()
|
| 190 |
+
file_path = os.path.join(temp_dir, file.name)
|
| 191 |
+
file.save(file_path)
|
| 192 |
+
else:
|
| 193 |
+
return "Unable to process the file format."
|
| 194 |
+
|
| 195 |
if file_path.lower().endswith(".pdf"):
|
| 196 |
loader = PyPDFLoader(file_path)
|
| 197 |
elif file_path.lower().endswith(".docx"):
|
| 198 |
loader = UnstructuredWordDocumentLoader(file_path)
|
| 199 |
else:
|
| 200 |
loader = TextLoader(file_path)
|
| 201 |
+
|
| 202 |
docs = loader.load()
|
| 203 |
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
|
| 204 |
db = FAISS.from_documents(chunks, embeddings)
|
|
|
|
| 212 |
)
|
| 213 |
return qa_temp.run(query)
|
| 214 |
|
| 215 |
+
def patched_combine_llm_outputs(self, llm_outputs):
|
| 216 |
+
return {}
|
| 217 |
+
ChatOpenAI._combine_llm_outputs = patched_combine_llm_outputs
|
| 218 |
+
|
| 219 |
initial_prompt = PromptTemplate(
|
| 220 |
input_variables=["text"],
|
| 221 |
+
template="""
|
| 222 |
+
Write a concise summary of the following text:
|
| 223 |
|
| 224 |
+
"{text}"
|
| 225 |
|
| 226 |
+
CONCISE SUMMARY:
|
| 227 |
"""
|
| 228 |
)
|
| 229 |
+
|
| 230 |
refine_prompt = PromptTemplate(
|
| 231 |
input_variables=["existing_answer", "text"],
|
| 232 |
+
template="""
|
| 233 |
+
You have a summary already:
|
| 234 |
+
"{existing_answer}"
|
|
|
|
| 235 |
|
| 236 |
+
Now refine the summary with the following text:
|
| 237 |
+
"{text}"
|
| 238 |
|
| 239 |
+
REVISED SUMMARY:
|
| 240 |
"""
|
| 241 |
)
|
| 242 |
+
# Document Summarization (Refine)
|
| 243 |
def document_summarize(file):
|
| 244 |
+
if isinstance(file, str):
|
| 245 |
+
file_path = file
|
| 246 |
+
elif isinstance(file, dict):
|
| 247 |
+
file_path = file.get("name", None)
|
| 248 |
+
if file_path is None:
|
| 249 |
+
return "Unable to obtain the uploaded file."
|
| 250 |
+
elif hasattr(file, "save"):
|
| 251 |
+
temp_dir = mkdtemp()
|
| 252 |
+
file_path = os.path.join(temp_dir, file.name)
|
| 253 |
+
file.save(file_path)
|
| 254 |
+
else:
|
| 255 |
+
return "Unable to process the file format."
|
| 256 |
+
|
| 257 |
if file_path.lower().endswith(".pdf"):
|
| 258 |
loader = PyPDFLoader(file_path)
|
| 259 |
elif file_path.lower().endswith(".docx"):
|
| 260 |
loader = UnstructuredWordDocumentLoader(file_path)
|
| 261 |
else:
|
| 262 |
loader = TextLoader(file_path)
|
| 263 |
+
|
| 264 |
docs = loader.load()
|
| 265 |
summarize_chain = load_summarize_chain(llm_gpt4, chain_type="refine", question_prompt=initial_prompt, refine_prompt=refine_prompt)
|
| 266 |
summary = summarize_chain.invoke(docs)
|
| 267 |
return summary['output_text']
|
| 268 |
|
| 269 |
+
# CSVAgent: Process CSV queries for financial calculations and data analysis
|
| 270 |
def csv_agent(file, query):
|
| 271 |
+
if isinstance(file, str):
|
| 272 |
+
file_path = file
|
| 273 |
+
elif isinstance(file, dict):
|
| 274 |
+
file_path = file.get("name", None)
|
| 275 |
+
if file_path is None:
|
| 276 |
+
return "Unable to obtain the uploaded CSV file."
|
| 277 |
+
elif hasattr(file, "save"):
|
| 278 |
+
temp_dir = mkdtemp()
|
| 279 |
+
file_path = os.path.join(temp_dir, file.name)
|
| 280 |
+
file.save(file_path)
|
| 281 |
+
else:
|
| 282 |
+
return "Unable to process the file format."
|
| 283 |
try:
|
| 284 |
df = pd.read_csv(file_path)
|
| 285 |
except Exception as e:
|
| 286 |
+
return f"Error reading CSV: {str(e)}"
|
| 287 |
safe_dict = {"df": df, "pd": pd}
|
| 288 |
try:
|
| 289 |
result = eval(query, {"__builtins__": {}}, safe_dict)
|
| 290 |
return str(result)
|
| 291 |
except Exception as e:
|
| 292 |
+
return f"Query error: {str(e)}"
|
| 293 |
|
| 294 |
+
# SearchAgent: Use SerpAPI to perform real external search
|
| 295 |
def search_agent(query):
|
| 296 |
api_key = os.environ.get("SERPAPI_API_KEY")
|
| 297 |
if not api_key:
|
| 298 |
return "SERPAPI_API_KEY not set. Please set the environment variable."
|
| 299 |
+
params = {
|
| 300 |
+
"engine": "google",
|
| 301 |
+
"q": query,
|
| 302 |
+
"api_key": api_key,
|
| 303 |
+
"num": 5 # Adjust number of results as needed
|
| 304 |
+
}
|
| 305 |
search = GoogleSearch(params)
|
| 306 |
results = search.get_dict()
|
| 307 |
if "organic_results" in results:
|
|
|
|
| 315 |
else:
|
| 316 |
return "No results found."
|
| 317 |
|
| 318 |
+
# CrewAI Agent: Simulate CrewAI response integrating chat history and tool routing
|
| 319 |
+
def crew_ai_agent(query):
|
| 320 |
+
return f"CrewAI response for '{query}' with integrated chat history."
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# === Global State for Tab 5 ===
|
| 324 |
+
session_retriever = None
|
| 325 |
+
session_qa_chain = None
|
| 326 |
+
|
| 327 |
+
def save_file(file):
|
| 328 |
+
if isinstance(file, str):
|
| 329 |
+
return file
|
| 330 |
+
elif isinstance(file, dict):
|
| 331 |
+
return file.get("name", None)
|
| 332 |
+
elif hasattr(file, "save"):
|
| 333 |
+
temp_dir = mkdtemp()
|
| 334 |
+
file_path = os.path.join(temp_dir, file.name)
|
| 335 |
+
file.save(file_path)
|
| 336 |
+
return file_path
|
| 337 |
+
else:
|
| 338 |
+
return None
|
| 339 |
+
# === 新增 Uploaded QA 函式與 Tool (不更動其他代碼) ===
|
| 340 |
def uploaded_qa(file, query):
|
| 341 |
+
if isinstance(file, str):
|
| 342 |
+
file_path = file
|
| 343 |
+
elif isinstance(file, dict):
|
| 344 |
+
file_path = file.get("name", None)
|
| 345 |
+
if file_path is None:
|
| 346 |
+
return "Unable to obtain the uploaded file path."
|
| 347 |
+
elif hasattr(file, "save"):
|
| 348 |
+
temp_dir = mkdtemp()
|
| 349 |
+
file_path = os.path.join(temp_dir, file.name)
|
| 350 |
+
file.save(file_path)
|
| 351 |
+
else:
|
| 352 |
+
return "Unable to process the file format."
|
| 353 |
if file_path.lower().endswith(".pdf"):
|
| 354 |
loader = PyPDFLoader(file_path)
|
| 355 |
elif file_path.lower().endswith(".docx"):
|
|
|
|
| 369 |
)
|
| 370 |
return qa_temp.run(query)
|
| 371 |
|
| 372 |
+
uploaded_qa_tool = Tool(name="uploaded_qa", func=uploaded_qa, description="Tool for answering questions based on uploaded documents without conflicting with Biden QA.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
+
# === Multi-Agent Tools and Initialization ===
|
| 375 |
+
faq_tool = Tool(name="faq_qa", func=qa_gpt.run, description="Document Q&A tool that answers questions based on document content.")
|
| 376 |
+
summarise_tool = Tool(name="summarise", func=lambda q: llm_gpt4.predict(f"Summarise this:\n{q}"), description="Tool for summarising long content.")
|
| 377 |
+
data_tool = Tool(name="python_calc", func=lambda q: str(eval(q)), description="Execute Python-based calculations.")
|
| 378 |
+
csv_tool = Tool(name="csv_agent", func=csv_agent, description="CSV file query and data analysis tool.")
|
| 379 |
+
search_tool = Tool(name="search_agent", func=search_agent, description="External search tool using SerpAPI.")
|
| 380 |
+
crew_tool = Tool(name="crew_ai", func=crew_ai_agent, description="CrewAI agent integrating chat history and tool routing.")
|
| 381 |
|
| 382 |
+
tools = [summarise_tool, data_tool, csv_tool, search_tool, crew_tool, uploaded_qa_tool]
|
| 383 |
+
agent_executor = initialize_agent(tools, llm=llm_gpt4, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
+
def multi_agent_chat(query):
|
| 386 |
try:
|
| 387 |
+
return agent_executor.run(query)
|
| 388 |
except Exception as e:
|
| 389 |
+
return f"Error: {str(e)}"
|
| 390 |
|
| 391 |
+
# === 統一上傳檔案並由後端自動判斷功能的 Advanced Multi-Agent 功能 ===
|
| 392 |
+
def multi_agent_chat_advanced(query, file=None):
|
| 393 |
global session_retriever, session_qa_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
if file is not None:
|
| 396 |
+
# Step 1: 處理上傳並建立 retriever + QA chain
|
| 397 |
+
file_path = save_file(file)
|
| 398 |
if file_path is None:
|
| 399 |
+
return "無法處理該檔案格式。"
|
| 400 |
+
|
| 401 |
if file_path.lower().endswith(".csv"):
|
| 402 |
return csv_agent(file, query)
|
| 403 |
+
|
| 404 |
elif file_path.lower().endswith((".pdf", ".txt", ".docx")):
|
| 405 |
+
loader = (
|
| 406 |
+
PyPDFLoader(file_path)
|
| 407 |
+
if file_path.endswith(".pdf")
|
| 408 |
+
else UnstructuredWordDocumentLoader(file_path)
|
| 409 |
+
if file_path.endswith(".docx")
|
| 410 |
+
else TextLoader(file_path)
|
| 411 |
+
)
|
| 412 |
docs = loader.load()
|
| 413 |
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
|
| 414 |
db = FAISS.from_documents(chunks, embeddings)
|
|
|
|
| 419 |
memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True),
|
| 420 |
combine_docs_chain_kwargs={"prompt": custom_prompt}
|
| 421 |
)
|
| 422 |
+
|
| 423 |
+
# 如果 query 是摘要
|
| 424 |
+
if any(kw in query.lower() for kw in ["summarize", "summary", "摘要", "總結"]):
|
| 425 |
+
return document_summarize(file)
|
| 426 |
else:
|
| 427 |
+
return session_qa_chain.run(query)
|
|
|
|
|
|
|
|
|
|
| 428 |
else:
|
| 429 |
+
return "不支援的檔案格式。"
|
| 430 |
+
|
| 431 |
+
# Step 2: 已經上傳過檔案 → 使用現成的 retriever
|
| 432 |
elif session_qa_chain is not None:
|
| 433 |
+
return session_qa_chain.run(query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
else:
|
| 435 |
+
return "請先上傳檔案,再進行提問。"
|
| 436 |
+
|
| 437 |
+
# === Gradio Interface Settings (All texts in English) ===
|
|
|
|
| 438 |
|
|
|
|
| 439 |
demo_description = """
|
| 440 |
**Context**:
|
| 441 |
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
|
|
|
| 454 |
|
| 455 |
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
|
| 456 |
"""
|
| 457 |
+
|
| 458 |
demo_description2 = """
|
| 459 |
**Context**:
|
| 460 |
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
|
|
|
| 473 |
|
| 474 |
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
|
| 475 |
"""
|
| 476 |
+
|
| 477 |
demo_description3 = """
|
| 478 |
**Context**:
|
| 479 |
+
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
| 480 |
+
your uploaded document.
|
| 481 |
+
All responses are grounded in this document.
|
| 482 |
+
If no relevant information is found in the document, the system will say "No relevant info found."
|
| 483 |
+
|
| 484 |
+
*Note: The GPT module supports follow-up questions with conversation history management, enabling more interactive and context-aware discussions.*
|
| 485 |
|
| 486 |
Feel free to ask any question related to your document.
|
| 487 |
"""
|
| 488 |
+
|
| 489 |
demo_description4 = """
|
| 490 |
**Context**:
|
| 491 |
This assistant performs multi-agent tasks using tools such as:
|
|
|
|
| 494 |
- Financial or CSV-style logic queries
|
| 495 |
- Multi-step reasoning via agent orchestration
|
| 496 |
|
| 497 |
+
系統會根據上傳檔案(CSV、PDF、TXT、DOCX)以及查詢內容,自動選擇呼叫適當的功能,
|
| 498 |
+
例如:若查詢中含有 "summarize"/"摘要",則對文件進行摘要;若是 CSV 則進行資料查詢等。
|
| 499 |
"""
|
| 500 |
+
|
| 501 |
demo_description5 = """
|
| 502 |
**Context**:
|
| 503 |
This demo uses Document Summarization via a Map-Reduce chain.
|