File size: 2,703 Bytes
2a8991b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import re
from langchain.tools import DuckDuckGoSearchRun
from langchain.chains import RetrievalQA
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from datasets import load_dataset
from agent import SmoalAgent

# System prompt for formatting answers
SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""

# Initialize web search tool
search_tool = DuckDuckGoSearchRun()

# Create custom prompt template with system instructions
prompt_template = SYSTEM_PROMPT + "\n\nContext: {context}\nQuestion: {question}\n"
PROMPT = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)

# Load GAIA dataset and setup RAG components
def load_gaia_and_setup_rag():
    try:
        # Load GAIA dataset (requires HUGGINGFACE_HUB_TOKEN)
        dataset = load_dataset("GAIA", split="train")
        texts = [item['text'] for item in dataset if 'text' in item]
        
        # Create embeddings and vector store
        embeddings = OpenAIEmbeddings()
        vectorstore = FAISS.from_texts(texts, embeddings)
        
        # Create retriever and QA chain with custom prompt
        retriever = vectorstore.as_retriever()
        qa_chain = RetrievalQA.from_chain_type(
            llm=SmoalAgent(),
            chain_type="stuff",
            retriever=retriever,
            chain_type_kwargs={"prompt": PROMPT}
        )
        return qa_chain
    except Exception as e:
        print(f"RAG initialization error: {str(e)}")
        return None

# Extract final answer from model response
def extract_final_answer(response):
    """Extracts the final answer using the specified template format"""
    match = re.search(r"FINAL ANSWER: (.*)", response, re.IGNORECASE)
    if match:
        return match.group(1).strip()
    # Fallback to return full response if pattern not found
    return response

# Initialize RAG chain
global rag_chain
rag_chain = load_gaia_and_setup_rag()