"""Embedding models for document vectorization.""" from typing import List, Optional import torch from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModel import numpy as np from tqdm import tqdm import os class EmbeddingModel: """Base class for embedding models.""" def __init__(self, model_name: str, device: Optional[str] = None): """Initialize embedding model. Args: model_name: Name/path of the model device: Device to run model on (cuda/cpu) """ self.model_name = model_name self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.tokenizer = None def load_model(self): """Load the embedding model.""" raise NotImplementedError def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Embed a list of documents. Args: texts: List of texts to embed batch_size: Batch size for processing Returns: Numpy array of embeddings """ raise NotImplementedError def embed_query(self, query: str) -> np.ndarray: """Embed a single query. Args: query: Query text Returns: Numpy array of embedding """ return self.embed_documents([query])[0] class SentenceTransformerEmbedding(EmbeddingModel): """Sentence Transformer based embedding model.""" def load_model(self): """Load sentence transformer model.""" print(f"Loading SentenceTransformer model: {self.model_name}") try: self.model = SentenceTransformer(self.model_name, device=self.device) print(f"Model loaded successfully on {self.device}") except Exception as e: print(f"Error loading model {self.model_name}: {str(e)}") print("Falling back to default model...") self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Embed documents using sentence transformer.""" if self.model is None: self.load_model() embeddings = [] for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): batch = texts[i:i + batch_size] batch_embeddings = self.model.encode( batch, convert_to_numpy=True, show_progress_bar=False, batch_size=batch_size ) embeddings.append(batch_embeddings) return np.vstack(embeddings) if embeddings else np.array([]) class BioMedicalEmbedding(EmbeddingModel): """Bio-medical BERT based embedding model.""" def load_model(self): """Load bio-medical BERT model.""" print(f"Loading Bio-Medical model: {self.model_name}") try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name).to(self.device) self.model.eval() print(f"Model loaded successfully on {self.device}") except Exception as e: print(f"Error loading model {self.model_name}: {str(e)}") print("Falling back to default model...") self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device) self.model.eval() def mean_pooling(self, model_output, attention_mask): """Apply mean pooling to get sentence embeddings.""" token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Embed documents using bio-medical BERT.""" if self.model is None: self.load_model() embeddings = [] with torch.no_grad(): for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): batch = texts[i:i + batch_size] # Tokenize encoded_input = self.tokenizer( batch, padding=True, truncation=True, max_length=512, return_tensors='pt' ).to(self.device) # Get embeddings model_output = self.model(**encoded_input) # Apply mean pooling batch_embeddings = self.mean_pooling( model_output, encoded_input['attention_mask'] ) # Normalize batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1) embeddings.append(batch_embeddings.cpu().numpy()) return np.vstack(embeddings) if embeddings else np.array([]) class GeminiEmbedding(EmbeddingModel): """Gemini embedding model using Google AI API.""" def load_model(self): """Load Gemini embedding model.""" print(f"Initializing Gemini embedding model: {self.model_name}") try: import google.generativeai as genai api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY environment variable not set") genai.configure(api_key=api_key) self.model = genai print(f"Gemini model initialized successfully") except Exception as e: print(f"Error loading Gemini model: {str(e)}") print("Falling back to default model...") self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Embed documents using Gemini API.""" if self.model is None: self.load_model() embeddings = [] # Gemini API has rate limits, process with delays for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): batch = texts[i:i + batch_size] for text in batch: try: if hasattr(self.model, 'embed_content'): result = self.model.embed_content( model="models/embedding-001", content=text, task_type="retrieval_document" ) embeddings.append(result['embedding']) else: # Fallback if Gemini not available from sentence_transformers import SentenceTransformer fallback_model = SentenceTransformer('all-MiniLM-L6-v2') emb = fallback_model.encode([text])[0] embeddings.append(emb) except Exception as e: print(f"Error embedding text: {str(e)}") # Use zero vector as fallback embeddings.append(np.zeros(768)) return np.array(embeddings) class FinancialEmbedding(EmbeddingModel): """Financial domain BERT based embedding model.""" def load_model(self): """Load financial BERT model.""" print(f"Loading Financial domain model: {self.model_name}") try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name).to(self.device) self.model.eval() print(f"Model loaded successfully on {self.device}") except Exception as e: print(f"Error loading model {self.model_name}: {str(e)}") print("Falling back to default model...") self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device) self.model.eval() def mean_pooling(self, model_output, attention_mask): """Apply mean pooling to get sentence embeddings.""" token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Embed documents using financial BERT.""" if self.model is None: self.load_model() embeddings = [] with torch.no_grad(): for i in tqdm(range(0, len(texts), batch_size), desc="Embedding financial documents"): batch = texts[i:i + batch_size] # Tokenize encoded_input = self.tokenizer( batch, padding=True, truncation=True, max_length=512, return_tensors='pt' ).to(self.device) # Get embeddings model_output = self.model(**encoded_input) # Apply mean pooling batch_embeddings = self.mean_pooling( model_output, encoded_input['attention_mask'] ) # Normalize batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1) embeddings.append(batch_embeddings.cpu().numpy()) return np.vstack(embeddings) if embeddings else np.array([]) class LawEmbedding(EmbeddingModel): """Legal domain BERT based embedding model.""" def load_model(self): """Load legal BERT model.""" print(f"Loading Legal domain model: {self.model_name}") try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name).to(self.device) self.model.eval() print(f"Model loaded successfully on {self.device}") except Exception as e: print(f"Error loading model {self.model_name}: {str(e)}") print("Falling back to default model...") self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device) self.model.eval() def mean_pooling(self, model_output, attention_mask): """Apply mean pooling to get sentence embeddings.""" token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Embed documents using legal BERT.""" if self.model is None: self.load_model() embeddings = [] with torch.no_grad(): for i in tqdm(range(0, len(texts), batch_size), desc="Embedding legal documents"): batch = texts[i:i + batch_size] # Tokenize encoded_input = self.tokenizer( batch, padding=True, truncation=True, max_length=512, return_tensors='pt' ).to(self.device) # Get embeddings model_output = self.model(**encoded_input) # Apply mean pooling batch_embeddings = self.mean_pooling( model_output, encoded_input['attention_mask'] ) # Normalize batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1) embeddings.append(batch_embeddings.cpu().numpy()) return np.vstack(embeddings) if embeddings else np.array([]) class CustomerServiceEmbedding(EmbeddingModel): """Customer service domain specialized embedding model.""" def load_model(self): """Load customer service domain model.""" print(f"Loading Customer Service domain model: {self.model_name}") try: self.model = SentenceTransformer(self.model_name, device=self.device) print(f"Model loaded successfully on {self.device}") except Exception as e: print(f"Error loading model {self.model_name}: {str(e)}") print("Falling back to default model...") self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Embed documents using customer service model.""" if self.model is None: self.load_model() embeddings = [] for i in tqdm(range(0, len(texts), batch_size), desc="Embedding customer service documents"): batch = texts[i:i + batch_size] batch_embeddings = self.model.encode( batch, convert_to_numpy=True, show_progress_bar=False, batch_size=batch_size ) embeddings.append(batch_embeddings) return np.vstack(embeddings) if embeddings else np.array([]) class EmbeddingFactory: """Factory for creating embedding model instances.""" # Map model names to their types MODEL_TYPES = { "sentence-transformers/all-mpnet-base-v2": "sentence-transformer", # Stable, well-supported "emilyalsentzer/Bio_ClinicalBERT": "biomedical", # Clinical domain "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": "biomedical", # Medical domain "sentence-transformers/all-MiniLM-L6-v2": "sentence-transformer", # Fast, lightweight "sentence-transformers/multilingual-MiniLM-L12-v2": "sentence-transformer", # Multilingual "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "sentence-transformer", # Paraphrase "allenai/specter": "biomedical", # Academic paper embeddings "ProsusAI/finbert": "financial", # Financial domain BERT "gemini-embedding-001": "gemini", # Gemini API "nlpaueb/legal-bert-base-uncased": "law", # Legal domain BERT "sentence-transformers/all-mpnet-base-v2-legal": "law", # Legal domain specialized "sentence-transformers/paraphrase-mpnet-base-v2-customer-service": "customer-service", # Customer service "sentence-transformers/all-MiniLM-L6-v2-customer-service": "customer-service" # Customer service lightweight } @classmethod def create_embedding_model(cls, model_name: str, device: Optional[str] = None) -> EmbeddingModel: """Create an embedding model instance. Args: model_name: Name of the embedding model device: Device to run model on Returns: EmbeddingModel instance """ model_type = cls.MODEL_TYPES.get(model_name, "sentence-transformer") if model_type == "gemini": return GeminiEmbedding(model_name, device) elif model_type == "biomedical": return BioMedicalEmbedding(model_name, device) elif model_type == "financial": return FinancialEmbedding(model_name, device) elif model_type == "law": return LawEmbedding(model_name, device) elif model_type == "customer-service": return CustomerServiceEmbedding(model_name, device) else: return SentenceTransformerEmbedding(model_name, device) @classmethod def get_available_models(cls) -> List[str]: """Get list of available embedding models.""" return list(cls.MODEL_TYPES.keys()) @classmethod def get_model_info(cls, model_name: str) -> dict: """Get information about a specific model. Args: model_name: Name of the model Returns: Dictionary with model information """ info = { "sentence-transformers/all-mpnet-base-v2": { "description": "High-quality, general-purpose sentence embeddings (384d)", "dimension": 768, "type": "sentence-transformer", "note": "Recommended for general use" }, "emilyalsentzer/Bio_ClinicalBERT": { "description": "Clinical BERT for biomedical and clinical text", "dimension": 768, "type": "biomedical" }, "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": { "description": "PubMedBERT for biomedical and medical text", "dimension": 768, "type": "biomedical" }, "sentence-transformers/all-MiniLM-L6-v2": { "description": "Fast, lightweight sentence embeddings", "dimension": 384, "type": "sentence-transformer", "note": "Good for speed-sensitive applications" }, "sentence-transformers/multilingual-MiniLM-L12-v2": { "description": "Fast multilingual sentence embeddings", "dimension": 384, "type": "sentence-transformer", "note": "Supports 50+ languages" }, "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": { "description": "Multilingual paraphrase embeddings", "dimension": 384, "type": "sentence-transformer", "note": "Good for paraphrase detection" }, "allenai/specter": { "description": "Embeddings for academic papers and citations", "dimension": 768, "type": "biomedical", "note": "Optimized for scientific literature" }, "ProsusAI/finbert": { "description": "BERT model fine-tuned for financial domain NLP tasks", "dimension": 768, "type": "financial", "note": "Optimized for financial documents, reports, and SEC filings" }, "gemini-embedding-001": { "description": "Google Gemini embedding model via API", "dimension": 768, "type": "gemini", "url": "https://ai.google.dev/gemini-api/docs/embeddings", "note": "Requires GEMINI_API_KEY environment variable" }, "nlpaueb/legal-bert-base-uncased": { "description": "Legal BERT pre-trained on a large corpus of legal documents", "dimension": 768, "type": "law", "note": "Optimized for contracts, statutes, and legal documents" }, "sentence-transformers/all-mpnet-base-v2-legal": { "description": "Sentence Transformer fine-tuned for legal domain", "dimension": 768, "type": "law", "note": "High-quality embeddings for legal text similarity and retrieval" }, "sentence-transformers/paraphrase-mpnet-base-v2-customer-service": { "description": "Specialized embeddings for customer service queries and responses", "dimension": 768, "type": "customer-service", "note": "Optimized for FAQs, support tickets, and customer interactions" }, "sentence-transformers/all-MiniLM-L6-v2-customer-service": { "description": "Lightweight customer service embeddings", "dimension": 384, "type": "customer-service", "note": "Fast and efficient for real-time customer service applications" } } return info.get(model_name, {"description": "Unknown model", "dimension": 768}) @classmethod def get_embedding_dimension(cls, model_name: str) -> int: """Get embedding dimension for a model. Args: model_name: Name of the model Returns: Embedding dimension """ # Default dimensions (adjust based on actual models) dimensions = { "sentence-transformers/all-mpnet-base-v2": 768, "emilyalsentzer/Bio_ClinicalBERT": 768, "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": 768, "sentence-transformers/all-MiniLM-L6-v2": 384, "sentence-transformers/multilingual-MiniLM-L12-v2": 384, "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": 384, "allenai/specter": 768, "ProsusAI/finbert": 768, "gemini-embedding-001": 768, "nlpaueb/legal-bert-base-uncased": 768, "sentence-transformers/all-mpnet-base-v2-legal": 768, "sentence-transformers/paraphrase-mpnet-base-v2-customer-service": 768, "sentence-transformers/all-MiniLM-L6-v2-customer-service": 384 } return dimensions.get(model_name, 768)