{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "5b9902eb", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "903fa846aba0401ba19cada8e6c8822c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Split strings: 0%| | 0/165 [00:00 \u001b[39m\u001b[32m56\u001b[39m supabase: Client = \u001b[43mcreate_client\u001b[49m\u001b[43m(\u001b[49m\u001b[43msupabase_url\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msupabase_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 58\u001b[39m \u001b[38;5;66;03m# Reranker Model (ModernBERT Cross-Encoder)\u001b[39;00m\n\u001b[32m 59\u001b[39m reranker = CrossEncoder(config[\u001b[33m\"\u001b[39m\u001b[33mmodels\u001b[39m\u001b[33m\"\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33mreranker\u001b[39m\u001b[33m\"\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33mmodel_name\u001b[39m\u001b[33m\"\u001b[39m], cache_folder=config[\u001b[33m\"\u001b[39m\u001b[33mmodels\u001b[39m\u001b[33m\"\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33mcache_folder\u001b[39m\u001b[33m\"\u001b[39m])\n", "\u001b[36mFile \u001b[39m\u001b[32m~/projects/gaia/.venv/lib/python3.13/site-packages/supabase/_sync/client.py:371\u001b[39m, in \u001b[36mcreate_client\u001b[39m\u001b[34m(supabase_url, supabase_key, options)\u001b[39m\n\u001b[32m 340\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcreate_client\u001b[39m(\n\u001b[32m 341\u001b[39m supabase_url: \u001b[38;5;28mstr\u001b[39m,\n\u001b[32m 342\u001b[39m supabase_key: \u001b[38;5;28mstr\u001b[39m,\n\u001b[32m 343\u001b[39m options: Optional[ClientOptions] = \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m 344\u001b[39m ) -> Client:\n\u001b[32m 345\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Create client function to instantiate supabase client like JS runtime.\u001b[39;00m\n\u001b[32m 346\u001b[39m \n\u001b[32m 347\u001b[39m \u001b[33;03m Parameters\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 369\u001b[39m \u001b[33;03m Client\u001b[39;00m\n\u001b[32m 370\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m371\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mClient\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 372\u001b[39m \u001b[43m \u001b[49m\u001b[43msupabase_url\u001b[49m\u001b[43m=\u001b[49m\u001b[43msupabase_url\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msupabase_key\u001b[49m\u001b[43m=\u001b[49m\u001b[43msupabase_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m=\u001b[49m\u001b[43moptions\u001b[49m\n\u001b[32m 373\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/projects/gaia/.venv/lib/python3.13/site-packages/supabase/_sync/client.py:104\u001b[39m, in \u001b[36mClient.create\u001b[39m\u001b[34m(cls, supabase_url, supabase_key, options)\u001b[39m\n\u001b[32m 96\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 97\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcreate\u001b[39m(\n\u001b[32m 98\u001b[39m \u001b[38;5;28mcls\u001b[39m,\n\u001b[32m (...)\u001b[39m\u001b[32m 101\u001b[39m options: Optional[ClientOptions] = \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m 102\u001b[39m ) -> \u001b[33m\"\u001b[39m\u001b[33mClient\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 103\u001b[39m auth_header = options.headers.get(\u001b[33m\"\u001b[39m\u001b[33mAuthorization\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m options \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m104\u001b[39m client = \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msupabase_url\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msupabase_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 106\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m auth_header \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 107\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n", "\u001b[36mFile \u001b[39m\u001b[32m~/projects/gaia/.venv/lib/python3.13/site-packages/supabase/_sync/client.py:56\u001b[39m, in \u001b[36mClient.__init__\u001b[39m\u001b[34m(self, supabase_url, supabase_key, options)\u001b[39m\n\u001b[32m 42\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"Instantiate the client.\u001b[39;00m\n\u001b[32m 43\u001b[39m \n\u001b[32m 44\u001b[39m \u001b[33;03mParameters\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 52\u001b[39m \u001b[33;03m `DEFAULT_OPTIONS` dict.\u001b[39;00m\n\u001b[32m 53\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m supabase_url:\n\u001b[32m---> \u001b[39m\u001b[32m56\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m SupabaseException(\u001b[33m\"\u001b[39m\u001b[33msupabase_url is required\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 57\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m supabase_key:\n\u001b[32m 58\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m SupabaseException(\u001b[33m\"\u001b[39m\u001b[33msupabase_key is required\u001b[39m\u001b[33m\"\u001b[39m)\n", "\u001b[31mSupabaseException\u001b[39m: supabase_url is required" ] } ], "source": [ "\"\"\"\n", "GAIA Agent with Multi-Modal File Processing and Hybrid Retrieval.\n", "\n", "This module defines a LangGraph agent that can:\n", "1. Retrieve similar questions using Hybrid Search (Vector + BM25) and Reranking\n", "2. Process files using tools (PDF, XLSX, MP3, etc.)\n", "3. Answer questions using web search, calculator, and other tools\n", "\"\"\"\n", "\n", "import os\n", "import bm25s\n", "from dotenv import load_dotenv\n", "\n", "from langgraph.graph import START, END, StateGraph\n", "from langgraph.prebuilt import tools_condition, ToolNode\n", "\n", "from sentence_transformers import SentenceTransformer, CrossEncoder\n", "from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint\n", "from langchain_core.messages import HumanMessage\n", "from supabase.client import Client, create_client\n", "\n", "from utils import load_config, load_prompt, init_bm25_index, reciprocal_rank_fusion\n", "from tools import tools_list\n", "from states import AgentState\n", "\n", "load_dotenv()\n", "config = load_config()\n", "\n", "# Environment details and others\n", "hf_key = os.getenv(\"HF_INFERENCE_KEY\")\n", "supabase_url = os.getenv(\"SUPABASE_URL\")\n", "supabase_key = os.getenv(\"SUPABASE_SERVICE_KEY\")\n", "\n", "# ============================================\n", "# Model & Embeddings Setup\n", "# ============================================\n", "\n", "# Model List\n", "\n", "llm_model_name = \"Qwen/Qwen3-32B-Instruct\"\n", "\n", "enable_keyword_search = config[\"retrievers\"][\"enable_keyword_search\"]\n", "enable_vector_search = config[\"retrievers\"][\"enable_vector_search\"]\n", "\n", "# BM25 Retriever\n", "bm25_retriever, bm25_corpus, bm25_ids = None, None, None\n", "if enable_keyword_search:\n", " bm25_retriever, bm25_corpus, bm25_ids = init_bm25_index(corpus_file=config[\"data\"])\n", "\n", "embeddings, supabase = None, None\n", "if enable_vector_search:\n", " # Embeddings for Vector Search\n", " embeddings = SentenceTransformer(model_name_or_path=config[\"models\"][\"embeddings\"][\"model_name\"], cache_folder=config[\"models\"][\"cache_folder\"])\n", "\n", " # Supabase Vector Store\n", " supabase: Client = create_client(supabase_url, supabase_key)\n", "\n", "# Reranker Model (ModernBERT Cross-Encoder)\n", "reranker = CrossEncoder(config[\"models\"][\"reranker\"][\"model_name\"], cache_folder=config[\"models\"][\"cache_folder\"])\n", "\n", "# LLM for Agent\n", "llm = HuggingFaceEndpoint(\n", " repo_id=config[\"models\"][\"llm\"][\"model_name\"],\n", " temperature=config[\"models\"][\"llm\"][\"parameters\"][\"temperature\"],\n", " repetition_penalty=config[\"models\"][\"llm\"][\"parameters\"][\"repetition_penalty\"],\n", " provider=config[\"models\"][\"llm\"][\"parameters\"][\"provider\"],\n", " huggingfacehub_api_token=hf_key\n", ")\n", "\n", "agent_llm = ChatHuggingFace(llm=llm)\n", "agent_with_tools = agent_llm.bind_tools(tools_list)\n", "\n", "\n", "# ============================================\n", "# Graph Nodes\n", "# ============================================\n", "\n", "def retriever_node(state: AgentState) -> AgentState:\n", " \"\"\"\n", " Hybrid Search Node: Retrieve docs via Vector Search + BM25, combine with RRF.\n", " \"\"\"\n", " print(\"--- RETRIEVER NODE ---\")\n", " messages = state.get(\"messages\", [])\n", " if not messages:\n", " return {\"retrieved_docs\": []}\n", " \n", " question_content = messages[0].content\n", " \n", " if not enable_vector_search and not enable_keyword_search:\n", " print(\"No retrieval method enabled.\")\n", " return {\"retrieved_docs\": []}\n", " \n", " # 1. Vector Search\n", " vector_docs = []\n", " if supabase and embeddings:\n", " try:\n", " response = supabase.rpc(\n", " config[\"retrievers\"][\"vector_store\"][\"query\"],\n", " {\"query_embedding\": embeddings.encode(question_content).tolist(), \n", " \"match_count\": config[\"retrievers\"][\"vector_store\"][\"k\"], \n", " \"match_threshold\": config[\"retrievers\"][\"vector_store\"][\"threshold\"]\n", " } \n", " ).execute()\n", "\n", " vector_docs = response.data\n", "\n", " except Exception as e:\n", " print(f\"Vector search error: {e}\")\n", " \n", " # 2. BM25 Search\n", " bm25_docs = []\n", " if bm25_retriever and bm25_corpus and bm25_ids:\n", " try:\n", " query_tokens = bm25s.tokenize([question_content], stopwords=\"en\")\n", " results, scores = bm25_retriever.retrieve(query_tokens, k=config[\"retrievers\"][\"bm25\"][\"k\"])\n", " indices = results[0]\n", " \n", " for i, idx in enumerate(indices):\n", " content = bm25_corpus[idx]\n", " task_id = bm25_ids[idx]\n", " score = scores[0][i]\n", " bm25_dict = {\"content\":content, \"metadata\": {\"source\": \"bm25_search\", \"task_id\": task_id, \"score\": score}}\n", " bm25_docs.append(bm25_dict)\n", " except Exception as e:\n", " print(f\"BM25 search error: {e}\")\n", "\n", " # 3. RRF Fusion\n", " final_candidates = []\n", " if vector_docs and bm25_docs:\n", " fused = reciprocal_rank_fusion([vector_docs, bm25_docs])\n", " final_candidates = [id for id, doc, score in fused]\n", " else:\n", " final_candidates = vector_docs + bm25_docs\n", " final_candidates = [doc[\"metadata\"][\"task_id\"] for doc in final_candidates]\n", " \n", " top_candidates = final_candidates[:20]\n", " \n", " return {\"retrieved_docs\": top_candidates}" ] } ], "metadata": { "kernelspec": { "display_name": "gaia", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 5 }