File size: 5,538 Bytes
c613356
 
 
 
 
 
 
 
a075fae
c613356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a075fae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import re
from langchain.tools import DuckDuckGoSearchRun
from langchain.chains import RetrievalQA
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from datasets import load_dataset
from smolagents import CodeAgent, DuckDuckGoSearchTool, InferenceClientModel

# System prompt for formatting answers
SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""

# Initialize web search tool
search_tool = DuckDuckGoSearchRun()

# Create custom prompt template with system instructions
prompt_template = SYSTEM_PROMPT + "\n\nContext: {context}\nQuestion: {question}\n"
PROMPT = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)

# Load GAIA dataset and setup RAG components
def load_gaia_and_setup_rag():
    try:
        # Load GAIA dataset (requires HUGGINGFACE_HUB_TOKEN)
        dataset = load_dataset("GAIA", split="train")
        texts = [item['text'] for item in dataset if 'text' in item]
        
        # Create embeddings and vector store
        embeddings = OpenAIEmbeddings()
        vectorstore = FAISS.from_texts(texts, embeddings)
        
        # Create retriever and QA chain with custom prompt
        retriever = vectorstore.as_retriever()
        qa_chain = RetrievalQA.from_chain_type(
            llm=SmoalAgent(),
            chain_type="stuff",
            retriever=retriever,
            chain_type_kwargs={"prompt": PROMPT}
        )
        return qa_chain
    except Exception as e:
        print(f"RAG initialization error: {str(e)}")
        return None

# Extract final answer from model response
def extract_final_answer(response):
    """Extracts the final answer using the specified template format"""
    match = re.search(r"FINAL ANSWER: (.*)", response, re.IGNORECASE)
    if match:
        return match.group(1).strip()
    # Fallback to return full response if pattern not found
    return response

# Initialize RAG chain
global rag_chain
rag_chain = load_gaia_and_setup_rag()

# Initialize search tool
search_tool = DuckDuckGoSearchTool()

# Load GAIA dataset and setup RAG
rag_chain = None

def load_gaia_and_setup_rag():
    try:
        from datasets import load_dataset
        # Load GAIA dataset (test split)
        dataset = load_dataset("gaia-benchmark/gaia", split="test")
        
        # Extract contexts from dataset
        contexts = [item["context"] for item in dataset if "context" in item and item["context"]]
        
        # Create embeddings and vector store
        embeddings = OpenAIEmbeddings()
        vector_store = FAISS.from_texts(contexts, embeddings)
        
        # Create retriever
        retriever = vector_store.as_retriever(search_kwargs={"k": 3})
        
        # Define prompt template
        SYSTEM_PROMPT = """
        You are a precise QA system. Answer ONLY with the exact answer, no explanations.
        Answers must be in one of these formats:
        - A single number
        - A single string
        - A comma-separated list of numbers or strings
        Do not include any additional text, explanations, or formatting.
        """
        
        prompt_template = PromptTemplate(
            template=SYSTEM_PROMPT + "\nContext: {context}\nQuestion: {question}\nAnswer:",
            input_variables=["context", "question"]
        )
        
        # Create RAG chain
        global rag_chain
        rag_chain = RetrievalQA.from_chain_type(
            llm=OpenAI(temperature=0),
            chain_type="stuff",
            retriever=retriever,
            chain_type_kwargs={"prompt": prompt_template}
        )
        
        print(f"Successfully loaded GAIA dataset and created RAG chain with {len(contexts)} contexts")
        return True
    except Exception as e:
        print(f"Error setting up RAG: {e}")
        return False

# Initialize RAG when the module is loaded
load_gaia_and_setup_rag()

# Initialize CodeAgent
def initialize_code_agent():
    try:
        # Initialize model with environment variables
        model = InferenceClientModel(
            api_key=os.getenv("OPENAI_API_KEY"),
            model_name="gpt-3.5-turbo"
        )
        
        # Create agent with search tool
        agent = CodeAgent(
            tools=[search_tool],
            model=model
        )
        
        print("CodeAgent initialized successfully")
        return agent
    except Exception as e:
        print(f"Error initializing CodeAgent: {e}")
        return None

# Final answer extraction
def extract_final_answer(text):
    # Use regex to find the final answer pattern
    match = re.search(r'FINAL ANSWER: (.*)', text, re.IGNORECASE)
    if match:
        return match.group(1).strip()
    # If no pattern found, return the text as is (with cleanup)
    return text.strip()