Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from sentence_transformers import SentenceTransformer | |
| import os | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class RagEngine: | |
| def __init__(self, data_path="data/tesco_faq.csv", collection_name="tesco_faq"): | |
| self.data_path = data_path | |
| self.collection_name = collection_name | |
| self.client = chromadb.PersistentClient(path="./chroma_db") | |
| # specific embedding function using sentence-transformers | |
| self.sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="all-MiniLM-L6-v2" | |
| ) | |
| self.collection = self.client.get_or_create_collection( | |
| name=self.collection_name, | |
| embedding_function=self.sentence_transformer_ef | |
| ) | |
| # Check if collection is empty, if so, ingest | |
| if self.collection.count() == 0: | |
| self.ingest_data() | |
| def ingest_data(self): | |
| logger.info(f"Ingesting data from {self.data_path}...") | |
| try: | |
| df = pd.read_csv(self.data_path) | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| for index, row in df.iterrows(): | |
| # Construct a meaningful document from the row | |
| # We want the model to see the context: Topic, Subtopic, Question, Answer | |
| topic = row.get('Topic', '') | |
| subtopic = row.get('Subtopic', '') | |
| question = row.get('Question', '') | |
| answer = row.get('Answer', '') | |
| # Create a rich text representation for embedding | |
| text_content = f"Topic: {topic}\nSubtopic: {subtopic}\nQuestion: {question}\nAnswer: {answer}" | |
| documents.append(text_content) | |
| metadatas.append({ | |
| "topic": str(topic), | |
| "subtopic": str(subtopic), | |
| "question": str(question) | |
| }) | |
| ids.append(f"faq_{index}") | |
| # Batch add to avoid potential limits (though 350 is small) | |
| self.collection.add( | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| logger.info(f"Successfully ingested {len(documents)} documents.") | |
| except Exception as e: | |
| logger.error(f"Error ingesting data: {e}") | |
| raise | |
| def retrieve(self, query, n_results=3): | |
| results = self.collection.query( | |
| query_texts=[query], | |
| n_results=n_results | |
| ) | |
| return results | |
| if __name__ == "__main__": | |
| # Test run | |
| rag = RagEngine() | |
| results = rag.retrieve("Where do you deliver?") | |
| print(results) | |