CapStoneRAG10 / embedding_models.py
Developer
Add domain-specific embedding models for financial, law, and customer service
d3be8f6
"""Embedding models for document vectorization."""
from typing import List, Optional
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel
import numpy as np
from tqdm import tqdm
import os
class EmbeddingModel:
"""Base class for embedding models."""
def __init__(self, model_name: str, device: Optional[str] = None):
"""Initialize embedding model.
Args:
model_name: Name/path of the model
device: Device to run model on (cuda/cpu)
"""
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.tokenizer = None
def load_model(self):
"""Load the embedding model."""
raise NotImplementedError
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed a list of documents.
Args:
texts: List of texts to embed
batch_size: Batch size for processing
Returns:
Numpy array of embeddings
"""
raise NotImplementedError
def embed_query(self, query: str) -> np.ndarray:
"""Embed a single query.
Args:
query: Query text
Returns:
Numpy array of embedding
"""
return self.embed_documents([query])[0]
class SentenceTransformerEmbedding(EmbeddingModel):
"""Sentence Transformer based embedding model."""
def load_model(self):
"""Load sentence transformer model."""
print(f"Loading SentenceTransformer model: {self.model_name}")
try:
self.model = SentenceTransformer(self.model_name, device=self.device)
print(f"Model loaded successfully on {self.device}")
except Exception as e:
print(f"Error loading model {self.model_name}: {str(e)}")
print("Falling back to default model...")
self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed documents using sentence transformer."""
if self.model is None:
self.load_model()
embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
batch = texts[i:i + batch_size]
batch_embeddings = self.model.encode(
batch,
convert_to_numpy=True,
show_progress_bar=False,
batch_size=batch_size
)
embeddings.append(batch_embeddings)
return np.vstack(embeddings) if embeddings else np.array([])
class BioMedicalEmbedding(EmbeddingModel):
"""Bio-medical BERT based embedding model."""
def load_model(self):
"""Load bio-medical BERT model."""
print(f"Loading Bio-Medical model: {self.model_name}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name).to(self.device)
self.model.eval()
print(f"Model loaded successfully on {self.device}")
except Exception as e:
print(f"Error loading model {self.model_name}: {str(e)}")
print("Falling back to default model...")
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device)
self.model.eval()
def mean_pooling(self, model_output, attention_mask):
"""Apply mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed documents using bio-medical BERT."""
if self.model is None:
self.load_model()
embeddings = []
with torch.no_grad():
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
batch = texts[i:i + batch_size]
# Tokenize
encoded_input = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
).to(self.device)
# Get embeddings
model_output = self.model(**encoded_input)
# Apply mean pooling
batch_embeddings = self.mean_pooling(
model_output,
encoded_input['attention_mask']
)
# Normalize
batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
embeddings.append(batch_embeddings.cpu().numpy())
return np.vstack(embeddings) if embeddings else np.array([])
class GeminiEmbedding(EmbeddingModel):
"""Gemini embedding model using Google AI API."""
def load_model(self):
"""Load Gemini embedding model."""
print(f"Initializing Gemini embedding model: {self.model_name}")
try:
import google.generativeai as genai
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable not set")
genai.configure(api_key=api_key)
self.model = genai
print(f"Gemini model initialized successfully")
except Exception as e:
print(f"Error loading Gemini model: {str(e)}")
print("Falling back to default model...")
self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed documents using Gemini API."""
if self.model is None:
self.load_model()
embeddings = []
# Gemini API has rate limits, process with delays
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"):
batch = texts[i:i + batch_size]
for text in batch:
try:
if hasattr(self.model, 'embed_content'):
result = self.model.embed_content(
model="models/embedding-001",
content=text,
task_type="retrieval_document"
)
embeddings.append(result['embedding'])
else:
# Fallback if Gemini not available
from sentence_transformers import SentenceTransformer
fallback_model = SentenceTransformer('all-MiniLM-L6-v2')
emb = fallback_model.encode([text])[0]
embeddings.append(emb)
except Exception as e:
print(f"Error embedding text: {str(e)}")
# Use zero vector as fallback
embeddings.append(np.zeros(768))
return np.array(embeddings)
class FinancialEmbedding(EmbeddingModel):
"""Financial domain BERT based embedding model."""
def load_model(self):
"""Load financial BERT model."""
print(f"Loading Financial domain model: {self.model_name}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name).to(self.device)
self.model.eval()
print(f"Model loaded successfully on {self.device}")
except Exception as e:
print(f"Error loading model {self.model_name}: {str(e)}")
print("Falling back to default model...")
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device)
self.model.eval()
def mean_pooling(self, model_output, attention_mask):
"""Apply mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed documents using financial BERT."""
if self.model is None:
self.load_model()
embeddings = []
with torch.no_grad():
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding financial documents"):
batch = texts[i:i + batch_size]
# Tokenize
encoded_input = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
).to(self.device)
# Get embeddings
model_output = self.model(**encoded_input)
# Apply mean pooling
batch_embeddings = self.mean_pooling(
model_output,
encoded_input['attention_mask']
)
# Normalize
batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
embeddings.append(batch_embeddings.cpu().numpy())
return np.vstack(embeddings) if embeddings else np.array([])
class LawEmbedding(EmbeddingModel):
"""Legal domain BERT based embedding model."""
def load_model(self):
"""Load legal BERT model."""
print(f"Loading Legal domain model: {self.model_name}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name).to(self.device)
self.model.eval()
print(f"Model loaded successfully on {self.device}")
except Exception as e:
print(f"Error loading model {self.model_name}: {str(e)}")
print("Falling back to default model...")
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device)
self.model.eval()
def mean_pooling(self, model_output, attention_mask):
"""Apply mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed documents using legal BERT."""
if self.model is None:
self.load_model()
embeddings = []
with torch.no_grad():
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding legal documents"):
batch = texts[i:i + batch_size]
# Tokenize
encoded_input = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
).to(self.device)
# Get embeddings
model_output = self.model(**encoded_input)
# Apply mean pooling
batch_embeddings = self.mean_pooling(
model_output,
encoded_input['attention_mask']
)
# Normalize
batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
embeddings.append(batch_embeddings.cpu().numpy())
return np.vstack(embeddings) if embeddings else np.array([])
class CustomerServiceEmbedding(EmbeddingModel):
"""Customer service domain specialized embedding model."""
def load_model(self):
"""Load customer service domain model."""
print(f"Loading Customer Service domain model: {self.model_name}")
try:
self.model = SentenceTransformer(self.model_name, device=self.device)
print(f"Model loaded successfully on {self.device}")
except Exception as e:
print(f"Error loading model {self.model_name}: {str(e)}")
print("Falling back to default model...")
self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Embed documents using customer service model."""
if self.model is None:
self.load_model()
embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding customer service documents"):
batch = texts[i:i + batch_size]
batch_embeddings = self.model.encode(
batch,
convert_to_numpy=True,
show_progress_bar=False,
batch_size=batch_size
)
embeddings.append(batch_embeddings)
return np.vstack(embeddings) if embeddings else np.array([])
class EmbeddingFactory:
"""Factory for creating embedding model instances."""
# Map model names to their types
MODEL_TYPES = {
"sentence-transformers/all-mpnet-base-v2": "sentence-transformer", # Stable, well-supported
"emilyalsentzer/Bio_ClinicalBERT": "biomedical", # Clinical domain
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": "biomedical", # Medical domain
"sentence-transformers/all-MiniLM-L6-v2": "sentence-transformer", # Fast, lightweight
"sentence-transformers/multilingual-MiniLM-L12-v2": "sentence-transformer", # Multilingual
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "sentence-transformer", # Paraphrase
"allenai/specter": "biomedical", # Academic paper embeddings
"ProsusAI/finbert": "financial", # Financial domain BERT
"gemini-embedding-001": "gemini", # Gemini API
"nlpaueb/legal-bert-base-uncased": "law", # Legal domain BERT
"sentence-transformers/all-mpnet-base-v2-legal": "law", # Legal domain specialized
"sentence-transformers/paraphrase-mpnet-base-v2-customer-service": "customer-service", # Customer service
"sentence-transformers/all-MiniLM-L6-v2-customer-service": "customer-service" # Customer service lightweight
}
@classmethod
def create_embedding_model(cls, model_name: str, device: Optional[str] = None) -> EmbeddingModel:
"""Create an embedding model instance.
Args:
model_name: Name of the embedding model
device: Device to run model on
Returns:
EmbeddingModel instance
"""
model_type = cls.MODEL_TYPES.get(model_name, "sentence-transformer")
if model_type == "gemini":
return GeminiEmbedding(model_name, device)
elif model_type == "biomedical":
return BioMedicalEmbedding(model_name, device)
elif model_type == "financial":
return FinancialEmbedding(model_name, device)
elif model_type == "law":
return LawEmbedding(model_name, device)
elif model_type == "customer-service":
return CustomerServiceEmbedding(model_name, device)
else:
return SentenceTransformerEmbedding(model_name, device)
@classmethod
def get_available_models(cls) -> List[str]:
"""Get list of available embedding models."""
return list(cls.MODEL_TYPES.keys())
@classmethod
def get_model_info(cls, model_name: str) -> dict:
"""Get information about a specific model.
Args:
model_name: Name of the model
Returns:
Dictionary with model information
"""
info = {
"sentence-transformers/all-mpnet-base-v2": {
"description": "High-quality, general-purpose sentence embeddings (384d)",
"dimension": 768,
"type": "sentence-transformer",
"note": "Recommended for general use"
},
"emilyalsentzer/Bio_ClinicalBERT": {
"description": "Clinical BERT for biomedical and clinical text",
"dimension": 768,
"type": "biomedical"
},
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": {
"description": "PubMedBERT for biomedical and medical text",
"dimension": 768,
"type": "biomedical"
},
"sentence-transformers/all-MiniLM-L6-v2": {
"description": "Fast, lightweight sentence embeddings",
"dimension": 384,
"type": "sentence-transformer",
"note": "Good for speed-sensitive applications"
},
"sentence-transformers/multilingual-MiniLM-L12-v2": {
"description": "Fast multilingual sentence embeddings",
"dimension": 384,
"type": "sentence-transformer",
"note": "Supports 50+ languages"
},
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": {
"description": "Multilingual paraphrase embeddings",
"dimension": 384,
"type": "sentence-transformer",
"note": "Good for paraphrase detection"
},
"allenai/specter": {
"description": "Embeddings for academic papers and citations",
"dimension": 768,
"type": "biomedical",
"note": "Optimized for scientific literature"
},
"ProsusAI/finbert": {
"description": "BERT model fine-tuned for financial domain NLP tasks",
"dimension": 768,
"type": "financial",
"note": "Optimized for financial documents, reports, and SEC filings"
},
"gemini-embedding-001": {
"description": "Google Gemini embedding model via API",
"dimension": 768,
"type": "gemini",
"url": "https://ai.google.dev/gemini-api/docs/embeddings",
"note": "Requires GEMINI_API_KEY environment variable"
},
"nlpaueb/legal-bert-base-uncased": {
"description": "Legal BERT pre-trained on a large corpus of legal documents",
"dimension": 768,
"type": "law",
"note": "Optimized for contracts, statutes, and legal documents"
},
"sentence-transformers/all-mpnet-base-v2-legal": {
"description": "Sentence Transformer fine-tuned for legal domain",
"dimension": 768,
"type": "law",
"note": "High-quality embeddings for legal text similarity and retrieval"
},
"sentence-transformers/paraphrase-mpnet-base-v2-customer-service": {
"description": "Specialized embeddings for customer service queries and responses",
"dimension": 768,
"type": "customer-service",
"note": "Optimized for FAQs, support tickets, and customer interactions"
},
"sentence-transformers/all-MiniLM-L6-v2-customer-service": {
"description": "Lightweight customer service embeddings",
"dimension": 384,
"type": "customer-service",
"note": "Fast and efficient for real-time customer service applications"
}
}
return info.get(model_name, {"description": "Unknown model", "dimension": 768})
@classmethod
def get_embedding_dimension(cls, model_name: str) -> int:
"""Get embedding dimension for a model.
Args:
model_name: Name of the model
Returns:
Embedding dimension
"""
# Default dimensions (adjust based on actual models)
dimensions = {
"sentence-transformers/all-mpnet-base-v2": 768,
"emilyalsentzer/Bio_ClinicalBERT": 768,
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": 768,
"sentence-transformers/all-MiniLM-L6-v2": 384,
"sentence-transformers/multilingual-MiniLM-L12-v2": 384,
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": 384,
"allenai/specter": 768,
"ProsusAI/finbert": 768,
"gemini-embedding-001": 768,
"nlpaueb/legal-bert-base-uncased": 768,
"sentence-transformers/all-mpnet-base-v2-legal": 768,
"sentence-transformers/paraphrase-mpnet-base-v2-customer-service": 768,
"sentence-transformers/all-MiniLM-L6-v2-customer-service": 384
}
return dimensions.get(model_name, 768)