adasRAGassist / retriever.py
gk2410's picture
Update retriever.py
7610555 verified
# retriever.py
import json
import networkx as nx
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.schema import Document
# 1. Load ADAS logs from JSONL
def load_logs(path="adas_logs.jsonl"):
with open(path) as f:
data = [json.loads(line) for line in f]
return [Document(page_content=d["context"], metadata={"event": d["event"]}) for d in data]
# 2. Build FAISS vectorstore
def get_vectorstore(documents):
splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=20)
chunks = splitter.split_documents(documents)
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
return FAISS.from_documents(chunks, embeddings)
# 3. Build simple graph from documents
def build_graph(documents):
G = nx.Graph()
for doc in documents:
event = doc.metadata.get("event", "unknown_event")
context = doc.page_content
# Add event node
if not G.has_node(event):
G.add_node(event, type="event")
# Add context sentence nodes
sentences = context.split(". ")
for s in sentences:
if s.strip():
s_node = s.strip()
G.add_node(s_node, type="context")
G.add_edge(event, s_node)
return G
# 4. Graph search: retrieve related context by keyword match
def graph_search(graph, query, top_k=3):
matches = []
for node, data in graph.nodes(data=True):
if query.lower() in node.lower():
matches.append(node)
return matches[:top_k]
# 5. Hybrid retrieval from FAISS + Graph
def hybrid_retrieve(vectorstore, graph, query):
# FAISS semantic search
faiss_docs = vectorstore.similarity_search(query, k=3)
faiss_context = "\n".join([doc.page_content for doc in faiss_docs])
# Graph keyword match
graph_context = "\n".join(graph_search(graph, query))
# Merge both
return faiss_context + "\n" + graph_context