Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,11 +6,12 @@ import torch
|
|
| 6 |
import re
|
| 7 |
import transformers
|
| 8 |
import chardet
|
| 9 |
-
|
| 10 |
-
from
|
| 11 |
-
from
|
|
|
|
| 12 |
|
| 13 |
-
# ---
|
| 14 |
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
|
| 15 |
|
| 16 |
# === Enterprise Environment Settings ===
|
|
@@ -25,36 +26,34 @@ os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers"
|
|
| 25 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hf_cache"
|
| 26 |
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
|
| 27 |
|
| 28 |
-
# ===
|
| 29 |
-
from
|
| 30 |
-
from
|
| 31 |
from langchain.chains import RetrievalQA
|
| 32 |
from langchain.prompts import PromptTemplate
|
| 33 |
-
from
|
| 34 |
-
from
|
| 35 |
from langchain.chains import ConversationalRetrievalChain
|
| 36 |
from langchain.memory import ConversationBufferMemory
|
| 37 |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader
|
| 38 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 39 |
from langchain.chains.summarize import load_summarize_chain
|
|
|
|
| 40 |
from langchain.schema import AIMessage
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
# ===
|
| 43 |
from serpapi import GoogleSearch
|
| 44 |
-
|
| 45 |
-
# CrewAI Imports
|
| 46 |
from crewai import Crew, Agent, Task, Process
|
| 47 |
from crewai.tools import tool
|
| 48 |
from langchain_experimental.agents import create_pandas_dataframe_agent
|
| 49 |
-
from pydantic import BaseModel
|
| 50 |
-
|
| 51 |
-
import numexpr as ne
|
| 52 |
|
| 53 |
-
# Global variables for non‑math functions (unchanged)
|
| 54 |
session_retriever = None
|
| 55 |
session_qa_chain = None
|
| 56 |
csv_dataframe = None # CSV tool will use this
|
| 57 |
-
|
| 58 |
# === Safe Result Formatter ===
|
| 59 |
def safe_format_result(result) -> str:
|
| 60 |
try:
|
|
@@ -70,7 +69,6 @@ def safe_format_result(result) -> str:
|
|
| 70 |
return str(result)
|
| 71 |
except Exception as e:
|
| 72 |
return f"Error formatting result: {e}"
|
| 73 |
-
|
| 74 |
# === Model and Device Setup ===
|
| 75 |
if torch.backends.mps.is_available():
|
| 76 |
device = "mps"
|
|
@@ -174,11 +172,12 @@ qa_gpt = ConversationalRetrievalChain.from_llm(
|
|
| 174 |
combine_docs_chain_kwargs={"prompt": custom_prompt}
|
| 175 |
)
|
| 176 |
|
| 177 |
-
# === Helper Function
|
| 178 |
def get_file_path(file):
|
| 179 |
if isinstance(file, str):
|
| 180 |
return file
|
| 181 |
elif isinstance(file, dict):
|
|
|
|
| 182 |
return file.get("data", file.get("name", None))
|
| 183 |
elif hasattr(file, "save"):
|
| 184 |
temp_dir = mkdtemp()
|
|
@@ -188,7 +187,7 @@ def get_file_path(file):
|
|
| 188 |
else:
|
| 189 |
return None
|
| 190 |
|
| 191 |
-
# ===
|
| 192 |
def rag_llama_qa(query):
|
| 193 |
output = RetrievalQA.from_chain_type(
|
| 194 |
llm=llm_local,
|
|
@@ -264,6 +263,7 @@ def document_summarize(file):
|
|
| 264 |
summary = summarize_chain.invoke(docs)
|
| 265 |
return summary['output_text']
|
| 266 |
|
|
|
|
| 267 |
def csv_agent(file, query):
|
| 268 |
file_path = get_file_path(file)
|
| 269 |
if file_path is None:
|
|
@@ -323,130 +323,418 @@ def uploaded_qa(file, query):
|
|
| 323 |
)
|
| 324 |
return qa_temp.run(query)
|
| 325 |
|
| 326 |
-
# ===
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
pattern = r"(\d+(?:\.\d+)?)%\s+of\s+(\d+(?:\.\d+)?)"
|
| 332 |
-
match = re.search(pattern, query)
|
| 333 |
-
if match:
|
| 334 |
-
perc = float(match.group(1)) / 100.0
|
| 335 |
-
number = match.group(2)
|
| 336 |
-
return f"{perc} * {number}"
|
| 337 |
-
return query
|
| 338 |
-
|
| 339 |
-
def convert_trig(expr: str) -> str:
|
| 340 |
-
"""
|
| 341 |
-
Convert sine functions from degrees to radians.
|
| 342 |
-
For example: sin(45) -> sin(pi*45/180)
|
| 343 |
-
"""
|
| 344 |
-
pattern = r"sin\(\s*(\d+(?:\.\d+)?)\s*\)"
|
| 345 |
-
return re.sub(pattern, lambda m: f"sin(pi*{m.group(1)}/180)", expr)
|
| 346 |
-
|
| 347 |
-
def convert_nl_to_expr(query: str) -> str:
|
| 348 |
-
"""
|
| 349 |
-
Use GPT-4 to convert a natural language math question into a valid mathematical expression.
|
| 350 |
-
"""
|
| 351 |
-
prompt = (
|
| 352 |
-
"Convert the following natural language math question into a valid mathematical expression "
|
| 353 |
-
"using standard arithmetic symbols. Output only the expression.\n\n"
|
| 354 |
-
"Question: " + query
|
| 355 |
-
)
|
| 356 |
-
response = llm_gpt4.invoke(prompt)
|
| 357 |
-
if isinstance(response, AIMessage):
|
| 358 |
-
response = response.content
|
| 359 |
-
return response.strip()
|
| 360 |
-
|
| 361 |
-
def generate_nl_explanation(expression: str, result: str) -> str:
|
| 362 |
-
"""
|
| 363 |
-
Use GPT-4 to generate a natural language explanation for the evaluated expression.
|
| 364 |
-
"""
|
| 365 |
-
prompt = (
|
| 366 |
-
f"Explain in a clear, human-friendly sentence the result of the expression '{expression}' which evaluates to '{result}'."
|
| 367 |
-
)
|
| 368 |
-
explanation = llm_gpt4.invoke(prompt)
|
| 369 |
-
if isinstance(explanation, AIMessage):
|
| 370 |
-
explanation = explanation.content
|
| 371 |
-
return explanation.strip()
|
| 372 |
|
| 373 |
-
def
|
| 374 |
try:
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
if "
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
expression = convert_trig(expression)
|
| 382 |
-
# Step 4: Evaluate the expression using numexpr
|
| 383 |
-
evaluated_result = ne.evaluate(expression)
|
| 384 |
-
# Step 5: Generate natural language explanation
|
| 385 |
-
nl_explanation = generate_nl_explanation(expression, str(evaluated_result))
|
| 386 |
-
return nl_explanation
|
| 387 |
except Exception as e:
|
| 388 |
-
return f"
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
math_agent = Agent(
|
| 404 |
role="Math Assistant",
|
| 405 |
-
goal="Perform
|
| 406 |
-
backstory="You
|
| 407 |
-
tools=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
verbose=True
|
| 409 |
)
|
| 410 |
-
|
| 411 |
router_agent = Agent(
|
| 412 |
role="Query Router",
|
| 413 |
-
goal="
|
| 414 |
-
backstory="You
|
| 415 |
-
tools=[
|
| 416 |
verbose=True
|
| 417 |
)
|
| 418 |
-
|
| 419 |
router_task = Task(
|
| 420 |
description="""
|
| 421 |
-
Based on the user's query,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
""",
|
| 423 |
-
expected_output="The final answer from the selected agent.",
|
| 424 |
agent=router_agent,
|
| 425 |
input_variables=["query"]
|
| 426 |
)
|
| 427 |
|
| 428 |
-
crew_llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0.2, openai_api_key=openai_api_key)
|
| 429 |
-
|
| 430 |
crew = Crew(
|
| 431 |
-
agents=[math_agent],
|
| 432 |
tasks=[router_task],
|
| 433 |
process=Process.sequential,
|
| 434 |
verbose=True,
|
| 435 |
llm=crew_llm
|
| 436 |
)
|
| 437 |
|
| 438 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
def multi_agent_chat_advanced(query: str, file=None) -> str:
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
)
|
| 451 |
|
| 452 |
if __name__ == "__main__":
|
|
|
|
| 6 |
import re
|
| 7 |
import transformers
|
| 8 |
import chardet
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
import gradio as gr
|
| 13 |
|
| 14 |
+
# --- 解決 Matplotlib 權限問題 ---
|
| 15 |
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
|
| 16 |
|
| 17 |
# === Enterprise Environment Settings ===
|
|
|
|
| 26 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hf_cache"
|
| 27 |
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
|
| 28 |
|
| 29 |
+
# === Load Required Modules ===
|
| 30 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 31 |
+
from langchain.vectorstores import Chroma, FAISS
|
| 32 |
from langchain.chains import RetrievalQA
|
| 33 |
from langchain.prompts import PromptTemplate
|
| 34 |
+
from langchain.llms import HuggingFacePipeline
|
| 35 |
+
from langchain.chat_models import ChatOpenAI
|
| 36 |
from langchain.chains import ConversationalRetrievalChain
|
| 37 |
from langchain.memory import ConversationBufferMemory
|
| 38 |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader
|
| 39 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 40 |
from langchain.chains.summarize import load_summarize_chain
|
| 41 |
+
from tempfile import mkdtemp
|
| 42 |
from langchain.schema import AIMessage
|
| 43 |
+
from datetime import datetime
|
| 44 |
+
import numexpr as ne
|
| 45 |
+
import pandas as pd
|
| 46 |
|
| 47 |
+
# === Multi-Agent Imports ===
|
| 48 |
from serpapi import GoogleSearch
|
| 49 |
+
# CrewAI 部分:完全使用 CrewAI 的 Agent、Task、Crew 與 @tool 裝飾器
|
|
|
|
| 50 |
from crewai import Crew, Agent, Task, Process
|
| 51 |
from crewai.tools import tool
|
| 52 |
from langchain_experimental.agents import create_pandas_dataframe_agent
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
| 54 |
session_retriever = None
|
| 55 |
session_qa_chain = None
|
| 56 |
csv_dataframe = None # CSV tool will use this
|
|
|
|
| 57 |
# === Safe Result Formatter ===
|
| 58 |
def safe_format_result(result) -> str:
|
| 59 |
try:
|
|
|
|
| 69 |
return str(result)
|
| 70 |
except Exception as e:
|
| 71 |
return f"Error formatting result: {e}"
|
|
|
|
| 72 |
# === Model and Device Setup ===
|
| 73 |
if torch.backends.mps.is_available():
|
| 74 |
device = "mps"
|
|
|
|
| 172 |
combine_docs_chain_kwargs={"prompt": custom_prompt}
|
| 173 |
)
|
| 174 |
|
| 175 |
+
# === Helper Function:從上傳文件中提取檔案路徑 ===
|
| 176 |
def get_file_path(file):
|
| 177 |
if isinstance(file, str):
|
| 178 |
return file
|
| 179 |
elif isinstance(file, dict):
|
| 180 |
+
# 優先使用 "data" 鍵,其次是 "name"
|
| 181 |
return file.get("data", file.get("name", None))
|
| 182 |
elif hasattr(file, "save"):
|
| 183 |
temp_dir = mkdtemp()
|
|
|
|
| 187 |
else:
|
| 188 |
return None
|
| 189 |
|
| 190 |
+
# === 原有功能(Tab 1~Tab 4)函式 ===
|
| 191 |
def rag_llama_qa(query):
|
| 192 |
output = RetrievalQA.from_chain_type(
|
| 193 |
llm=llm_local,
|
|
|
|
| 263 |
summary = summarize_chain.invoke(docs)
|
| 264 |
return summary['output_text']
|
| 265 |
|
| 266 |
+
|
| 267 |
def csv_agent(file, query):
|
| 268 |
file_path = get_file_path(file)
|
| 269 |
if file_path is None:
|
|
|
|
| 323 |
)
|
| 324 |
return qa_temp.run(query)
|
| 325 |
|
| 326 |
+
# === CrewAI Multi-Agent System (Tab 5) ===
|
| 327 |
+
# 完全捨棄 langchain.agents.Tool,使用 CrewAI 的 @tool 裝飾器來定義工具
|
| 328 |
+
from pydantic import BaseModel
|
| 329 |
+
class SimpleQuery(BaseModel):
|
| 330 |
+
query: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
+
def _general_chat(query: str) -> str:
|
| 333 |
try:
|
| 334 |
+
response = llm_gpt4.invoke(query)
|
| 335 |
+
if isinstance(response, AIMessage):
|
| 336 |
+
response = response.content # ✅ 取出真實字串
|
| 337 |
+
if any(kw in response.lower() for kw in ["i'm not sure", "i don't know", "no information", "can't find"]):
|
| 338 |
+
return _search_web_tool(query)
|
| 339 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
except Exception as e:
|
| 341 |
+
return f"General chat error: {e}"
|
| 342 |
+
@tool("general_chat")
|
| 343 |
+
def general_chat_tool(query: str) -> str:
|
| 344 |
+
"""General assistant: Answer general questions without relying on documents."""
|
| 345 |
+
try:
|
| 346 |
+
response = llm_gpt4.invoke(query)
|
| 347 |
+
if isinstance(response, AIMessage):
|
| 348 |
+
response = response.content # ✅ 取出真實字串
|
| 349 |
+
if any(kw in response.lower() for kw in ["i'm not sure", "i don't know", "no information", "can't find"]):
|
| 350 |
+
return search_web(query)
|
| 351 |
+
return response
|
| 352 |
+
except Exception as e:
|
| 353 |
+
return f"General chat error: {e}"
|
| 354 |
+
|
| 355 |
+
def get_time_tool(query: str) -> str:
|
| 356 |
+
now = datetime.now()
|
| 357 |
+
return now.strftime("The current local time is %I:%M %p on %B %d, %Y.")
|
| 358 |
+
|
| 359 |
+
@tool("summarise")
|
| 360 |
+
def summarise_tool(query: str) -> str:
|
| 361 |
+
"""Summarise: Use document summarisation functionality."""
|
| 362 |
+
global session_retriever, session_qa_chain
|
| 363 |
+
if session_retriever is None:
|
| 364 |
+
return "尚未上傳文件。"
|
| 365 |
+
try:
|
| 366 |
+
docs = session_retriever.get_relevant_documents(query if query.strip() else "summary")
|
| 367 |
+
if not docs:
|
| 368 |
+
return "無法從文件中找到相關內容。"
|
| 369 |
+
summarize_chain = load_summarize_chain(llm_gpt4, chain_type="refine", question_prompt=initial_prompt, refine_prompt=refine_prompt)
|
| 370 |
+
summary = summarize_chain.invoke(docs)
|
| 371 |
+
return summary['output_text']
|
| 372 |
+
except Exception as e:
|
| 373 |
+
return f"摘要錯誤: {e}"
|
| 374 |
+
|
| 375 |
+
def _calc_tool(query: str) -> str:
|
| 376 |
+
# 判斷是否是明確的數學算式(只包含數學相關符號)
|
| 377 |
+
if re.match(r"^[a-zA-Z\d\s\+\-\*\/\%\(\)\.]+$", query.strip()):
|
| 378 |
+
try:
|
| 379 |
+
result = ne.evaluate(query)
|
| 380 |
+
return f"The result of the calculation is: {result}"
|
| 381 |
+
except Exception as e:
|
| 382 |
+
return f"Error in formula calculation: {e}"
|
| 383 |
+
# 否則交給 GPT-4o 處理自然語言的問題 + 生成自然語言回答
|
| 384 |
+
try:
|
| 385 |
+
response = llm_gpt4.invoke(query)
|
| 386 |
+
if isinstance(response, AIMessage):
|
| 387 |
+
response = response.content
|
| 388 |
+
return response
|
| 389 |
+
except Exception as e:
|
| 390 |
+
return f"Natural language math error: {e}"
|
| 391 |
+
|
| 392 |
+
@tool("python_calc")
|
| 393 |
+
def python_calc_tool(query: str) -> str:
|
| 394 |
+
"""Python Calculation: Perform basic arithmetic or logical operations."""
|
| 395 |
+
try:
|
| 396 |
+
result = ne.evaluate(query)
|
| 397 |
+
return str(result)
|
| 398 |
+
except Exception as e:
|
| 399 |
+
return f"計算錯誤: {e}"
|
| 400 |
+
def _search_web_tool(query: str) -> str:
|
| 401 |
+
return search_web(query)
|
| 402 |
+
@tool("search_tool")
|
| 403 |
+
def search_tool_func(query: str) -> str:
|
| 404 |
+
"""Search: Perform web searches using external search engines."""
|
| 405 |
+
return search_web(query)
|
| 406 |
+
|
| 407 |
+
@tool("uploaded_qa")
|
| 408 |
+
def uploaded_qa_tool_func(query: str) -> str:
|
| 409 |
+
"""Document QA: Answer questions based on the uploaded document content."""
|
| 410 |
+
global session_qa_chain
|
| 411 |
+
if session_qa_chain is not None:
|
| 412 |
+
try:
|
| 413 |
+
return session_qa_chain.run(query)
|
| 414 |
+
except Exception as e:
|
| 415 |
+
return f"文檔問答錯誤: {e}"
|
| 416 |
+
else:
|
| 417 |
+
return "尚未上傳文件。"
|
| 418 |
+
|
| 419 |
+
@tool("csv_agent")
|
| 420 |
+
def csv_tool_func(query: str) -> str:
|
| 421 |
+
"""CSV Agent: Use natural language to analyse uploaded CSV files."""
|
| 422 |
+
global csv_dataframe
|
| 423 |
+
if csv_dataframe is None:
|
| 424 |
+
return "No CSV file uploaded."
|
| 425 |
+
try:
|
| 426 |
+
agent = create_pandas_dataframe_agent(llm=llm_gpt4, df=csv_dataframe, verbose=True)
|
| 427 |
+
return agent.run(f"Here is the table:\n{csv_dataframe.head().to_string(index=False)}\n\n{query}")
|
| 428 |
+
except Exception as e:
|
| 429 |
+
return f"CSV Agent error: {e}"
|
| 430 |
+
|
| 431 |
+
# 建立 CrewAI 代理(僅針對 Tab 5)
|
| 432 |
+
general_agent = Agent(
|
| 433 |
+
role="General Assistant",
|
| 434 |
+
goal="Respond to any general query that is not related to documents or CSV files.",
|
| 435 |
+
backstory="You're an intelligent assistant who answers questions about anything general, such as math, dates, or general knowledge.",
|
| 436 |
+
tools=[general_chat_tool],
|
| 437 |
+
verbose=True
|
| 438 |
+
)
|
| 439 |
+
summarizer_agent = Agent(
|
| 440 |
+
role="Document Summarizer",
|
| 441 |
+
goal="Summarise the content of the uploaded document.",
|
| 442 |
+
backstory="You are a professional summarisation expert who can identify key points in long documents.",
|
| 443 |
+
tools=[summarise_tool],
|
| 444 |
+
verbose=True
|
| 445 |
+
)
|
| 446 |
+
document_qa_agent = Agent(
|
| 447 |
+
role="Document QA Specialist",
|
| 448 |
+
goal="Answer questions based on the uploaded document.",
|
| 449 |
+
backstory="You are an expert in document understanding and can accurately extract answers.",
|
| 450 |
+
tools=[uploaded_qa_tool_func],
|
| 451 |
+
verbose=True
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
search_agent = Agent(
|
| 455 |
+
role="Search Expert",
|
| 456 |
+
goal="Search the web and provide relevant information.",
|
| 457 |
+
backstory="You are an expert at finding relevant information from the internet.",
|
| 458 |
+
tools=[search_tool_func],
|
| 459 |
+
verbose=True
|
| 460 |
+
)
|
| 461 |
|
| 462 |
math_agent = Agent(
|
| 463 |
role="Math Assistant",
|
| 464 |
+
goal="Perform accurate arithmetic or logical calculations.",
|
| 465 |
+
backstory="You are a calculator expert skilled at quick computations.",
|
| 466 |
+
tools=[python_calc_tool],
|
| 467 |
+
verbose=True
|
| 468 |
+
)
|
| 469 |
+
csv_agent = Agent(
|
| 470 |
+
role="CSV Analyst",
|
| 471 |
+
goal="Analyse tabular data and answer questions about the uploaded CSV file.",
|
| 472 |
+
backstory="You are skilled in interpreting tabular datasets and can extract numerical or logical insights.",
|
| 473 |
+
tools=[csv_tool_func],
|
| 474 |
verbose=True
|
| 475 |
)
|
|
|
|
| 476 |
router_agent = Agent(
|
| 477 |
role="Query Router",
|
| 478 |
+
goal="Determine the most suitable agent or tool to handle the user query.",
|
| 479 |
+
backstory="You are an intelligent query dispatcher that analyses the user's intent and chooses the best AI agent to answer.",
|
| 480 |
+
tools=[python_calc_tool, search_tool_func, csv_tool_func, uploaded_qa_tool_func, summarise_tool, general_chat_tool],
|
| 481 |
verbose=True
|
| 482 |
)
|
|
|
|
| 483 |
router_task = Task(
|
| 484 |
description="""
|
| 485 |
+
Based on the user's query, decide which agent or tool is best suited to handle it:
|
| 486 |
+
- If the query is related to the content of an uploaded file (e.g., 'what is this document about?'), send it to the **Document QA Agent**.
|
| 487 |
+
- If the query contains words like 'summarize', 'summary', or 'main points', use the **Summarizer Agent**.
|
| 488 |
+
- If the query involves numbers, calculations, or logic (e.g., '50 * 23 - 5', 'what is 10% of 800'), send it to the **Math Agent**.
|
| 489 |
+
- If the user uploaded a CSV file and asks about table content, data trends, or uses words like 'data', 'table', 'csv', 'column', or 'row', send it to the **CSV Agent**.
|
| 490 |
+
- If the user asks about current events, trending topics, or online information (e.g., 'What is LangChain?', 'latest news'), send it to the **Search Agent**.
|
| 491 |
+
- If the query is about current date, time, or day of week (e.g., 'what is today\\'s date?', 'what time is it?', 'what day is it?'), send it to the **Search Agent** instead of the General Agent.
|
| 492 |
+
- If the question is general and not related to documents, calculations, CSVs, or the internet (e.g., 'Who are you?', 'Tell me a fun fact'), send it to the **General Agent**.
|
| 493 |
+
- If none of these apply, use your best judgment to choose the most relevant agent.
|
| 494 |
""",
|
| 495 |
+
expected_output="The final answer from the selected agent or tool.",
|
| 496 |
agent=router_agent,
|
| 497 |
input_variables=["query"]
|
| 498 |
)
|
| 499 |
|
|
|
|
|
|
|
| 500 |
crew = Crew(
|
| 501 |
+
agents=[general_agent, summarizer_agent, document_qa_agent, search_agent, math_agent, csv_agent],
|
| 502 |
tasks=[router_task],
|
| 503 |
process=Process.sequential,
|
| 504 |
verbose=True,
|
| 505 |
llm=crew_llm
|
| 506 |
)
|
| 507 |
|
| 508 |
+
#def multi_agent_chat(query: str) -> str:
|
| 509 |
+
# print(f"Routing query: {query}")
|
| 510 |
+
# try:
|
| 511 |
+
# 🔧 補這段,用來轉換 upload file 語意(擺在 kickoff 前)
|
| 512 |
+
# if file is not None and ("upload file" in query.lower() or "the file" in query.lower()):
|
| 513 |
+
# query = f"The user uploaded the following file:\n\n{query}"
|
| 514 |
+
|
| 515 |
+
# result = crew.kickoff(inputs={"query": query})
|
| 516 |
+
# result_str = safe_format_result(result)
|
| 517 |
+
# if any(kw in result_str.lower() for kw in ["i don't know", "no relevant", "no information", "no answer"]) or result_str.strip() == "":
|
| 518 |
+
# try:
|
| 519 |
+
# result = search_agent.execute(query)
|
| 520 |
+
# return safe_format_result(result)
|
| 521 |
+
# except Exception as e:
|
| 522 |
+
# return f"[Fallback Search Error] {e}"
|
| 523 |
+
# return result_str
|
| 524 |
+
# except Exception as e:
|
| 525 |
+
# return f"Error: {e}"
|
| 526 |
+
|
| 527 |
def multi_agent_chat_advanced(query: str, file=None) -> str:
|
| 528 |
+
global session_retriever, session_qa_chain, csv_dataframe
|
| 529 |
+
|
| 530 |
+
# === Step 0: Smart routing without needing uploaded files ===
|
| 531 |
+
lower_query = query.lower()
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
if re.match(r"^[a-zA-Z\d\s\+\-\*\/\%\(\)\.]+$", lower_query):
|
| 535 |
+
return _calc_tool(query)
|
| 536 |
+
|
| 537 |
+
date_keywords = ["what date", "today", "what time", "what day", "現在幾點", "今天幾號", "禮拜幾"]
|
| 538 |
+
if any(k in lower_query for k in date_keywords):
|
| 539 |
+
return get_time_tool(query)
|
| 540 |
+
|
| 541 |
+
general_keywords = ["who are you", "tell me", "what is your name", "what can you do", "fun fact"]
|
| 542 |
+
if any(k in lower_query for k in general_keywords):
|
| 543 |
+
return _general_chat(query)
|
| 544 |
+
|
| 545 |
+
# === Step 1: 檢查 file 是否存在與格式 ===
|
| 546 |
+
file_path = get_file_path(file) if file is not None else None
|
| 547 |
+
|
| 548 |
+
# === Step 2: 判斷 Query 是否需要走文件處理邏輯 ===
|
| 549 |
+
non_doc_keywords = ["calculate", "sum", "date", "time", "how many", "how much", "weather", "temperature"]
|
| 550 |
+
use_file_chain = not any(kw in query.lower() for kw in non_doc_keywords)
|
| 551 |
+
|
| 552 |
+
# === Step 3: 有上傳檔案 ===
|
| 553 |
+
if file_path:
|
| 554 |
+
file_lower = file_path.lower()
|
| 555 |
+
|
| 556 |
+
# === 3-1: 處理 CSV ===
|
| 557 |
+
if file_lower.endswith(".csv"):
|
| 558 |
+
try:
|
| 559 |
+
with open(file_path, 'rb') as f:
|
| 560 |
+
result = chardet.detect(f.read())
|
| 561 |
+
encoding = result['encoding']
|
| 562 |
+
df = pd.read_csv(file_path, encoding=encoding)
|
| 563 |
+
csv_dataframe = df # 👈 保證 global 賦值成功
|
| 564 |
+
|
| 565 |
+
# 若 query 有提到 file,加 context
|
| 566 |
+
if "file" in query.lower() or "upload" in query.lower():
|
| 567 |
+
query = f"The user uploaded the following CSV file:\n\n{query}"
|
| 568 |
+
|
| 569 |
+
result = crew.kickoff(inputs={"query": query})
|
| 570 |
+
return safe_format_result(result)
|
| 571 |
+
except Exception as e:
|
| 572 |
+
return f"CSV Parsing Error: {e}"
|
| 573 |
+
|
| 574 |
+
# === 3-2: 處理 PDF / DOCX / TXT ===
|
| 575 |
+
elif file_lower.endswith((".pdf", ".txt", ".docx")):
|
| 576 |
+
try:
|
| 577 |
+
loader = (
|
| 578 |
+
PyPDFLoader(file_path) if file_lower.endswith(".pdf")
|
| 579 |
+
else UnstructuredWordDocumentLoader(file_path) if file_lower.endswith(".docx")
|
| 580 |
+
else TextLoader(file_path)
|
| 581 |
+
)
|
| 582 |
+
docs = loader.load()
|
| 583 |
+
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
|
| 584 |
+
db = FAISS.from_documents(chunks, embeddings)
|
| 585 |
+
session_retriever = db.as_retriever()
|
| 586 |
+
session_qa_chain = ConversationalRetrievalChain.from_llm(
|
| 587 |
+
llm=llm_gpt4,
|
| 588 |
+
retriever=session_retriever,
|
| 589 |
+
memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True),
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
# 若是摘要類 Query,先用 Summarize Chain 回傳
|
| 593 |
+
if any(kw in query.lower() for kw in ["summarize", "summary", "摘要", "總結"]):
|
| 594 |
+
return document_summarize(file_path)
|
| 595 |
+
|
| 596 |
+
# 若判斷需用 QA Chain
|
| 597 |
+
if use_file_chain:
|
| 598 |
+
try:
|
| 599 |
+
return session_qa_chain.run(query)
|
| 600 |
+
except Exception as e:
|
| 601 |
+
return f"Document QA Error: {e}"
|
| 602 |
+
|
| 603 |
+
# 否則進入 Multi-Agent 推理
|
| 604 |
+
if "file" in query.lower() or "upload" in query.lower():
|
| 605 |
+
query = f"The user uploaded the following document:\n\n{query}"
|
| 606 |
+
|
| 607 |
+
result = crew.kickoff(inputs={"query": query})
|
| 608 |
+
return safe_format_result(result)
|
| 609 |
+
|
| 610 |
+
except Exception as e:
|
| 611 |
+
return f"Document Processing Error: {e}"
|
| 612 |
+
|
| 613 |
+
else:
|
| 614 |
+
return "Unsupported file format."
|
| 615 |
+
|
| 616 |
+
# === Step 4: 沒有上傳檔案,直接用 CrewAI 推理 ===
|
| 617 |
+
try:
|
| 618 |
+
result = crew.kickoff(inputs={"query": query})
|
| 619 |
+
return safe_format_result(result)
|
| 620 |
+
except Exception as e:
|
| 621 |
+
return f"Multi-Agent Error: {e}"
|
| 622 |
+
|
| 623 |
+
# === Gradio Interface Settings ===
|
| 624 |
+
demo_description = """
|
| 625 |
+
**Context**:
|
| 626 |
+
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
| 627 |
+
Biden’s 2023 State of the Union Address.
|
| 628 |
+
All responses are grounded in this document.
|
| 629 |
+
If no relevant information is found in the document, the system will say "No relevant info found."
|
| 630 |
+
|
| 631 |
+
**Sample Questions**:
|
| 632 |
+
1. What were the main topics regarding infrastructure in this speech?
|
| 633 |
+
2. How does the speech address the competition with China?
|
| 634 |
+
3. What does Biden say about job growth in the past two years?
|
| 635 |
+
4. Does the speech mention anything about Social Security or Medicare?
|
| 636 |
+
5. What does the speech propose regarding Big Tech or online privacy?
|
| 637 |
+
|
| 638 |
+
*Note: The LLaMA module generates responses based solely on the current query without follow-up memory or chat history management.*
|
| 639 |
+
|
| 640 |
+
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
|
| 641 |
+
"""
|
| 642 |
+
demo_description2 = """
|
| 643 |
+
**Context**:
|
| 644 |
+
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
| 645 |
+
Biden’s 2023 State of the Union Address.
|
| 646 |
+
All responses are grounded in this document.
|
| 647 |
+
If no relevant information is found in the document, the system will say "No relevant info found."
|
| 648 |
+
|
| 649 |
+
**Sample Questions**:
|
| 650 |
+
1. What were the main topics regarding infrastructure in this speech?
|
| 651 |
+
2. How does the speech address the competition with China?
|
| 652 |
+
3. What does Biden say about job growth in the past two years?
|
| 653 |
+
4. Does the speech mention anything about Social Security or Medicare?
|
| 654 |
+
5. What does the speech propose regarding Big Tech or online privacy?
|
| 655 |
+
|
| 656 |
+
*Note: The GPT module supports follow-up questions with conversation history management, enabling more interactive and context-aware discussions.*
|
| 657 |
+
|
| 658 |
+
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
|
| 659 |
+
"""
|
| 660 |
+
demo_description3 = """
|
| 661 |
+
**Context**:
|
| 662 |
+
Upload a PDF, TXT, or DOCX file and ask a question about its content.
|
| 663 |
+
This demo uses GPT-4 to answer questions based on the content of your uploaded document.
|
| 664 |
+
|
| 665 |
+
Feel free to ask any question related to your document.
|
| 666 |
+
"""
|
| 667 |
+
demo_description4 = """
|
| 668 |
+
**Context**:
|
| 669 |
+
This assistant performs multi-agent tasks using tools such as:
|
| 670 |
+
- Document summarisation
|
| 671 |
+
- FAQ-style document Q&A
|
| 672 |
+
- Financial or CSV-style logic queries
|
| 673 |
+
- Multi-step reasoning via agent orchestration
|
| 674 |
+
|
| 675 |
+
The system will automatically select the appropriate function based on the uploaded file (CSV, PDF, TXT, DOCX) and the query content.
|
| 676 |
+
For example, if the query contains "summarize"/"摘要", it will summarize the document; if it's CSV data, it will perform data analysis.
|
| 677 |
+
"""
|
| 678 |
+
demo_description5 = """
|
| 679 |
+
**Context**:
|
| 680 |
+
This demo uses Document Summarization via a Map-Reduce chain.
|
| 681 |
+
Upload a PDF, TXT, or DOCX file to get an automatic summary of its contents.
|
| 682 |
+
"""
|
| 683 |
+
|
| 684 |
+
demo = gr.TabbedInterface(
|
| 685 |
+
interface_list=[
|
| 686 |
+
gr.Interface(
|
| 687 |
+
fn=rag_llama_qa,
|
| 688 |
+
inputs="text",
|
| 689 |
+
outputs="text",
|
| 690 |
+
title="Biden Q&A (LLaMA)",
|
| 691 |
+
allow_flagging="never",
|
| 692 |
+
description=demo_description
|
| 693 |
+
),
|
| 694 |
+
gr.Interface(
|
| 695 |
+
fn=rag_gpt4_qa,
|
| 696 |
+
inputs="text",
|
| 697 |
+
outputs="text",
|
| 698 |
+
title="Biden Q&A (GPT-4)",
|
| 699 |
+
allow_flagging="never",
|
| 700 |
+
description=demo_description2
|
| 701 |
+
),
|
| 702 |
+
gr.Interface(
|
| 703 |
+
fn=upload_and_chat,
|
| 704 |
+
inputs=[gr.File(label="Upload PDF, TXT, or DOCX"), gr.Textbox(label="Ask a question")],
|
| 705 |
+
outputs="text",
|
| 706 |
+
title="Your Docs Q&A (Upload + GPT-4)",
|
| 707 |
+
allow_flagging="never",
|
| 708 |
+
description=demo_description3
|
| 709 |
+
),
|
| 710 |
+
gr.Interface(
|
| 711 |
+
fn=document_summarize,
|
| 712 |
+
inputs=[gr.File(label="Upload PDF, TXT, or DOCX")],
|
| 713 |
+
outputs="text",
|
| 714 |
+
title="Document Summarization",
|
| 715 |
+
allow_flagging="never",
|
| 716 |
+
description=demo_description5
|
| 717 |
+
),
|
| 718 |
+
gr.Interface(
|
| 719 |
+
fn=multi_agent_chat_advanced,
|
| 720 |
+
inputs=[
|
| 721 |
+
gr.Textbox(label="請輸入查詢內容"),
|
| 722 |
+
gr.File(label="上傳文件 (CSV, PDF, TXT, DOCX)", file_count="single")
|
| 723 |
+
],
|
| 724 |
+
outputs="text",
|
| 725 |
+
title="Multi-Agent AI Assistant",
|
| 726 |
+
allow_flagging="never",
|
| 727 |
+
description=demo_description4
|
| 728 |
+
)
|
| 729 |
+
],
|
| 730 |
+
tab_names=[
|
| 731 |
+
"Biden Q&A (LLaMA)",
|
| 732 |
+
"Biden Q&A (GPT-4)",
|
| 733 |
+
"Your Docs Q&A (Upload + GPT-4)",
|
| 734 |
+
"Document Summarization",
|
| 735 |
+
"Multi-Agent AI Assistant"
|
| 736 |
+
],
|
| 737 |
+
title="RAG + Multi-Agent Platform"
|
| 738 |
)
|
| 739 |
|
| 740 |
if __name__ == "__main__":
|