Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| from langchain.tools import DuckDuckGoSearchRun | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.prompts import PromptTemplate | |
| from datasets import load_dataset | |
| from smolagents import CodeAgent, DuckDuckGoSearchTool, InferenceClientModel | |
| # System prompt for formatting answers | |
| SYSTEM_PROMPT = """ | |
| You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| """ | |
| # Initialize web search tool | |
| search_tool = DuckDuckGoSearchRun() | |
| # Create custom prompt template with system instructions | |
| prompt_template = SYSTEM_PROMPT + "\n\nContext: {context}\nQuestion: {question}\n" | |
| PROMPT = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| # Load GAIA dataset and setup RAG components | |
| def load_gaia_and_setup_rag(): | |
| try: | |
| # Load GAIA dataset (requires HUGGINGFACE_HUB_TOKEN) | |
| dataset = load_dataset("GAIA", split="train") | |
| texts = [item['text'] for item in dataset if 'text' in item] | |
| # Create embeddings and vector store | |
| embeddings = OpenAIEmbeddings() | |
| vectorstore = FAISS.from_texts(texts, embeddings) | |
| # Create retriever and QA chain with custom prompt | |
| retriever = vectorstore.as_retriever() | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=SmoalAgent(), | |
| chain_type="stuff", | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": PROMPT} | |
| ) | |
| return qa_chain | |
| except Exception as e: | |
| print(f"RAG initialization error: {str(e)}") | |
| return None | |
| # Extract final answer from model response | |
| def extract_final_answer(response): | |
| """Extracts the final answer using the specified template format""" | |
| match = re.search(r"FINAL ANSWER: (.*)", response, re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| # Fallback to return full response if pattern not found | |
| return response | |
| # Initialize RAG chain | |
| global rag_chain | |
| rag_chain = load_gaia_and_setup_rag() | |
| # Initialize search tool | |
| search_tool = DuckDuckGoSearchTool() | |
| # Load GAIA dataset and setup RAG | |
| rag_chain = None | |
| def load_gaia_and_setup_rag(): | |
| try: | |
| from datasets import load_dataset | |
| # Load GAIA dataset (test split) | |
| dataset = load_dataset("gaia-benchmark/gaia", split="test") | |
| # Extract contexts from dataset | |
| contexts = [item["context"] for item in dataset if "context" in item and item["context"]] | |
| # Create embeddings and vector store | |
| embeddings = OpenAIEmbeddings() | |
| vector_store = FAISS.from_texts(contexts, embeddings) | |
| # Create retriever | |
| retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
| # Define prompt template | |
| SYSTEM_PROMPT = """ | |
| You are a precise QA system. Answer ONLY with the exact answer, no explanations. | |
| Answers must be in one of these formats: | |
| - A single number | |
| - A single string | |
| - A comma-separated list of numbers or strings | |
| Do not include any additional text, explanations, or formatting. | |
| """ | |
| prompt_template = PromptTemplate( | |
| template=SYSTEM_PROMPT + "\nContext: {context}\nQuestion: {question}\nAnswer:", | |
| input_variables=["context", "question"] | |
| ) | |
| # Create RAG chain | |
| global rag_chain | |
| rag_chain = RetrievalQA.from_chain_type( | |
| llm=OpenAI(temperature=0), | |
| chain_type="stuff", | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": prompt_template} | |
| ) | |
| print(f"Successfully loaded GAIA dataset and created RAG chain with {len(contexts)} contexts") | |
| return True | |
| except Exception as e: | |
| print(f"Error setting up RAG: {e}") | |
| return False | |
| # Initialize RAG when the module is loaded | |
| load_gaia_and_setup_rag() | |
| # Initialize CodeAgent | |
| def initialize_code_agent(): | |
| try: | |
| # Initialize model with environment variables | |
| model = InferenceClientModel( | |
| api_key=os.getenv("OPENAI_API_KEY"), | |
| model_name="gpt-3.5-turbo" | |
| ) | |
| # Create agent with search tool | |
| agent = CodeAgent( | |
| tools=[search_tool], | |
| model=model | |
| ) | |
| print("CodeAgent initialized successfully") | |
| return agent | |
| except Exception as e: | |
| print(f"Error initializing CodeAgent: {e}") | |
| return None | |
| # Final answer extraction | |
| def extract_final_answer(text): | |
| # Use regex to find the final answer pattern | |
| match = re.search(r'FINAL ANSWER: (.*)', text, re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| # If no pattern found, return the text as is (with cleanup) | |
| return text.strip() |