|
|
import json |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
import numpy as np |
|
|
|
|
|
class RAGEngine: |
|
|
def __init__(self, json_path): |
|
|
self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
with open(json_path, 'r') as f: |
|
|
self.data = json.load(f) |
|
|
self.texts = [] |
|
|
self.build_corpus() |
|
|
self.build_index() |
|
|
|
|
|
def build_corpus(self): |
|
|
|
|
|
self.texts = [ |
|
|
f"Product: {item['product_name']}\nCategory: {item['category']}\nPolicy: {item['return_policy']}\nReason: {item['return_reason']}" |
|
|
for item in self.data |
|
|
] |
|
|
|
|
|
def build_index(self): |
|
|
embeddings = self.embedder.encode(self.texts, convert_to_numpy=True) |
|
|
dim = embeddings.shape[1] |
|
|
self.index = faiss.IndexFlatL2(dim) |
|
|
self.index.add(embeddings) |
|
|
|
|
|
def retrieve(self, query, top_k=3): |
|
|
query_emb = self.embedder.encode([query], convert_to_numpy=True) |
|
|
distances, indices = self.index.search(query_emb, top_k) |
|
|
results = [self.texts[idx] for idx in indices[0] if idx != -1] |
|
|
return results |
|
|
|