alaselababatunde's picture
Updated
2a14080
import pandas as pd
import chromadb
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
import os
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RagEngine:
def __init__(self, data_path="data/tesco_faq.csv", collection_name="tesco_faq"):
self.data_path = data_path
self.collection_name = collection_name
self.client = chromadb.PersistentClient(path="./chroma_db")
# specific embedding function using sentence-transformers
self.sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
embedding_function=self.sentence_transformer_ef
)
# Check if collection is empty, if so, ingest
if self.collection.count() == 0:
self.ingest_data()
def ingest_data(self):
logger.info(f"Ingesting data from {self.data_path}...")
try:
df = pd.read_csv(self.data_path)
documents = []
metadatas = []
ids = []
for index, row in df.iterrows():
# Construct a meaningful document from the row
# We want the model to see the context: Topic, Subtopic, Question, Answer
topic = row.get('Topic', '')
subtopic = row.get('Subtopic', '')
question = row.get('Question', '')
answer = row.get('Answer', '')
# Create a rich text representation for embedding
text_content = f"Topic: {topic}\nSubtopic: {subtopic}\nQuestion: {question}\nAnswer: {answer}"
documents.append(text_content)
metadatas.append({
"topic": str(topic),
"subtopic": str(subtopic),
"question": str(question)
})
ids.append(f"faq_{index}")
# Batch add to avoid potential limits (though 350 is small)
self.collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
logger.info(f"Successfully ingested {len(documents)} documents.")
except Exception as e:
logger.error(f"Error ingesting data: {e}")
raise
def retrieve(self, query, n_results=3):
results = self.collection.query(
query_texts=[query],
n_results=n_results
)
return results
if __name__ == "__main__":
# Test run
rag = RagEngine()
results = rag.retrieve("Where do you deliver?")
print(results)