Spaces:
Build error
Build error
| import os | |
| import numpy as np | |
| import pandas as pd | |
| from transformers import pipeline | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from utils.logger import setup_logger | |
| from utils.model_loader import ModelLoader | |
| logger = setup_logger(__name__) | |
| class RAGSystem: | |
| def __init__(self, csv_path="apparel.csv"): | |
| try: | |
| # Initialize the sentence transformer model | |
| self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| # Initialize the QA pipeline | |
| self.qa_pipeline = pipeline( | |
| "question-answering", | |
| model="distilbert-base-cased-distilled-squad", | |
| tokenizer="distilbert-base-cased-distilled-squad" | |
| ) | |
| self.setup_system(csv_path) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize RAGSystem: {str(e)}") | |
| raise | |
| def setup_system(self, csv_path): | |
| if not os.path.exists(csv_path): | |
| raise FileNotFoundError(f"CSV file not found at {csv_path}") | |
| try: | |
| # Load and preprocess documents | |
| self.documents = pd.read_csv(csv_path) | |
| self.texts = self.documents['Title'].astype(str).tolist() | |
| # Create embeddings for all documents | |
| self.embeddings = self.embedder.encode(self.texts) | |
| logger.info(f"Successfully loaded {len(self.texts)} documents") | |
| except Exception as e: | |
| logger.error(f"Failed to setup RAG system: {str(e)}") | |
| raise | |
| def get_relevant_documents(self, query, top_k=5): | |
| try: | |
| # Get query embedding | |
| query_embedding = self.embedder.encode([query]) | |
| # Calculate similarities | |
| similarities = cosine_similarity(query_embedding, self.embeddings)[0] | |
| # Get top k most similar documents | |
| top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| return [self.texts[i] for i in top_indices] | |
| except Exception as e: | |
| logger.error(f"Error retrieving relevant documents: {str(e)}") | |
| return [] | |
| def process_query(self, query): | |
| try: | |
| # Get relevant documents | |
| relevant_docs = self.get_relevant_documents(query) | |
| if not relevant_docs: | |
| return "No relevant documents found." | |
| # Combine retrieved documents into context | |
| context = " ".join(relevant_docs) | |
| # Prepare QA input | |
| qa_input = { | |
| "question": query, | |
| "context": context[:512] # Limit context length for the model | |
| } | |
| # Get answer using QA pipeline | |
| answer = self.qa_pipeline(qa_input) | |
| return answer['answer'] | |
| except Exception as e: | |
| logger.error(f"Error processing query: {str(e)}") | |
| return f"Failed to process query: {str(e)}" |