kingkaikai's picture
upload my tools and updata the app.py
2a8991b verified
raw
history blame
2.7 kB
import os
import re
from langchain.tools import DuckDuckGoSearchRun
from langchain.chains import RetrievalQA
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from datasets import load_dataset
from agent import SmoalAgent
# System prompt for formatting answers
SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""
# Initialize web search tool
search_tool = DuckDuckGoSearchRun()
# Create custom prompt template with system instructions
prompt_template = SYSTEM_PROMPT + "\n\nContext: {context}\nQuestion: {question}\n"
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
# Load GAIA dataset and setup RAG components
def load_gaia_and_setup_rag():
try:
# Load GAIA dataset (requires HUGGINGFACE_HUB_TOKEN)
dataset = load_dataset("GAIA", split="train")
texts = [item['text'] for item in dataset if 'text' in item]
# Create embeddings and vector store
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_texts(texts, embeddings)
# Create retriever and QA chain with custom prompt
retriever = vectorstore.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
llm=SmoalAgent(),
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": PROMPT}
)
return qa_chain
except Exception as e:
print(f"RAG initialization error: {str(e)}")
return None
# Extract final answer from model response
def extract_final_answer(response):
"""Extracts the final answer using the specified template format"""
match = re.search(r"FINAL ANSWER: (.*)", response, re.IGNORECASE)
if match:
return match.group(1).strip()
# Fallback to return full response if pattern not found
return response
# Initialize RAG chain
global rag_chain
rag_chain = load_gaia_and_setup_rag()