kingkaikai's picture
delet import error
521750d verified
import os
import re
from langchain.tools import DuckDuckGoSearchRun
from langchain.chains import RetrievalQA
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from datasets import load_dataset
from smolagents import CodeAgent, DuckDuckGoSearchTool, InferenceClientModel
# System prompt for formatting answers
SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""
# Initialize web search tool
search_tool = DuckDuckGoSearchRun()
# Create custom prompt template with system instructions
prompt_template = SYSTEM_PROMPT + "\n\nContext: {context}\nQuestion: {question}\n"
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
# Load GAIA dataset and setup RAG components
def load_gaia_and_setup_rag():
try:
# Load GAIA dataset (requires HUGGINGFACE_HUB_TOKEN)
dataset = load_dataset("GAIA", split="train")
texts = [item['text'] for item in dataset if 'text' in item]
# Create embeddings and vector store
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_texts(texts, embeddings)
# Create retriever and QA chain with custom prompt
retriever = vectorstore.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
llm=SmoalAgent(),
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": PROMPT}
)
return qa_chain
except Exception as e:
print(f"RAG initialization error: {str(e)}")
return None
# Extract final answer from model response
def extract_final_answer(response):
"""Extracts the final answer using the specified template format"""
match = re.search(r"FINAL ANSWER: (.*)", response, re.IGNORECASE)
if match:
return match.group(1).strip()
# Fallback to return full response if pattern not found
return response
# Initialize RAG chain
global rag_chain
rag_chain = load_gaia_and_setup_rag()
# Initialize search tool
search_tool = DuckDuckGoSearchTool()
# Load GAIA dataset and setup RAG
rag_chain = None
def load_gaia_and_setup_rag():
try:
from datasets import load_dataset
# Load GAIA dataset (test split)
dataset = load_dataset("gaia-benchmark/gaia", split="test")
# Extract contexts from dataset
contexts = [item["context"] for item in dataset if "context" in item and item["context"]]
# Create embeddings and vector store
embeddings = OpenAIEmbeddings()
vector_store = FAISS.from_texts(contexts, embeddings)
# Create retriever
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
# Define prompt template
SYSTEM_PROMPT = """
You are a precise QA system. Answer ONLY with the exact answer, no explanations.
Answers must be in one of these formats:
- A single number
- A single string
- A comma-separated list of numbers or strings
Do not include any additional text, explanations, or formatting.
"""
prompt_template = PromptTemplate(
template=SYSTEM_PROMPT + "\nContext: {context}\nQuestion: {question}\nAnswer:",
input_variables=["context", "question"]
)
# Create RAG chain
global rag_chain
rag_chain = RetrievalQA.from_chain_type(
llm=OpenAI(temperature=0),
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt_template}
)
print(f"Successfully loaded GAIA dataset and created RAG chain with {len(contexts)} contexts")
return True
except Exception as e:
print(f"Error setting up RAG: {e}")
return False
# Initialize RAG when the module is loaded
load_gaia_and_setup_rag()
# Initialize CodeAgent
def initialize_code_agent():
try:
# Initialize model with environment variables
model = InferenceClientModel(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="gpt-3.5-turbo"
)
# Create agent with search tool
agent = CodeAgent(
tools=[search_tool],
model=model
)
print("CodeAgent initialized successfully")
return agent
except Exception as e:
print(f"Error initializing CodeAgent: {e}")
return None
# Final answer extraction
def extract_final_answer(text):
# Use regex to find the final answer pattern
match = re.search(r'FINAL ANSWER: (.*)', text, re.IGNORECASE)
if match:
return match.group(1).strip()
# If no pattern found, return the text as is (with cleanup)
return text.strip()