Spaces:
Sleeping
Sleeping
| """Retrieval component for the RAG system.""" | |
| import faiss | |
| import numpy as np | |
| from typing import List, Dict, Tuple | |
| from elasticsearch import Elasticsearch | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| class FinancialDataRetriever: | |
| def __init__(self, config: Dict): | |
| """Initialize the retriever with configuration.""" | |
| self.retriever_type = config['rag']['retriever'] | |
| self.max_documents = config['rag']['max_documents'] | |
| self.similarity_threshold = config['rag']['similarity_threshold'] | |
| # Initialize FAISS index | |
| self.dimension = 768 # BERT embedding dimension | |
| self.index = faiss.IndexFlatL2(self.dimension) | |
| # Initialize transformer model for embeddings | |
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| self.model = AutoModel.from_pretrained('bert-base-uncased') | |
| # Initialize Elasticsearch if needed | |
| if self.retriever_type == "elasticsearch": | |
| self.es = Elasticsearch() | |
| def encode_text(self, texts: List[str]) -> np.ndarray: | |
| """Encode text using BERT.""" | |
| tokens = self.tokenizer(texts, padding=True, truncation=True, | |
| return_tensors="pt", max_length=512) | |
| with torch.no_grad(): | |
| outputs = self.model(**tokens) | |
| embeddings = outputs.last_hidden_state[:, 0, :].numpy() | |
| return embeddings | |
| def index_documents(self, documents: List[Dict]): | |
| """Index documents for retrieval.""" | |
| if self.retriever_type == "faiss": | |
| texts = [doc['text'] for doc in documents] | |
| embeddings = self.encode_text(texts) | |
| self.index.add(embeddings) | |
| self.documents = documents | |
| else: | |
| for doc in documents: | |
| self.es.index(index="financial_data", document=doc) | |
| def retrieve(self, query: str, k: int = None) -> List[Dict]: | |
| """Retrieve relevant documents.""" | |
| k = k or self.max_documents | |
| query_embedding = self.encode_text([query]) | |
| if self.retriever_type == "faiss": | |
| distances, indices = self.index.search(query_embedding, k) | |
| results = [ | |
| { | |
| 'document': self.documents[idx], | |
| 'score': float(1 / (1 + dist)) | |
| } | |
| for dist, idx in zip(distances[0], indices[0]) | |
| if 1 / (1 + dist) >= self.similarity_threshold | |
| ] | |
| else: | |
| response = self.es.search( | |
| index="financial_data", | |
| query={ | |
| "match": { | |
| "text": query | |
| } | |
| }, | |
| size=k | |
| ) | |
| results = [ | |
| { | |
| 'document': hit['_source'], | |
| 'score': hit['_score'] | |
| } | |
| for hit in response['hits']['hits'] | |
| if hit['_score'] >= self.similarity_threshold | |
| ] | |
| return results | |