Spaces:
Running
Running
| """Embedding models for document vectorization.""" | |
| from typing import List, Optional | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModel | |
| import numpy as np | |
| from tqdm import tqdm | |
| import os | |
| class EmbeddingModel: | |
| """Base class for embedding models.""" | |
| def __init__(self, model_name: str, device: Optional[str] = None): | |
| """Initialize embedding model. | |
| Args: | |
| model_name: Name/path of the model | |
| device: Device to run model on (cuda/cpu) | |
| """ | |
| self.model_name = model_name | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| self.tokenizer = None | |
| def load_model(self): | |
| """Load the embedding model.""" | |
| raise NotImplementedError | |
| def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Embed a list of documents. | |
| Args: | |
| texts: List of texts to embed | |
| batch_size: Batch size for processing | |
| Returns: | |
| Numpy array of embeddings | |
| """ | |
| raise NotImplementedError | |
| def embed_query(self, query: str) -> np.ndarray: | |
| """Embed a single query. | |
| Args: | |
| query: Query text | |
| Returns: | |
| Numpy array of embedding | |
| """ | |
| return self.embed_documents([query])[0] | |
| class SentenceTransformerEmbedding(EmbeddingModel): | |
| """Sentence Transformer based embedding model.""" | |
| def load_model(self): | |
| """Load sentence transformer model.""" | |
| print(f"Loading SentenceTransformer model: {self.model_name}") | |
| try: | |
| self.model = SentenceTransformer(self.model_name, device=self.device) | |
| print(f"Model loaded successfully on {self.device}") | |
| except Exception as e: | |
| print(f"Error loading model {self.model_name}: {str(e)}") | |
| print("Falling back to default model...") | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) | |
| def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Embed documents using sentence transformer.""" | |
| if self.model is None: | |
| self.load_model() | |
| embeddings = [] | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): | |
| batch = texts[i:i + batch_size] | |
| batch_embeddings = self.model.encode( | |
| batch, | |
| convert_to_numpy=True, | |
| show_progress_bar=False, | |
| batch_size=batch_size | |
| ) | |
| embeddings.append(batch_embeddings) | |
| return np.vstack(embeddings) if embeddings else np.array([]) | |
| class BioMedicalEmbedding(EmbeddingModel): | |
| """Bio-medical BERT based embedding model.""" | |
| def load_model(self): | |
| """Load bio-medical BERT model.""" | |
| print(f"Loading Bio-Medical model: {self.model_name}") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModel.from_pretrained(self.model_name).to(self.device) | |
| self.model.eval() | |
| print(f"Model loaded successfully on {self.device}") | |
| except Exception as e: | |
| print(f"Error loading model {self.model_name}: {str(e)}") | |
| print("Falling back to default model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device) | |
| self.model.eval() | |
| def mean_pooling(self, model_output, attention_mask): | |
| """Apply mean pooling to get sentence embeddings.""" | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Embed documents using bio-medical BERT.""" | |
| if self.model is None: | |
| self.load_model() | |
| embeddings = [] | |
| with torch.no_grad(): | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): | |
| batch = texts[i:i + batch_size] | |
| # Tokenize | |
| encoded_input = self.tokenizer( | |
| batch, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| # Get embeddings | |
| model_output = self.model(**encoded_input) | |
| # Apply mean pooling | |
| batch_embeddings = self.mean_pooling( | |
| model_output, | |
| encoded_input['attention_mask'] | |
| ) | |
| # Normalize | |
| batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1) | |
| embeddings.append(batch_embeddings.cpu().numpy()) | |
| return np.vstack(embeddings) if embeddings else np.array([]) | |
| class GeminiEmbedding(EmbeddingModel): | |
| """Gemini embedding model using Google AI API.""" | |
| def load_model(self): | |
| """Load Gemini embedding model.""" | |
| print(f"Initializing Gemini embedding model: {self.model_name}") | |
| try: | |
| import google.generativeai as genai | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError("GEMINI_API_KEY environment variable not set") | |
| genai.configure(api_key=api_key) | |
| self.model = genai | |
| print(f"Gemini model initialized successfully") | |
| except Exception as e: | |
| print(f"Error loading Gemini model: {str(e)}") | |
| print("Falling back to default model...") | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) | |
| def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Embed documents using Gemini API.""" | |
| if self.model is None: | |
| self.load_model() | |
| embeddings = [] | |
| # Gemini API has rate limits, process with delays | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): | |
| batch = texts[i:i + batch_size] | |
| for text in batch: | |
| try: | |
| if hasattr(self.model, 'embed_content'): | |
| result = self.model.embed_content( | |
| model="models/embedding-001", | |
| content=text, | |
| task_type="retrieval_document" | |
| ) | |
| embeddings.append(result['embedding']) | |
| else: | |
| # Fallback if Gemini not available | |
| from sentence_transformers import SentenceTransformer | |
| fallback_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| emb = fallback_model.encode([text])[0] | |
| embeddings.append(emb) | |
| except Exception as e: | |
| print(f"Error embedding text: {str(e)}") | |
| # Use zero vector as fallback | |
| embeddings.append(np.zeros(768)) | |
| return np.array(embeddings) | |
| class FinancialEmbedding(EmbeddingModel): | |
| """Financial domain BERT based embedding model.""" | |
| def load_model(self): | |
| """Load financial BERT model.""" | |
| print(f"Loading Financial domain model: {self.model_name}") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModel.from_pretrained(self.model_name).to(self.device) | |
| self.model.eval() | |
| print(f"Model loaded successfully on {self.device}") | |
| except Exception as e: | |
| print(f"Error loading model {self.model_name}: {str(e)}") | |
| print("Falling back to default model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device) | |
| self.model.eval() | |
| def mean_pooling(self, model_output, attention_mask): | |
| """Apply mean pooling to get sentence embeddings.""" | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Embed documents using financial BERT.""" | |
| if self.model is None: | |
| self.load_model() | |
| embeddings = [] | |
| with torch.no_grad(): | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Embedding financial documents"): | |
| batch = texts[i:i + batch_size] | |
| # Tokenize | |
| encoded_input = self.tokenizer( | |
| batch, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| # Get embeddings | |
| model_output = self.model(**encoded_input) | |
| # Apply mean pooling | |
| batch_embeddings = self.mean_pooling( | |
| model_output, | |
| encoded_input['attention_mask'] | |
| ) | |
| # Normalize | |
| batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1) | |
| embeddings.append(batch_embeddings.cpu().numpy()) | |
| return np.vstack(embeddings) if embeddings else np.array([]) | |
| class LawEmbedding(EmbeddingModel): | |
| """Legal domain BERT based embedding model.""" | |
| def load_model(self): | |
| """Load legal BERT model.""" | |
| print(f"Loading Legal domain model: {self.model_name}") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModel.from_pretrained(self.model_name).to(self.device) | |
| self.model.eval() | |
| print(f"Model loaded successfully on {self.device}") | |
| except Exception as e: | |
| print(f"Error loading model {self.model_name}: {str(e)}") | |
| print("Falling back to default model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device) | |
| self.model.eval() | |
| def mean_pooling(self, model_output, attention_mask): | |
| """Apply mean pooling to get sentence embeddings.""" | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Embed documents using legal BERT.""" | |
| if self.model is None: | |
| self.load_model() | |
| embeddings = [] | |
| with torch.no_grad(): | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Embedding legal documents"): | |
| batch = texts[i:i + batch_size] | |
| # Tokenize | |
| encoded_input = self.tokenizer( | |
| batch, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| # Get embeddings | |
| model_output = self.model(**encoded_input) | |
| # Apply mean pooling | |
| batch_embeddings = self.mean_pooling( | |
| model_output, | |
| encoded_input['attention_mask'] | |
| ) | |
| # Normalize | |
| batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1) | |
| embeddings.append(batch_embeddings.cpu().numpy()) | |
| return np.vstack(embeddings) if embeddings else np.array([]) | |
| class CustomerServiceEmbedding(EmbeddingModel): | |
| """Customer service domain specialized embedding model.""" | |
| def load_model(self): | |
| """Load customer service domain model.""" | |
| print(f"Loading Customer Service domain model: {self.model_name}") | |
| try: | |
| self.model = SentenceTransformer(self.model_name, device=self.device) | |
| print(f"Model loaded successfully on {self.device}") | |
| except Exception as e: | |
| print(f"Error loading model {self.model_name}: {str(e)}") | |
| print("Falling back to default model...") | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) | |
| def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Embed documents using customer service model.""" | |
| if self.model is None: | |
| self.load_model() | |
| embeddings = [] | |
| for i in tqdm(range(0, len(texts), batch_size), desc="Embedding customer service documents"): | |
| batch = texts[i:i + batch_size] | |
| batch_embeddings = self.model.encode( | |
| batch, | |
| convert_to_numpy=True, | |
| show_progress_bar=False, | |
| batch_size=batch_size | |
| ) | |
| embeddings.append(batch_embeddings) | |
| return np.vstack(embeddings) if embeddings else np.array([]) | |
| class EmbeddingFactory: | |
| """Factory for creating embedding model instances.""" | |
| # Map model names to their types | |
| MODEL_TYPES = { | |
| "sentence-transformers/all-mpnet-base-v2": "sentence-transformer", # Stable, well-supported | |
| "emilyalsentzer/Bio_ClinicalBERT": "biomedical", # Clinical domain | |
| "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": "biomedical", # Medical domain | |
| "sentence-transformers/all-MiniLM-L6-v2": "sentence-transformer", # Fast, lightweight | |
| "sentence-transformers/multilingual-MiniLM-L12-v2": "sentence-transformer", # Multilingual | |
| "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "sentence-transformer", # Paraphrase | |
| "allenai/specter": "biomedical", # Academic paper embeddings | |
| "ProsusAI/finbert": "financial", # Financial domain BERT | |
| "gemini-embedding-001": "gemini", # Gemini API | |
| "nlpaueb/legal-bert-base-uncased": "law", # Legal domain BERT | |
| "sentence-transformers/all-mpnet-base-v2-legal": "law", # Legal domain specialized | |
| "sentence-transformers/paraphrase-mpnet-base-v2-customer-service": "customer-service", # Customer service | |
| "sentence-transformers/all-MiniLM-L6-v2-customer-service": "customer-service" # Customer service lightweight | |
| } | |
| def create_embedding_model(cls, model_name: str, device: Optional[str] = None) -> EmbeddingModel: | |
| """Create an embedding model instance. | |
| Args: | |
| model_name: Name of the embedding model | |
| device: Device to run model on | |
| Returns: | |
| EmbeddingModel instance | |
| """ | |
| model_type = cls.MODEL_TYPES.get(model_name, "sentence-transformer") | |
| if model_type == "gemini": | |
| return GeminiEmbedding(model_name, device) | |
| elif model_type == "biomedical": | |
| return BioMedicalEmbedding(model_name, device) | |
| elif model_type == "financial": | |
| return FinancialEmbedding(model_name, device) | |
| elif model_type == "law": | |
| return LawEmbedding(model_name, device) | |
| elif model_type == "customer-service": | |
| return CustomerServiceEmbedding(model_name, device) | |
| else: | |
| return SentenceTransformerEmbedding(model_name, device) | |
| def get_available_models(cls) -> List[str]: | |
| """Get list of available embedding models.""" | |
| return list(cls.MODEL_TYPES.keys()) | |
| def get_model_info(cls, model_name: str) -> dict: | |
| """Get information about a specific model. | |
| Args: | |
| model_name: Name of the model | |
| Returns: | |
| Dictionary with model information | |
| """ | |
| info = { | |
| "sentence-transformers/all-mpnet-base-v2": { | |
| "description": "High-quality, general-purpose sentence embeddings (384d)", | |
| "dimension": 768, | |
| "type": "sentence-transformer", | |
| "note": "Recommended for general use" | |
| }, | |
| "emilyalsentzer/Bio_ClinicalBERT": { | |
| "description": "Clinical BERT for biomedical and clinical text", | |
| "dimension": 768, | |
| "type": "biomedical" | |
| }, | |
| "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": { | |
| "description": "PubMedBERT for biomedical and medical text", | |
| "dimension": 768, | |
| "type": "biomedical" | |
| }, | |
| "sentence-transformers/all-MiniLM-L6-v2": { | |
| "description": "Fast, lightweight sentence embeddings", | |
| "dimension": 384, | |
| "type": "sentence-transformer", | |
| "note": "Good for speed-sensitive applications" | |
| }, | |
| "sentence-transformers/multilingual-MiniLM-L12-v2": { | |
| "description": "Fast multilingual sentence embeddings", | |
| "dimension": 384, | |
| "type": "sentence-transformer", | |
| "note": "Supports 50+ languages" | |
| }, | |
| "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": { | |
| "description": "Multilingual paraphrase embeddings", | |
| "dimension": 384, | |
| "type": "sentence-transformer", | |
| "note": "Good for paraphrase detection" | |
| }, | |
| "allenai/specter": { | |
| "description": "Embeddings for academic papers and citations", | |
| "dimension": 768, | |
| "type": "biomedical", | |
| "note": "Optimized for scientific literature" | |
| }, | |
| "ProsusAI/finbert": { | |
| "description": "BERT model fine-tuned for financial domain NLP tasks", | |
| "dimension": 768, | |
| "type": "financial", | |
| "note": "Optimized for financial documents, reports, and SEC filings" | |
| }, | |
| "gemini-embedding-001": { | |
| "description": "Google Gemini embedding model via API", | |
| "dimension": 768, | |
| "type": "gemini", | |
| "url": "https://ai.google.dev/gemini-api/docs/embeddings", | |
| "note": "Requires GEMINI_API_KEY environment variable" | |
| }, | |
| "nlpaueb/legal-bert-base-uncased": { | |
| "description": "Legal BERT pre-trained on a large corpus of legal documents", | |
| "dimension": 768, | |
| "type": "law", | |
| "note": "Optimized for contracts, statutes, and legal documents" | |
| }, | |
| "sentence-transformers/all-mpnet-base-v2-legal": { | |
| "description": "Sentence Transformer fine-tuned for legal domain", | |
| "dimension": 768, | |
| "type": "law", | |
| "note": "High-quality embeddings for legal text similarity and retrieval" | |
| }, | |
| "sentence-transformers/paraphrase-mpnet-base-v2-customer-service": { | |
| "description": "Specialized embeddings for customer service queries and responses", | |
| "dimension": 768, | |
| "type": "customer-service", | |
| "note": "Optimized for FAQs, support tickets, and customer interactions" | |
| }, | |
| "sentence-transformers/all-MiniLM-L6-v2-customer-service": { | |
| "description": "Lightweight customer service embeddings", | |
| "dimension": 384, | |
| "type": "customer-service", | |
| "note": "Fast and efficient for real-time customer service applications" | |
| } | |
| } | |
| return info.get(model_name, {"description": "Unknown model", "dimension": 768}) | |
| def get_embedding_dimension(cls, model_name: str) -> int: | |
| """Get embedding dimension for a model. | |
| Args: | |
| model_name: Name of the model | |
| Returns: | |
| Embedding dimension | |
| """ | |
| # Default dimensions (adjust based on actual models) | |
| dimensions = { | |
| "sentence-transformers/all-mpnet-base-v2": 768, | |
| "emilyalsentzer/Bio_ClinicalBERT": 768, | |
| "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": 768, | |
| "sentence-transformers/all-MiniLM-L6-v2": 384, | |
| "sentence-transformers/multilingual-MiniLM-L12-v2": 384, | |
| "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": 384, | |
| "allenai/specter": 768, | |
| "ProsusAI/finbert": 768, | |
| "gemini-embedding-001": 768, | |
| "nlpaueb/legal-bert-base-uncased": 768, | |
| "sentence-transformers/all-mpnet-base-v2-legal": 768, | |
| "sentence-transformers/paraphrase-mpnet-base-v2-customer-service": 768, | |
| "sentence-transformers/all-MiniLM-L6-v2-customer-service": 384 | |
| } | |
| return dimensions.get(model_name, 768) | |