Spaces:
Runtime error
Runtime error
Upload app (3).py
Browse files- app (3).py +651 -0
app (3).py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
import transformers
|
| 7 |
+
import chardet
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
# --- 解決 Matplotlib 權限問題 ---
|
| 14 |
+
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
|
| 15 |
+
|
| 16 |
+
# === Enterprise Environment Settings ===
|
| 17 |
+
os.environ["HOME"] = "/tmp"
|
| 18 |
+
os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
|
| 19 |
+
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 20 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
| 21 |
+
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
|
| 22 |
+
os.environ["HF_METRICS_CACHE"] = "/tmp/huggingface/metrics"
|
| 23 |
+
os.environ["GRADIO_FLAGGING_DIR"] = "/tmp/flagged"
|
| 24 |
+
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/sentence_transformers"
|
| 25 |
+
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hf_cache"
|
| 26 |
+
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
|
| 27 |
+
|
| 28 |
+
# === Load Required Modules ===
|
| 29 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 30 |
+
from langchain.vectorstores import Chroma, FAISS
|
| 31 |
+
from langchain.chains import RetrievalQA
|
| 32 |
+
from langchain.prompts import PromptTemplate
|
| 33 |
+
from langchain.llms import HuggingFacePipeline
|
| 34 |
+
from langchain.chat_models import ChatOpenAI
|
| 35 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 36 |
+
from langchain.memory import ConversationBufferMemory
|
| 37 |
+
from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader
|
| 38 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 39 |
+
from langchain.chains.summarize import load_summarize_chain
|
| 40 |
+
from tempfile import mkdtemp
|
| 41 |
+
import pandas as pd
|
| 42 |
+
|
| 43 |
+
# === Multi-Agent Imports ===
|
| 44 |
+
from serpapi import GoogleSearch
|
| 45 |
+
# CrewAI 部分:完全使用 CrewAI 的 Agent、Task、Crew 與 @tool 裝飾器
|
| 46 |
+
from crewai import Crew, Agent, Task, Process
|
| 47 |
+
from crewai.tools import tool
|
| 48 |
+
from langchain_experimental.agents import create_pandas_dataframe_agent
|
| 49 |
+
|
| 50 |
+
session_retriever = None
|
| 51 |
+
session_qa_chain = None
|
| 52 |
+
csv_dataframe = None # CSV tool will use this
|
| 53 |
+
# === Model and Device Setup ===
|
| 54 |
+
if torch.backends.mps.is_available():
|
| 55 |
+
device = "mps"
|
| 56 |
+
elif torch.cuda.is_available():
|
| 57 |
+
device = "cuda"
|
| 58 |
+
else:
|
| 59 |
+
device = "cpu"
|
| 60 |
+
print(f"Using device => {device}")
|
| 61 |
+
|
| 62 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 63 |
+
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
| 64 |
+
model_id = "ChienChung/my-llama-1b"
|
| 65 |
+
|
| 66 |
+
config_path = hf_hub_download(
|
| 67 |
+
repo_id=model_id,
|
| 68 |
+
filename="config.json",
|
| 69 |
+
use_auth_token=hf_token,
|
| 70 |
+
cache_dir="/tmp/huggingface"
|
| 71 |
+
)
|
| 72 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 73 |
+
config_dict = json.load(f)
|
| 74 |
+
if "rope_scaling" in config_dict:
|
| 75 |
+
config_dict["rope_scaling"] = {"type": "dynamic", "factor": config_dict["rope_scaling"].get("factor", 32.0)}
|
| 76 |
+
model_config = LlamaConfig.from_dict(config_dict)
|
| 77 |
+
model_config.trust_remote_code = True
|
| 78 |
+
|
| 79 |
+
print("Loading Llama model...")
|
| 80 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
+
model_id,
|
| 82 |
+
config=model_config,
|
| 83 |
+
trust_remote_code=True,
|
| 84 |
+
use_auth_token=hf_token,
|
| 85 |
+
cache_dir="/tmp/huggingface"
|
| 86 |
+
)
|
| 87 |
+
model.to(device)
|
| 88 |
+
print("Model loaded!")
|
| 89 |
+
|
| 90 |
+
print("Loading tokenizer...")
|
| 91 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 92 |
+
model_id,
|
| 93 |
+
trust_remote_code=True,
|
| 94 |
+
use_auth_token=hf_token,
|
| 95 |
+
cache_dir="/tmp/huggingface"
|
| 96 |
+
)
|
| 97 |
+
if tokenizer.pad_token is None:
|
| 98 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 99 |
+
print("Tokenizer loaded!")
|
| 100 |
+
|
| 101 |
+
query_pipeline = transformers.pipeline(
|
| 102 |
+
"text-generation",
|
| 103 |
+
model=model,
|
| 104 |
+
tokenizer=tokenizer,
|
| 105 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 106 |
+
device_map="auto" if device != "cpu" else None,
|
| 107 |
+
do_sample=False,
|
| 108 |
+
temperature=0.0,
|
| 109 |
+
max_new_tokens=200,
|
| 110 |
+
return_full_text=False
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# === Chroma DB and Document Retrieval Setup ===
|
| 114 |
+
print("Loading Chroma DB for Biden Speech...")
|
| 115 |
+
if not os.path.exists("/tmp/chroma_db"):
|
| 116 |
+
shutil.copytree("./chroma_db", "/tmp/chroma_db")
|
| 117 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
| 118 |
+
vectordb = Chroma(persist_directory="/tmp/chroma_db", embedding_function=embeddings)
|
| 119 |
+
retriever = vectordb.as_retriever()
|
| 120 |
+
|
| 121 |
+
custom_prompt = PromptTemplate(
|
| 122 |
+
input_variables=["context", "question"],
|
| 123 |
+
template="""You are a helpful AI assistant. Use only the text from the context below to answer the user's question.
|
| 124 |
+
If the answer is not in the context, say "No relevant info found."
|
| 125 |
+
If the question is not in the context, say "No relevant info found."
|
| 126 |
+
|
| 127 |
+
Return only the final answer in one to three sentences.
|
| 128 |
+
Do not restate the question or context.
|
| 129 |
+
Do not include these instructions in your final output.
|
| 130 |
+
|
| 131 |
+
Context:
|
| 132 |
+
{context}
|
| 133 |
+
|
| 134 |
+
Question: {question}
|
| 135 |
+
|
| 136 |
+
Answer:
|
| 137 |
+
"""
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
llm_local = HuggingFacePipeline(pipeline=query_pipeline)
|
| 141 |
+
llm_gpt4 = ChatOpenAI(model_name="gpt-4o-mini", temperature=0.2, openai_api_key=openai_api_key)
|
| 142 |
+
crew_llm = ChatOpenAI(
|
| 143 |
+
model_name="gpt-4o-mini",
|
| 144 |
+
temperature=0.2,
|
| 145 |
+
openai_api_key=openai_api_key
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
| 149 |
+
qa_gpt = ConversationalRetrievalChain.from_llm(
|
| 150 |
+
llm=llm_gpt4,
|
| 151 |
+
retriever=retriever,
|
| 152 |
+
memory=memory,
|
| 153 |
+
combine_docs_chain_kwargs={"prompt": custom_prompt}
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# === Helper Function:從上傳文件中提取檔案路徑 ===
|
| 157 |
+
def get_file_path(file):
|
| 158 |
+
if isinstance(file, str):
|
| 159 |
+
return file
|
| 160 |
+
elif isinstance(file, dict):
|
| 161 |
+
# 優先使用 "data" 鍵,其次是 "name"
|
| 162 |
+
return file.get("data", file.get("name", None))
|
| 163 |
+
elif hasattr(file, "save"):
|
| 164 |
+
temp_dir = mkdtemp()
|
| 165 |
+
file_path = os.path.join(temp_dir, file.name)
|
| 166 |
+
file.save(file_path)
|
| 167 |
+
return file_path
|
| 168 |
+
else:
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
# === 原有功能(Tab 1~Tab 4)函式 ===
|
| 172 |
+
def rag_llama_qa(query):
|
| 173 |
+
output = RetrievalQA.from_chain_type(
|
| 174 |
+
llm=llm_local,
|
| 175 |
+
chain_type="stuff",
|
| 176 |
+
retriever=retriever,
|
| 177 |
+
return_source_documents=False,
|
| 178 |
+
chain_type_kwargs={"prompt": custom_prompt}
|
| 179 |
+
).run(query)
|
| 180 |
+
lower_text = output.lower()
|
| 181 |
+
idx = lower_text.find("answer:")
|
| 182 |
+
return output[idx + len("answer:"):].strip() if idx != -1 else output
|
| 183 |
+
|
| 184 |
+
def rag_gpt4_qa(query):
|
| 185 |
+
return qa_gpt.run(query)
|
| 186 |
+
|
| 187 |
+
def upload_and_chat(file, query):
|
| 188 |
+
file_path = get_file_path(file)
|
| 189 |
+
if file_path is None:
|
| 190 |
+
return "Unable to obtain the uploaded file path."
|
| 191 |
+
if file_path.lower().endswith(".pdf"):
|
| 192 |
+
loader = PyPDFLoader(file_path)
|
| 193 |
+
elif file_path.lower().endswith(".docx"):
|
| 194 |
+
loader = UnstructuredWordDocumentLoader(file_path)
|
| 195 |
+
else:
|
| 196 |
+
loader = TextLoader(file_path)
|
| 197 |
+
docs = loader.load()
|
| 198 |
+
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
|
| 199 |
+
db = FAISS.from_documents(chunks, embeddings)
|
| 200 |
+
temp_retriever = db.as_retriever()
|
| 201 |
+
qa_temp = RetrievalQA.from_chain_type(
|
| 202 |
+
llm=llm_gpt4,
|
| 203 |
+
chain_type="stuff",
|
| 204 |
+
retriever=temp_retriever,
|
| 205 |
+
return_source_documents=False,
|
| 206 |
+
chain_type_kwargs={"prompt": custom_prompt}
|
| 207 |
+
)
|
| 208 |
+
return qa_temp.run(query)
|
| 209 |
+
|
| 210 |
+
initial_prompt = PromptTemplate(
|
| 211 |
+
input_variables=["text"],
|
| 212 |
+
template="""Write a concise and structured summary of the following content. Focus on capturing the main ideas and key details:
|
| 213 |
+
|
| 214 |
+
{text}
|
| 215 |
+
|
| 216 |
+
--- Summary ---
|
| 217 |
+
"""
|
| 218 |
+
)
|
| 219 |
+
refine_prompt = PromptTemplate(
|
| 220 |
+
input_variables=["existing_answer", "text"],
|
| 221 |
+
template="""You already have an existing summary:
|
| 222 |
+
{existing_answer}
|
| 223 |
+
|
| 224 |
+
Refine the summary based on the new content below. Add or update information only if it's relevant. Keep it concise:
|
| 225 |
+
|
| 226 |
+
{text}
|
| 227 |
+
|
| 228 |
+
--- Refined Summary ---
|
| 229 |
+
"""
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def document_summarize(file):
|
| 233 |
+
file_path = get_file_path(file)
|
| 234 |
+
if file_path is None:
|
| 235 |
+
return "Unable to obtain the uploaded file."
|
| 236 |
+
if file_path.lower().endswith(".pdf"):
|
| 237 |
+
loader = PyPDFLoader(file_path)
|
| 238 |
+
elif file_path.lower().endswith(".docx"):
|
| 239 |
+
loader = UnstructuredWordDocumentLoader(file_path)
|
| 240 |
+
else:
|
| 241 |
+
loader = TextLoader(file_path)
|
| 242 |
+
docs = loader.load()
|
| 243 |
+
summarize_chain = load_summarize_chain(llm_gpt4, chain_type="refine", question_prompt=initial_prompt, refine_prompt=refine_prompt)
|
| 244 |
+
summary = summarize_chain.invoke(docs)
|
| 245 |
+
return summary['output_text']
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def csv_agent(file, query):
|
| 249 |
+
file_path = get_file_path(file)
|
| 250 |
+
if file_path is None:
|
| 251 |
+
return "Unable to obtain the uploaded CSV file."
|
| 252 |
+
try:
|
| 253 |
+
with open(file_path, 'rb') as f:
|
| 254 |
+
result = chardet.detect(f.read())
|
| 255 |
+
encoding = result['encoding']
|
| 256 |
+
df = pd.read_csv(file_path, encoding=encoding)
|
| 257 |
+
except Exception as e:
|
| 258 |
+
return f"Error reading CSV: {e}"
|
| 259 |
+
safe_dict = {"df": df, "pd": pd}
|
| 260 |
+
try:
|
| 261 |
+
result = eval(query, {"__builtins__": {}}, safe_dict)
|
| 262 |
+
return str(result)
|
| 263 |
+
except Exception as e:
|
| 264 |
+
return f"Query error: {e}"
|
| 265 |
+
|
| 266 |
+
def search_agent(query):
|
| 267 |
+
api_key = os.environ.get("SERPAPI_API_KEY")
|
| 268 |
+
if not api_key:
|
| 269 |
+
return "SERPAPI_API_KEY not set. Please set the environment variable."
|
| 270 |
+
params = {"engine": "google", "q": query, "api_key": api_key, "num": 5}
|
| 271 |
+
search = GoogleSearch(params)
|
| 272 |
+
results = search.get_dict()
|
| 273 |
+
if "organic_results" in results:
|
| 274 |
+
output = ""
|
| 275 |
+
for result in results["organic_results"]:
|
| 276 |
+
title = result.get("title", "No Title")
|
| 277 |
+
link = result.get("link", "No Link")
|
| 278 |
+
snippet = result.get("snippet", "No Snippet")
|
| 279 |
+
output += f"Title: {title}\nLink: {link}\nSnippet: {snippet}\n\n"
|
| 280 |
+
return output.strip()
|
| 281 |
+
else:
|
| 282 |
+
return "No results found."
|
| 283 |
+
|
| 284 |
+
def uploaded_qa(file, query):
|
| 285 |
+
file_path = get_file_path(file)
|
| 286 |
+
if file_path is None:
|
| 287 |
+
return "Unable to obtain the uploaded file path."
|
| 288 |
+
if file_path.lower().endswith(".pdf"):
|
| 289 |
+
loader = PyPDFLoader(file_path)
|
| 290 |
+
elif file_path.lower().endswith(".docx"):
|
| 291 |
+
loader = UnstructuredWordDocumentLoader(file_path)
|
| 292 |
+
else:
|
| 293 |
+
loader = TextLoader(file_path)
|
| 294 |
+
docs = loader.load()
|
| 295 |
+
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
|
| 296 |
+
db = FAISS.from_documents(chunks, embeddings)
|
| 297 |
+
temp_retriever = db.as_retriever()
|
| 298 |
+
qa_temp = RetrievalQA.from_chain_type(
|
| 299 |
+
llm=llm_gpt4,
|
| 300 |
+
chain_type="stuff",
|
| 301 |
+
retriever=temp_retriever,
|
| 302 |
+
return_source_documents=False,
|
| 303 |
+
chain_type_kwargs={"prompt": custom_prompt}
|
| 304 |
+
)
|
| 305 |
+
return qa_temp.run(query)
|
| 306 |
+
|
| 307 |
+
# === CrewAI Multi-Agent System (Tab 5) ===
|
| 308 |
+
# 完全捨棄 langchain.agents.Tool,使用 CrewAI 的 @tool 裝飾器來定義工具
|
| 309 |
+
from pydantic import BaseModel
|
| 310 |
+
class SimpleQuery(BaseModel):
|
| 311 |
+
query: str
|
| 312 |
+
|
| 313 |
+
@tool("summarise")
|
| 314 |
+
def summarise_tool(query: str) -> str:
|
| 315 |
+
"""Summarise: Use document summarisation functionality."""
|
| 316 |
+
global session_retriever, session_qa_chain
|
| 317 |
+
if session_retriever is None:
|
| 318 |
+
return "尚未上傳文件。"
|
| 319 |
+
try:
|
| 320 |
+
docs = session_retriever.get_relevant_documents(query if query.strip() else "summary")
|
| 321 |
+
if not docs:
|
| 322 |
+
return "無法從文件中找到相關內容。"
|
| 323 |
+
summarize_chain = load_summarize_chain(llm_gpt4, chain_type="refine", question_prompt=initial_prompt, refine_prompt=refine_prompt)
|
| 324 |
+
summary = summarize_chain.invoke(docs)
|
| 325 |
+
return summary['output_text']
|
| 326 |
+
except Exception as e:
|
| 327 |
+
return f"摘要錯誤: {e}"
|
| 328 |
+
|
| 329 |
+
@tool("python_calc")
|
| 330 |
+
def python_calc_tool(query: str) -> str:
|
| 331 |
+
"""Python Calculation: Perform basic arithmetic or logical operations."""
|
| 332 |
+
try:
|
| 333 |
+
return str(eval(query))
|
| 334 |
+
except Exception as e:
|
| 335 |
+
return f"計算錯誤: {e}"
|
| 336 |
+
|
| 337 |
+
@tool("search_agent")
|
| 338 |
+
def search_tool_func(query: str) -> str:
|
| 339 |
+
"""Search: Perform web searches using external search engines."""
|
| 340 |
+
return search_agent(query)
|
| 341 |
+
|
| 342 |
+
@tool("uploaded_qa")
|
| 343 |
+
def uploaded_qa_tool_func(query: str) -> str:
|
| 344 |
+
"""Document QA: Answer questions based on the uploaded document content."""
|
| 345 |
+
global session_qa_chain
|
| 346 |
+
if session_qa_chain is not None:
|
| 347 |
+
try:
|
| 348 |
+
return session_qa_chain.run(query)
|
| 349 |
+
except Exception as e:
|
| 350 |
+
return f"文檔問答錯誤: {e}"
|
| 351 |
+
else:
|
| 352 |
+
return "尚未上傳文件。"
|
| 353 |
+
|
| 354 |
+
@tool("csv_agent")
|
| 355 |
+
def csv_tool_func(query: str) -> str:
|
| 356 |
+
"""CSV Agent: Use natural language to analyse uploaded CSV files."""
|
| 357 |
+
global csv_dataframe
|
| 358 |
+
if csv_dataframe is None:
|
| 359 |
+
return "No CSV file uploaded."
|
| 360 |
+
try:
|
| 361 |
+
agent = create_pandas_dataframe_agent(llm=llm_gpt4, df=csv_dataframe, verbose=True)
|
| 362 |
+
return agent.run(f"Here is the table:\n{csv_dataframe.head().to_string(index=False)}\n\n{query}")
|
| 363 |
+
except Exception as e:
|
| 364 |
+
return f"CSV Agent error: {e}"
|
| 365 |
+
|
| 366 |
+
# 建立 CrewAI 代理(僅針對 Tab 5)
|
| 367 |
+
summarizer_agent = Agent(
|
| 368 |
+
role="Document Summarizer",
|
| 369 |
+
goal="Summarise the content of the uploaded document.",
|
| 370 |
+
backstory="You are a professional summarisation expert who can identify key points in long documents.",
|
| 371 |
+
tools=[summarise_tool],
|
| 372 |
+
verbose=True
|
| 373 |
+
)
|
| 374 |
+
document_qa_agent = Agent(
|
| 375 |
+
role="Document QA Specialist",
|
| 376 |
+
goal="Answer questions based on the uploaded document.",
|
| 377 |
+
backstory="You are an expert in document understanding and can accurately extract answers.",
|
| 378 |
+
tools=[uploaded_qa_tool_func],
|
| 379 |
+
verbose=True
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
search_agent = Agent(
|
| 383 |
+
role="Search Expert",
|
| 384 |
+
goal="Search the web and provide relevant information.",
|
| 385 |
+
backstory="You are an expert at finding relevant information from the internet.",
|
| 386 |
+
tools=[search_tool_func],
|
| 387 |
+
verbose=True
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
math_agent = Agent(
|
| 391 |
+
role="Math Assistant",
|
| 392 |
+
goal="Perform accurate arithmetic or logical calculations.",
|
| 393 |
+
backstory="You are a calculator expert skilled at quick computations.",
|
| 394 |
+
tools=[python_calc_tool],
|
| 395 |
+
verbose=True
|
| 396 |
+
)
|
| 397 |
+
csv_agent = Agent(
|
| 398 |
+
role="CSV Analyst",
|
| 399 |
+
goal="Analyse tabular data and answer questions about the uploaded CSV file.",
|
| 400 |
+
backstory="You are skilled in interpreting tabular datasets and can extract numerical or logical insights.",
|
| 401 |
+
tools=[csv_tool_func],
|
| 402 |
+
verbose=True
|
| 403 |
+
)
|
| 404 |
+
router_agent = Agent(
|
| 405 |
+
role="Query Router",
|
| 406 |
+
goal="Determine the most suitable agent or tool to handle the user query.",
|
| 407 |
+
backstory="You are an intelligent query dispatcher that analyses the user's intent and chooses the best AI agent to answer.",
|
| 408 |
+
tools=[python_calc_tool, search_tool_func, csv_tool_func, uploaded_qa_tool_func, summarise_tool],
|
| 409 |
+
verbose=True
|
| 410 |
+
)
|
| 411 |
+
router_task = Task(
|
| 412 |
+
description=(
|
| 413 |
+
"Based on the user's query, decide which agent or tool is best suited to handle it:\n"
|
| 414 |
+
"- If the query is related to the content of an uploaded file (e.g., 'what is this document about?'), send it to the **Document QA Agent**.\n"
|
| 415 |
+
"- If the query contains words like 'summarize', 'summary', or 'main points', use the **Summarizer Agent**.\n"
|
| 416 |
+
"- If the query involves numbers, calculations, or logic (e.g., '50 * 23 - 5', 'what is 10% of 800'), send it to the **Math Agent**.\n"
|
| 417 |
+
"- If the user uploaded a CSV file and asks about table content, data trends, or uses words like 'data', 'table', 'csv', 'column', or 'row', send it to the **CSV Agent**.\n"
|
| 418 |
+
"- If the user asks about current events, trending topics, or online information (e.g., 'What is LangChain?', 'latest news'), send it to the **Search Agent**.\n"
|
| 419 |
+
"- If none of these apply, use your best judgment to choose the most relevant agent."
|
| 420 |
+
),
|
| 421 |
+
expected_output="The final answer from the selected agent or tool.",
|
| 422 |
+
agent=router_agent,
|
| 423 |
+
input_variables=["query"]
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
crew = Crew(
|
| 427 |
+
agents=[summarizer_agent, document_qa_agent, search_agent, math_agent, csv_agent],
|
| 428 |
+
tasks=[router_task],
|
| 429 |
+
process=Process.sequential,
|
| 430 |
+
verbose=True,
|
| 431 |
+
llm=crew_llm
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
def multi_agent_chat(query: str) -> str:
|
| 435 |
+
print(f"Routing query: {query}")
|
| 436 |
+
try:
|
| 437 |
+
result = crew.kickoff(inputs={"query": query})
|
| 438 |
+
result_str = str(result)
|
| 439 |
+
if "I don't know." in result_str or result_str.strip() == "":
|
| 440 |
+
return search_agent(query) # fallback 給搜尋
|
| 441 |
+
return f"[Agent: {result.agent_name}]\n{result.output}" #result_str
|
| 442 |
+
except Exception as e:
|
| 443 |
+
return f"Error: {e}"
|
| 444 |
+
|
| 445 |
+
def multi_agent_chat_advanced(query: str, file=None) -> str:
|
| 446 |
+
global session_retriever, session_qa_chain
|
| 447 |
+
|
| 448 |
+
# 判斷是否為與文件無關的查詢
|
| 449 |
+
non_doc_keywords = ["calculate", "sum", "date", "time", "how many", "how much", "weather", "temperature"]
|
| 450 |
+
use_file_chain = True
|
| 451 |
+
for kw in non_doc_keywords:
|
| 452 |
+
if kw in query.lower():
|
| 453 |
+
use_file_chain = False
|
| 454 |
+
break
|
| 455 |
+
|
| 456 |
+
if file is not None:
|
| 457 |
+
file_path = get_file_path(file)
|
| 458 |
+
if file_path is None:
|
| 459 |
+
return "Unable to process the file format."
|
| 460 |
+
|
| 461 |
+
# === CSV 處理 ===
|
| 462 |
+
if file_path.lower().endswith(".csv"):
|
| 463 |
+
global csv_dataframe
|
| 464 |
+
try:
|
| 465 |
+
with open(file_path, 'rb') as f:
|
| 466 |
+
result = chardet.detect(f.read())
|
| 467 |
+
encoding = result['encoding']
|
| 468 |
+
df = pd.read_csv(file_path, encoding=encoding)
|
| 469 |
+
csv_dataframe = df
|
| 470 |
+
result = crew.kickoff(inputs={"query": query})
|
| 471 |
+
return f"[Agent: {result.agent_name}]\n{result.output}"
|
| 472 |
+
except Exception as e:
|
| 473 |
+
return f"Error reading CSV: {e}"
|
| 474 |
+
|
| 475 |
+
# === 文本類型文件(PDF / DOCX / TXT) ===
|
| 476 |
+
elif file_path.lower().endswith((".pdf", ".txt", ".docx")):
|
| 477 |
+
loader = (
|
| 478 |
+
PyPDFLoader(file_path) if file_path.lower().endswith(".pdf")
|
| 479 |
+
else UnstructuredWordDocumentLoader(file_path) if file_path.lower().endswith(".docx")
|
| 480 |
+
else TextLoader(file_path)
|
| 481 |
+
)
|
| 482 |
+
docs = loader.load()
|
| 483 |
+
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
|
| 484 |
+
db = FAISS.from_documents(chunks, embeddings)
|
| 485 |
+
session_retriever = db.as_retriever()
|
| 486 |
+
session_qa_chain = ConversationalRetrievalChain.from_llm(
|
| 487 |
+
llm=llm_gpt4,
|
| 488 |
+
retriever=session_retriever,
|
| 489 |
+
memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True),
|
| 490 |
+
#combine_docs_chain_kwargs={"prompt": custom_prompt}
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# 決定使用摘要還是 QA
|
| 494 |
+
if any(kw in query.lower() for kw in ["summarize", "summary", "摘要", "總結"]):
|
| 495 |
+
return document_summarize(file_path)
|
| 496 |
+
elif use_file_chain:
|
| 497 |
+
try:
|
| 498 |
+
return session_qa_chain.run(query)
|
| 499 |
+
except Exception as e:
|
| 500 |
+
return f"Error: {e}"
|
| 501 |
+
else:
|
| 502 |
+
try:
|
| 503 |
+
result = crew.kickoff(inputs={"query": query})
|
| 504 |
+
return f"[Agent: {result.agent_name}]\n{result.output}"
|
| 505 |
+
except Exception as e:
|
| 506 |
+
return f"Error: {e}"
|
| 507 |
+
|
| 508 |
+
else:
|
| 509 |
+
return "Unsupported file format."
|
| 510 |
+
|
| 511 |
+
# 沒有上傳新檔案
|
| 512 |
+
elif session_qa_chain is not None:
|
| 513 |
+
if use_file_chain:
|
| 514 |
+
try:
|
| 515 |
+
return session_qa_chain.run(query)
|
| 516 |
+
except Exception as e:
|
| 517 |
+
return f"Error: {e}"
|
| 518 |
+
else:
|
| 519 |
+
try:
|
| 520 |
+
result = crew.kickoff(inputs={"query": query})
|
| 521 |
+
return f"[Agent: {result.agent_name}]\n{result.output}"
|
| 522 |
+
except Exception as e:
|
| 523 |
+
return f"Error: {e}"
|
| 524 |
+
|
| 525 |
+
# 沒有 session,直接丟給 CrewAI
|
| 526 |
+
else:
|
| 527 |
+
try:
|
| 528 |
+
result = crew.kickoff(inputs={"query": query})
|
| 529 |
+
return f"[Agent: {result.agent_name}]\n{result.output}"
|
| 530 |
+
except Exception as e:
|
| 531 |
+
return f"Error: {e}"
|
| 532 |
+
|
| 533 |
+
# === Gradio Interface Settings ===
|
| 534 |
+
demo_description = """
|
| 535 |
+
**Context**:
|
| 536 |
+
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
| 537 |
+
Biden’s 2023 State of the Union Address.
|
| 538 |
+
All responses are grounded in this document.
|
| 539 |
+
If no relevant information is found in the document, the system will say "No relevant info found."
|
| 540 |
+
|
| 541 |
+
**Sample Questions**:
|
| 542 |
+
1. What were the main topics regarding infrastructure in this speech?
|
| 543 |
+
2. How does the speech address the competition with China?
|
| 544 |
+
3. What does Biden say about job growth in the past two years?
|
| 545 |
+
4. Does the speech mention anything about Social Security or Medicare?
|
| 546 |
+
5. What does the speech propose regarding Big Tech or online privacy?
|
| 547 |
+
|
| 548 |
+
*Note: The LLaMA module generates responses based solely on the current query without follow-up memory or chat history management.*
|
| 549 |
+
|
| 550 |
+
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
|
| 551 |
+
"""
|
| 552 |
+
demo_description2 = """
|
| 553 |
+
**Context**:
|
| 554 |
+
This demo uses a Retrieval-Augmented Generation (RAG) system based on
|
| 555 |
+
Biden’s 2023 State of the Union Address.
|
| 556 |
+
All responses are grounded in this document.
|
| 557 |
+
If no relevant information is found in the document, the system will say "No relevant info found."
|
| 558 |
+
|
| 559 |
+
**Sample Questions**:
|
| 560 |
+
1. What were the main topics regarding infrastructure in this speech?
|
| 561 |
+
2. How does the speech address the competition with China?
|
| 562 |
+
3. What does Biden say about job growth in the past two years?
|
| 563 |
+
4. Does the speech mention anything about Social Security or Medicare?
|
| 564 |
+
5. What does the speech propose regarding Big Tech or online privacy?
|
| 565 |
+
|
| 566 |
+
*Note: The GPT module supports follow-up questions with conversation history management, enabling more interactive and context-aware discussions.*
|
| 567 |
+
|
| 568 |
+
Feel free to ask any question related to Biden’s 2023 State of the Union Address.
|
| 569 |
+
"""
|
| 570 |
+
demo_description3 = """
|
| 571 |
+
**Context**:
|
| 572 |
+
Upload a PDF, TXT, or DOCX file and ask a question about its content.
|
| 573 |
+
This demo uses GPT-4 to answer questions based on the content of your uploaded document.
|
| 574 |
+
|
| 575 |
+
Feel free to ask any question related to your document.
|
| 576 |
+
"""
|
| 577 |
+
demo_description4 = """
|
| 578 |
+
**Context**:
|
| 579 |
+
This assistant performs multi-agent tasks using tools such as:
|
| 580 |
+
- Document summarisation
|
| 581 |
+
- FAQ-style document Q&A
|
| 582 |
+
- Financial or CSV-style logic queries
|
| 583 |
+
- Multi-step reasoning via agent orchestration
|
| 584 |
+
|
| 585 |
+
The system will automatically select the appropriate function based on the uploaded file (CSV, PDF, TXT, DOCX) and the query content.
|
| 586 |
+
For example, if the query contains "summarize"/"摘要", it will summarize the document; if it's CSV data, it will perform data analysis.
|
| 587 |
+
"""
|
| 588 |
+
demo_description5 = """
|
| 589 |
+
**Context**:
|
| 590 |
+
This demo uses Document Summarization via a Map-Reduce chain.
|
| 591 |
+
Upload a PDF, TXT, or DOCX file to get an automatic summary of its contents.
|
| 592 |
+
"""
|
| 593 |
+
|
| 594 |
+
demo = gr.TabbedInterface(
|
| 595 |
+
interface_list=[
|
| 596 |
+
gr.Interface(
|
| 597 |
+
fn=rag_llama_qa,
|
| 598 |
+
inputs="text",
|
| 599 |
+
outputs="text",
|
| 600 |
+
title="Biden Q&A (LLaMA)",
|
| 601 |
+
allow_flagging="never",
|
| 602 |
+
description=demo_description
|
| 603 |
+
),
|
| 604 |
+
gr.Interface(
|
| 605 |
+
fn=rag_gpt4_qa,
|
| 606 |
+
inputs="text",
|
| 607 |
+
outputs="text",
|
| 608 |
+
title="Biden Q&A (GPT-4)",
|
| 609 |
+
allow_flagging="never",
|
| 610 |
+
description=demo_description2
|
| 611 |
+
),
|
| 612 |
+
gr.Interface(
|
| 613 |
+
fn=upload_and_chat,
|
| 614 |
+
inputs=[gr.File(label="Upload PDF, TXT, or DOCX"), gr.Textbox(label="Ask a question")],
|
| 615 |
+
outputs="text",
|
| 616 |
+
title="Your Docs Q&A (Upload + GPT-4)",
|
| 617 |
+
allow_flagging="never",
|
| 618 |
+
description=demo_description3
|
| 619 |
+
),
|
| 620 |
+
gr.Interface(
|
| 621 |
+
fn=document_summarize,
|
| 622 |
+
inputs=[gr.File(label="Upload PDF, TXT, or DOCX")],
|
| 623 |
+
outputs="text",
|
| 624 |
+
title="Document Summarization",
|
| 625 |
+
allow_flagging="never",
|
| 626 |
+
description=demo_description5
|
| 627 |
+
),
|
| 628 |
+
gr.Interface(
|
| 629 |
+
fn=multi_agent_chat_advanced,
|
| 630 |
+
inputs=[
|
| 631 |
+
gr.Textbox(label="請輸入查詢內容"),
|
| 632 |
+
gr.File(label="上傳文件 (CSV, PDF, TXT, DOCX)", file_count="single")
|
| 633 |
+
],
|
| 634 |
+
outputs="text",
|
| 635 |
+
title="Multi-Agent AI Assistant",
|
| 636 |
+
allow_flagging="never",
|
| 637 |
+
description=demo_description4
|
| 638 |
+
)
|
| 639 |
+
],
|
| 640 |
+
tab_names=[
|
| 641 |
+
"Biden Q&A (LLaMA)",
|
| 642 |
+
"Biden Q&A (GPT-4)",
|
| 643 |
+
"Your Docs Q&A (Upload + GPT-4)",
|
| 644 |
+
"Document Summarization",
|
| 645 |
+
"Multi-Agent AI Assistant"
|
| 646 |
+
],
|
| 647 |
+
title="RAG + Multi-Agent Platform"
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
if __name__ == "__main__":
|
| 651 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|