Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files- .gitattributes +3 -0
- app.py +78 -0
- data/1706.03762v7.pdf +3 -0
- data/NVIDIAAn.pdf +0 -0
- data/Usage policies _ OpenAI.pdf +3 -0
- data/recommendations-for-regulating-ai.pdf +3 -0
- ingest.py +28 -0
- requirements.txt +15 -3
- src/agent.py +68 -0
- src/eval.py +26 -0
- src/processor.py +76 -0
- src/tools.py +36 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/1706.03762v7.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/recommendations-for-regulating-ai.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/Usage[[:space:]]policies[[:space:]]_[[:space:]]OpenAI.pdf filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
# Configure logging
|
| 6 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 7 |
+
logger = logging.getLogger("APP_ENTRY")
|
| 8 |
+
|
| 9 |
+
logger.info("🚀 app.py module loaded. Streamlit starting up...")
|
| 10 |
+
|
| 11 |
+
st.set_page_config(page_title="Gemini Research Assistant", layout="wide")
|
| 12 |
+
st.title("💎 Agentic RAG: Gemini 2.0 Research Assistant")
|
| 13 |
+
|
| 14 |
+
# --- AUTO-INGESTION SEQUENCE ---
|
| 15 |
+
# This ensures the vector DB exists before the agent tries to load it.
|
| 16 |
+
# --- CONFIGURATION ---
|
| 17 |
+
DB_PATH = "./chroma_db"
|
| 18 |
+
DATA_PATH = "./data"
|
| 19 |
+
|
| 20 |
+
@st.cache_resource(show_spinner=False)
|
| 21 |
+
def initialize_knowledge_base():
|
| 22 |
+
"""Checks and builds the vector database if missing."""
|
| 23 |
+
if not os.path.exists(DB_PATH) or not os.listdir(DB_PATH):
|
| 24 |
+
logger.info("⚠️ VectorDB not found. Checking for PDF data...")
|
| 25 |
+
if os.path.exists(DATA_PATH) and any(f.endswith('.pdf') for f in os.listdir(DATA_PATH)):
|
| 26 |
+
logger.info("📄 Data found. Starting ingestion process...")
|
| 27 |
+
# We use a placeholder to show progress since st.spinner isn't thread-safe in early startup sometimes
|
| 28 |
+
status_placeholder = st.empty()
|
| 29 |
+
status_placeholder.info("🧠 Initializing Knowledge Base... Check Logs for progress.")
|
| 30 |
+
|
| 31 |
+
from src.processor import build_index
|
| 32 |
+
try:
|
| 33 |
+
build_index(DATA_PATH, DB_PATH)
|
| 34 |
+
status_placeholder.success("✅ Knowledge Base Built! Refreshing...")
|
| 35 |
+
logger.info("✅ Ingestion complete.")
|
| 36 |
+
status_placeholder.empty()
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"❌ Ingestion FAILED: {e}")
|
| 39 |
+
status_placeholder.error(f"Failed to build index: {e}")
|
| 40 |
+
raise e
|
| 41 |
+
else:
|
| 42 |
+
logger.warning("No data found in 'data' directory.")
|
| 43 |
+
st.warning("⚠️ No data found! Please add PDFs to the 'data' folder to use Local Research.")
|
| 44 |
+
else:
|
| 45 |
+
logger.info("✅ VectorDB exists. Skipping ingestion.")
|
| 46 |
+
|
| 47 |
+
# Run the initialization
|
| 48 |
+
initialize_knowledge_base()
|
| 49 |
+
|
| 50 |
+
# Lazy import agent AFTER DB check to prevent "Table not found" errors
|
| 51 |
+
logger.info("🤖 Loading Agent Logic...")
|
| 52 |
+
from src.agent import app as agent_app
|
| 53 |
+
logger.info("✅ Agent loaded. Ready to serve.")
|
| 54 |
+
|
| 55 |
+
if "messages" not in st.session_state:
|
| 56 |
+
st.session_state.messages = []
|
| 57 |
+
|
| 58 |
+
# Display history
|
| 59 |
+
for msg in st.session_state.messages:
|
| 60 |
+
with st.chat_message(msg["role"]):
|
| 61 |
+
st.markdown(msg["content"])
|
| 62 |
+
|
| 63 |
+
# Chat input
|
| 64 |
+
if prompt := st.chat_input("Ask about internal docs or latest tech..."):
|
| 65 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 66 |
+
with st.chat_message("user"):
|
| 67 |
+
st.markdown(prompt)
|
| 68 |
+
|
| 69 |
+
with st.chat_message("assistant"):
|
| 70 |
+
inputs = {"messages": [("user", prompt)]}
|
| 71 |
+
config = {"configurable": {"thread_id": "1"}}
|
| 72 |
+
|
| 73 |
+
# Execute LangGraph brain
|
| 74 |
+
response = agent_app.invoke(inputs, config=config)
|
| 75 |
+
answer = response["messages"][-1].content
|
| 76 |
+
|
| 77 |
+
st.markdown(answer)
|
| 78 |
+
st.session_state.messages.append({"role": "assistant", "content": answer})
|
data/1706.03762v7.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bdfaa68d8984f0dc02beaca527b76f207d99b666d31d1da728ee0728182df697
|
| 3 |
+
size 2215244
|
data/NVIDIAAn.pdf
ADDED
|
Binary file (90.6 kB). View file
|
|
|
data/Usage policies _ OpenAI.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65b3fd64e61ca4bdeac41fea6c44d8a927a3e16129b94af7118196848f0c7c6f
|
| 3 |
+
size 145434
|
data/recommendations-for-regulating-ai.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:abbefed52379ac6bd793071bc603e174b233718a4fc8ad32ac304edae1e39425
|
| 3 |
+
size 316312
|
ingest.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from src.processor import build_index
|
| 4 |
+
|
| 5 |
+
DATA_DIR = "./data"
|
| 6 |
+
DB_DIR = "./chroma_db"
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
print(f"Checking for data in {DATA_DIR}...")
|
| 10 |
+
if not os.path.exists(DATA_DIR):
|
| 11 |
+
print(f"Create a '{DATA_DIR}' directory and put your PDFs there.")
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
if not any(f.endswith(".pdf") for f in os.listdir(DATA_DIR)):
|
| 15 |
+
print("No PDF files found in data directory.")
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
print("Building Vector Database... (This may take a while for large docs)")
|
| 19 |
+
|
| 20 |
+
# Optional: Clear old DB if you want a fresh start every time
|
| 21 |
+
# if os.path.exists(DB_DIR):
|
| 22 |
+
# shutil.rmtree(DB_DIR)
|
| 23 |
+
|
| 24 |
+
vectorstore = build_index(DATA_DIR, DB_DIR)
|
| 25 |
+
print(f"Success! Database built at {DB_DIR}")
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
main()
|
requirements.txt
CHANGED
|
@@ -1,3 +1,15 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langchain
|
| 2 |
+
langchain-google-genai
|
| 3 |
+
langchain-chroma
|
| 4 |
+
langchain-community
|
| 5 |
+
langchain-text-splitters
|
| 6 |
+
langgraph
|
| 7 |
+
streamlit
|
| 8 |
+
chromadb
|
| 9 |
+
pypdf
|
| 10 |
+
ftfy
|
| 11 |
+
unidecode
|
| 12 |
+
ragas
|
| 13 |
+
datasets
|
| 14 |
+
duckduckgo-search
|
| 15 |
+
python-dotenv
|
src/agent.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated, List, TypedDict
|
| 2 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 3 |
+
from langgraph.graph import StateGraph, START, END
|
| 4 |
+
from langgraph.prebuilt import ToolNode
|
| 5 |
+
from src.tools import tools
|
| 6 |
+
|
| 7 |
+
# The Local Knowledge Registry (Update this whenever you add new data types)
|
| 8 |
+
# Manual as of now
|
| 9 |
+
LOCAL_MANIFEST = {
|
| 10 |
+
"topics": ["HR Policies", "Project X Design Docs", "Q3 Financials", "Employee Handbook"],
|
| 11 |
+
"date_range": "Documents updated as of Dec 2024",
|
| 12 |
+
"domain": "Internal Corporate Knowledge"
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
SYSTEM_PROMPT = f"""
|
| 16 |
+
You are an expert Research Assistant. You have access to:
|
| 17 |
+
1. INTERNAL DATA: {LOCAL_MANIFEST['topics']}. (Use 'local_research_tool')
|
| 18 |
+
2. EXTERNAL DATA: The entire internet via duckduckgosearch. (Use 'web_search_tool')
|
| 19 |
+
|
| 20 |
+
GUIDELINES:
|
| 21 |
+
- Given the user's technical question and the fact that our internal documents are insufficient, generate a generic search query for the internet that does NOT include any proprietary names or internal details.
|
| 22 |
+
- If a query is about {LOCAL_MANIFEST['topics']}, try LOCAL first.
|
| 23 |
+
- If a query is TECHNICAL (e.g., PyTorch, Python APIs) or REAL-TIME, go to WEB immediately.
|
| 24 |
+
- If the query is ambiguous, try LOCAL first, then fallback to WEB if the results are empty or low confidence.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AgentState(TypedDict):
|
| 29 |
+
"""MNC Agent state with intent tracking and sufficiency grading."""
|
| 30 |
+
messages: Annotated[List, "Chat history"]
|
| 31 |
+
intent: str
|
| 32 |
+
is_sufficient: bool
|
| 33 |
+
|
| 34 |
+
# Brain: Gemini 2.0 Flash (high-speed agentic reasoning)
|
| 35 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
| 36 |
+
llm_with_tools = llm.bind_tools(tools)
|
| 37 |
+
|
| 38 |
+
def router(state: AgentState):
|
| 39 |
+
"""Classifies user intent to prioritize retrieval paths."""
|
| 40 |
+
query = state['messages'][-1].content
|
| 41 |
+
prompt = f"Categorize intent: TECHNICAL (API/Docs), INTERNAL (Proprietary), or REALTIME. Query: {query}"
|
| 42 |
+
response = llm.invoke(prompt)
|
| 43 |
+
intent = "TECHNICAL" if any(x in response.content.upper() for x in ["TECHNICAL", "REALTIME"]) else "INTERNAL"
|
| 44 |
+
return {"intent": intent}
|
| 45 |
+
|
| 46 |
+
def call_model(state: AgentState):
|
| 47 |
+
"""Invokes Gemini with tools based on intent and history."""
|
| 48 |
+
return {"messages": [llm_with_tools.invoke(state['messages'])]}
|
| 49 |
+
|
| 50 |
+
# Orchestration Graph
|
| 51 |
+
workflow = StateGraph(AgentState)
|
| 52 |
+
|
| 53 |
+
workflow.add_node("router", router)
|
| 54 |
+
workflow.add_node("llm", call_model)
|
| 55 |
+
workflow.add_node("tools", ToolNode(tools))
|
| 56 |
+
|
| 57 |
+
workflow.add_edge(START, "router")
|
| 58 |
+
workflow.add_edge("router", "llm")
|
| 59 |
+
|
| 60 |
+
# Self-Correction Loop
|
| 61 |
+
def should_continue(state: AgentState):
|
| 62 |
+
last_msg = state["messages"][-1]
|
| 63 |
+
return "tools" if last_msg.tool_calls else END
|
| 64 |
+
|
| 65 |
+
workflow.add_conditional_edges("llm", should_continue)
|
| 66 |
+
workflow.add_edge("tools", "llm")
|
| 67 |
+
|
| 68 |
+
app = workflow.compile()
|
src/eval.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import Dataset
|
| 2 |
+
from ragas import evaluate
|
| 3 |
+
from ragas.llms import llm_factory
|
| 4 |
+
from ragas.metrics.collections import Faithfulness, ResponseRelevancy, ContextPrecision
|
| 5 |
+
from src.agent import app
|
| 6 |
+
|
| 7 |
+
# Ragas 2025: Experiment-based factory
|
| 8 |
+
judge_llm = llm_factory("gemini-2.0-flash")
|
| 9 |
+
|
| 10 |
+
def evaluate_agent(questions: list, references: list):
|
| 11 |
+
"""MNC-grade verification of RAG pipeline quality."""
|
| 12 |
+
results = []
|
| 13 |
+
for q, r in zip(questions, references):
|
| 14 |
+
output = app.invoke({"messages": [("user", q)]})
|
| 15 |
+
results.append({
|
| 16 |
+
"user_input": q,
|
| 17 |
+
"response": output["messages"][-1].content,
|
| 18 |
+
"retrieved_contexts": [m.content for m in output["messages"] if hasattr(m, "tool_call_id")],
|
| 19 |
+
"reference": r
|
| 20 |
+
})
|
| 21 |
+
|
| 22 |
+
dataset = Dataset.from_list(results)
|
| 23 |
+
metrics = [Faithfulness(), ResponseRelevancy(), ContextPrecision()]
|
| 24 |
+
|
| 25 |
+
# Evaluate with Gemini judge
|
| 26 |
+
return evaluate(dataset=dataset, metrics=metrics, llm=judge_llm)
|
src/processor.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import unicodedata
|
| 3 |
+
import hashlib
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import ftfy
|
| 6 |
+
import unidecode
|
| 7 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 8 |
+
from langchain_chroma import Chroma
|
| 9 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 10 |
+
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
# Configure logging to show up in Docker logs
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
def clean_text(text: str) -> str:
|
| 19 |
+
"""MNC-grade scrubbing for structural and encoding noise."""
|
| 20 |
+
# 1. Structural Scrubbing
|
| 21 |
+
text = re.sub(r'Page\s+\d+\s+of\s+\d+', '', text, flags=re.IGNORECASE)
|
| 22 |
+
text = re.sub(r'\b\d+\s*/\s*\d+\b', '', text)
|
| 23 |
+
text = re.sub(r'^\s*-\s*\d+\s*-\s*$', '', text, flags=re.MULTILINE)
|
| 24 |
+
text = re.sub(r'[-*_]{3,}', '', text)
|
| 25 |
+
|
| 26 |
+
# 2. Encoding Repairs
|
| 27 |
+
text = ftfy.fix_text(text)
|
| 28 |
+
text = unidecode.unidecode(text)
|
| 29 |
+
text = unicodedata.normalize('NFKC', text)
|
| 30 |
+
|
| 31 |
+
# 3. Whitespace Normalization
|
| 32 |
+
text = re.sub(r'[\t\xa0]', ' ', text)
|
| 33 |
+
text = re.sub(r'(?<=[a-z])\n(?=[a-z])', ' ', text) # Fix mid-sentence breaks
|
| 34 |
+
text = re.sub(r' +', ' ', text)
|
| 35 |
+
return text.strip()
|
| 36 |
+
|
| 37 |
+
def build_index(data_dir: str, persist_dir: str):
|
| 38 |
+
"""Processes messy data into a professional vector store."""
|
| 39 |
+
logger.info(f"Starting ingestion from: {data_dir}")
|
| 40 |
+
|
| 41 |
+
loader = DirectoryLoader(data_dir, glob="**/*.pdf", loader_cls=PyPDFLoader)
|
| 42 |
+
raw_docs = loader.load()
|
| 43 |
+
logger.info(f"Loaded {len(raw_docs)} documents.")
|
| 44 |
+
|
| 45 |
+
# Gemini 2025 standard embedding model
|
| 46 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
|
| 47 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 48 |
+
chunk_size=1200, chunk_overlap=150, add_start_index=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
final_chunks = []
|
| 52 |
+
for i, doc in enumerate(raw_docs):
|
| 53 |
+
logger.info(f"Processing doc {i+1}/{len(raw_docs)}: {doc.metadata.get('source', 'unknown')}")
|
| 54 |
+
cleaned_content = clean_text(doc.page_content)
|
| 55 |
+
source_name = Path(doc.metadata.get("source", "unknown")).name
|
| 56 |
+
|
| 57 |
+
# Metadata extraction for citations
|
| 58 |
+
metadata = {
|
| 59 |
+
"source": source_name,
|
| 60 |
+
"page": doc.metadata.get("page", 1),
|
| 61 |
+
"chunk_hash": hashlib.md5(cleaned_content.encode()).hexdigest()
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
chunks = splitter.create_documents([cleaned_content], metadatas=[metadata])
|
| 65 |
+
final_chunks.extend(chunks)
|
| 66 |
+
|
| 67 |
+
logger.info(f"Total chunks created: {len(final_chunks)}")
|
| 68 |
+
logger.info(f"Persisting to VectorDB at: {persist_dir}")
|
| 69 |
+
|
| 70 |
+
vectorstore = Chroma.from_documents(
|
| 71 |
+
documents=final_chunks,
|
| 72 |
+
embedding=embeddings,
|
| 73 |
+
persist_directory=persist_dir
|
| 74 |
+
)
|
| 75 |
+
logger.info("VectorDB successfully built and persisted.")
|
| 76 |
+
return vectorstore
|
src/tools.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 3 |
+
from langchain_chroma import Chroma
|
| 4 |
+
from langchain_community.tools.duckduckgo_search import DuckDuckGoSearchResults
|
| 5 |
+
from langchain.tools import tool
|
| 6 |
+
|
| 7 |
+
# Persistent storage setup
|
| 8 |
+
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
|
| 9 |
+
vector_db = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
|
| 10 |
+
|
| 11 |
+
@tool
|
| 12 |
+
def local_research_tool(query: str, search_type: Literal["similarity", "mmr"] = "similarity"):
|
| 13 |
+
"""
|
| 14 |
+
Searches the internal corporate knowledge base.
|
| 15 |
+
Use 'similarity' for exact facts and 'mmr' for broad, diverse research.
|
| 16 |
+
"""
|
| 17 |
+
retriever = vector_db.as_retriever(
|
| 18 |
+
search_type=search_type,
|
| 19 |
+
search_kwargs={"k": 5, "fetch_k": 20, "lambda_mult": 0.5}
|
| 20 |
+
)
|
| 21 |
+
docs = retriever.invoke(query)
|
| 22 |
+
|
| 23 |
+
# Formatted for model synthesis with citations
|
| 24 |
+
formatted = [
|
| 25 |
+
f"SOURCE: {d.metadata['source']} (Pg. {d.metadata['page']})\nCONTENT: {d.page_content}"
|
| 26 |
+
for d in docs
|
| 27 |
+
]
|
| 28 |
+
return "\n---\n".join(formatted)
|
| 29 |
+
|
| 30 |
+
# Gemini-optimized web search fallback
|
| 31 |
+
web_search_tool = DuckDuckGoSearchResults(
|
| 32 |
+
k=3,
|
| 33 |
+
description="Search the internet for real-time data, technical APIs (like PyTorch/LangChain), or news."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
tools = [local_research_tool, web_search_tool]
|