Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| from langchain.tools import DuckDuckGoSearchRun | |
| from langchain.chains import RetrievalQA | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.prompts import PromptTemplate | |
| from datasets import load_dataset | |
| from agent import SmoalAgent | |
| # System prompt for formatting answers | |
| SYSTEM_PROMPT = """ | |
| You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| """ | |
| # Initialize web search tool | |
| search_tool = DuckDuckGoSearchRun() | |
| # Create custom prompt template with system instructions | |
| prompt_template = SYSTEM_PROMPT + "\n\nContext: {context}\nQuestion: {question}\n" | |
| PROMPT = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| # Load GAIA dataset and setup RAG components | |
| def load_gaia_and_setup_rag(): | |
| try: | |
| # Load GAIA dataset (requires HUGGINGFACE_HUB_TOKEN) | |
| dataset = load_dataset("GAIA", split="train") | |
| texts = [item['text'] for item in dataset if 'text' in item] | |
| # Create embeddings and vector store | |
| embeddings = OpenAIEmbeddings() | |
| vectorstore = FAISS.from_texts(texts, embeddings) | |
| # Create retriever and QA chain with custom prompt | |
| retriever = vectorstore.as_retriever() | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=SmoalAgent(), | |
| chain_type="stuff", | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": PROMPT} | |
| ) | |
| return qa_chain | |
| except Exception as e: | |
| print(f"RAG initialization error: {str(e)}") | |
| return None | |
| # Extract final answer from model response | |
| def extract_final_answer(response): | |
| """Extracts the final answer using the specified template format""" | |
| match = re.search(r"FINAL ANSWER: (.*)", response, re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| # Fallback to return full response if pattern not found | |
| return response | |
| # Initialize RAG chain | |
| global rag_chain | |
| rag_chain = load_gaia_and_setup_rag() |