Spaces:
Sleeping
Sleeping
| """ | |
| RAG (Retrieval Augmented Generation) store for fraud pattern matching | |
| """ | |
| import os | |
| import json | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| import pickle | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| logger = logging.getLogger(__name__) | |
| class RAGStore: | |
| """Simple RAG store using sentence transformers and local file storage""" | |
| def __init__(self, collection_dir: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| self.collection_dir = collection_dir | |
| self.model_name = model_name | |
| self.embeddings_file = os.path.join(collection_dir, "embeddings.pkl") | |
| self.texts_file = os.path.join(collection_dir, "texts.json") | |
| self.metadata_file = os.path.join(collection_dir, "metadata.json") | |
| os.makedirs(collection_dir, exist_ok=True) | |
| # Initialize sentence transformer | |
| try: | |
| self.encoder = SentenceTransformer(model_name) | |
| logger.info(f"Initialized SentenceTransformer: {model_name}") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize SentenceTransformer: {e}") | |
| self.encoder = None | |
| # Load existing data | |
| self.texts = [] | |
| self.metadatas = [] | |
| self.embeddings = None | |
| self._load_data() | |
| def _load_data(self): | |
| """Load existing embeddings, texts, and metadata""" | |
| try: | |
| if os.path.exists(self.texts_file): | |
| with open(self.texts_file, 'r') as f: | |
| self.texts = json.load(f) | |
| if os.path.exists(self.metadata_file): | |
| with open(self.metadata_file, 'r') as f: | |
| self.metadatas = json.load(f) | |
| if os.path.exists(self.embeddings_file): | |
| with open(self.embeddings_file, 'rb') as f: | |
| self.embeddings = pickle.load(f) | |
| logger.info(f"Loaded {len(self.texts)} existing documents") | |
| except Exception as e: | |
| logger.error(f"Error loading RAG data: {e}") | |
| self.texts = [] | |
| self.metadatas = [] | |
| self.embeddings = None | |
| def _save_data(self): | |
| """Save embeddings, texts, and metadata to files""" | |
| try: | |
| with open(self.texts_file, 'w') as f: | |
| json.dump(self.texts, f) | |
| with open(self.metadata_file, 'w') as f: | |
| json.dump(self.metadatas, f, default=str) | |
| if self.embeddings is not None: | |
| with open(self.embeddings_file, 'wb') as f: | |
| pickle.dump(self.embeddings, f) | |
| logger.info(f"Saved {len(self.texts)} documents to storage") | |
| except Exception as e: | |
| logger.error(f"Error saving RAG data: {e}") | |
| def add(self, texts: List[str], metadatas: List[Dict[str, Any]]): | |
| """Add new documents to the RAG store""" | |
| if not self.encoder: | |
| logger.warning("No encoder available, cannot add documents") | |
| return | |
| if len(texts) != len(metadatas): | |
| logger.error("Texts and metadatas must have the same length") | |
| return | |
| try: | |
| # Generate embeddings for new texts | |
| new_embeddings = self.encoder.encode(texts) | |
| # Add to existing data | |
| self.texts.extend(texts) | |
| self.metadatas.extend(metadatas) | |
| if self.embeddings is None: | |
| self.embeddings = new_embeddings | |
| else: | |
| self.embeddings = np.vstack([self.embeddings, new_embeddings]) | |
| # Save to disk | |
| self._save_data() | |
| logger.info(f"Added {len(texts)} new documents to RAG store") | |
| except Exception as e: | |
| logger.error(f"Error adding documents to RAG store: {e}") | |
| def query(self, query: str, k: int = 5) -> List[Dict[str, Any]]: | |
| """Query the RAG store for similar documents""" | |
| if not self.encoder or self.embeddings is None or len(self.texts) == 0: | |
| logger.warning("RAG store is empty or encoder unavailable") | |
| return [] | |
| try: | |
| # Encode the query | |
| query_embedding = self.encoder.encode([query]) | |
| # Calculate similarities | |
| similarities = cosine_similarity(query_embedding, self.embeddings)[0] | |
| # Get top k results | |
| top_indices = np.argsort(similarities)[::-1][:k] | |
| results = [] | |
| for idx in top_indices: | |
| if similarities[idx] > 0.1: # Minimum similarity threshold | |
| results.append({ | |
| "text": self.texts[idx], | |
| "metadata": self.metadatas[idx], | |
| "similarity": float(similarities[idx]) | |
| }) | |
| logger.info(f"Query returned {len(results)} results") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error querying RAG store: {e}") | |
| return [] | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the RAG store""" | |
| return { | |
| "total_documents": len(self.texts), | |
| "has_embeddings": self.embeddings is not None, | |
| "encoder_available": self.encoder is not None, | |
| "collection_dir": self.collection_dir | |
| } | |
| def clear(self): | |
| """Clear all data from the RAG store""" | |
| try: | |
| self.texts = [] | |
| self.metadatas = [] | |
| self.embeddings = None | |
| # Remove files | |
| for file_path in [self.embeddings_file, self.texts_file, self.metadata_file]: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| logger.info("RAG store cleared") | |
| except Exception as e: | |
| logger.error(f"Error clearing RAG store: {e}") | |
| # Utility functions for fraud-specific RAG queries | |
| def build_fraud_context(transaction_data: Dict[str, Any]) -> str: | |
| """Build a searchable text representation of transaction data""" | |
| parts = [] | |
| if 'amount' in transaction_data: | |
| parts.append(f"amount:{transaction_data['amount']}") | |
| if 'merchant' in transaction_data: | |
| parts.append(f"merchant:{transaction_data['merchant']}") | |
| if 'category' in transaction_data: | |
| parts.append(f"category:{transaction_data['category']}") | |
| if 'description' in transaction_data: | |
| parts.append(f"description:{transaction_data['description']}") | |
| if 'timestamp' in transaction_data: | |
| parts.append(f"time:{transaction_data['timestamp']}") | |
| return " ".join(parts) | |
| def extract_fraud_patterns(rag_results: List[Dict[str, Any]]) -> List[str]: | |
| """Extract common fraud patterns from RAG results""" | |
| patterns = [] | |
| for result in rag_results: | |
| metadata = result.get('metadata', {}) | |
| similarity = result.get('similarity', 0) | |
| if similarity > 0.7: # High similarity threshold | |
| if 'merchant' in metadata: | |
| patterns.append(f"Similar merchant: {metadata['merchant']}") | |
| if 'amount' in metadata: | |
| patterns.append(f"Similar amount: ${metadata['amount']}") | |
| if 'category' in metadata: | |
| patterns.append(f"Similar category: {metadata['category']}") | |
| return list(set(patterns)) # Remove duplicates |