chatbot / utils /retriever.py
raksa-the-wildcats
Add complete accessibility chatbot with knowledge base
39d67a2
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
class KnowledgeRetriever:
def __init__(self, knowledge_base_path="knowledge_base.json"):
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
# Load knowledge base
with open(knowledge_base_path, 'r') as f:
self.kb = json.load(f)
self.content = self.kb['content']
self.embeddings = np.array(self.kb['embeddings'])
def retrieve_relevant_content(self, query, top_k=5, min_similarity=0.3):
"""Retrieve most relevant content for the query"""
# Encode query
query_embedding = self.embedder.encode([query])
# Calculate similarities
similarities = cosine_similarity(query_embedding, self.embeddings)[0]
# Get top results above threshold
top_indices = np.argsort(similarities)[-top_k:][::-1]
relevant_content = []
for idx in top_indices:
if similarities[idx] >= min_similarity:
content_item = self.content[idx].copy()
content_item['similarity_score'] = float(similarities[idx])
relevant_content.append(content_item)
return relevant_content
def format_context_for_llm(self, relevant_content):
"""Format retrieved content for LLM context"""
if not relevant_content:
return "No relevant information found in WebAIM resources."
context = "Relevant information from WebAIM resources:\n\n"
for i, item in enumerate(relevant_content, 1):
context += f"[Source {i}] From {item['source_file']} (Page {item['page_number']}):\n"
context += f"{item['text']}\n\n"
return context