returnpoilcytest / rag_utils.py
mali08890's picture
Upload 5 files
f3c8548 verified
import json
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
class RAGEngine:
def __init__(self, json_path):
self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
with open(json_path, 'r') as f:
self.data = json.load(f)
self.texts = []
self.build_corpus()
self.build_index()
def build_corpus(self):
# Combine multiple fields for better context
self.texts = [
f"Product: {item['product_name']}\nCategory: {item['category']}\nPolicy: {item['return_policy']}\nReason: {item['return_reason']}"
for item in self.data
]
def build_index(self):
embeddings = self.embedder.encode(self.texts, convert_to_numpy=True)
dim = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dim)
self.index.add(embeddings)
def retrieve(self, query, top_k=3):
query_emb = self.embedder.encode([query], convert_to_numpy=True)
distances, indices = self.index.search(query_emb, top_k)
results = [self.texts[idx] for idx in indices[0] if idx != -1]
return results