|
|
"""Embedding models for document vectorization.""" |
|
|
from typing import List, Optional |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
|
|
|
|
|
|
class EmbeddingModel: |
|
|
"""Base class for embedding models.""" |
|
|
|
|
|
def __init__(self, model_name: str, device: Optional[str] = None): |
|
|
"""Initialize embedding model. |
|
|
|
|
|
Args: |
|
|
model_name: Name/path of the model |
|
|
device: Device to run model on (cuda/cpu) |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
|
|
|
def load_model(self): |
|
|
"""Load the embedding model.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: |
|
|
"""Embed a list of documents. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to embed |
|
|
batch_size: Batch size for processing |
|
|
|
|
|
Returns: |
|
|
Numpy array of embeddings |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def embed_query(self, query: str) -> np.ndarray: |
|
|
"""Embed a single query. |
|
|
|
|
|
Args: |
|
|
query: Query text |
|
|
|
|
|
Returns: |
|
|
Numpy array of embedding |
|
|
""" |
|
|
return self.embed_documents([query])[0] |
|
|
|
|
|
|
|
|
class SentenceTransformerEmbedding(EmbeddingModel): |
|
|
"""Sentence Transformer based embedding model.""" |
|
|
|
|
|
def load_model(self): |
|
|
"""Load sentence transformer model.""" |
|
|
print(f"Loading SentenceTransformer model: {self.model_name}") |
|
|
try: |
|
|
self.model = SentenceTransformer(self.model_name, device=self.device) |
|
|
print(f"Model loaded successfully on {self.device}") |
|
|
except Exception as e: |
|
|
print(f"Error loading model {self.model_name}: {str(e)}") |
|
|
print("Falling back to default model...") |
|
|
self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) |
|
|
|
|
|
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: |
|
|
"""Embed documents using sentence transformer.""" |
|
|
if self.model is None: |
|
|
self.load_model() |
|
|
|
|
|
embeddings = [] |
|
|
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): |
|
|
batch = texts[i:i + batch_size] |
|
|
batch_embeddings = self.model.encode( |
|
|
batch, |
|
|
convert_to_numpy=True, |
|
|
show_progress_bar=False, |
|
|
batch_size=batch_size |
|
|
) |
|
|
embeddings.append(batch_embeddings) |
|
|
|
|
|
return np.vstack(embeddings) if embeddings else np.array([]) |
|
|
|
|
|
|
|
|
class BioMedicalEmbedding(EmbeddingModel): |
|
|
"""Bio-medical BERT based embedding model.""" |
|
|
|
|
|
def load_model(self): |
|
|
"""Load bio-medical BERT model.""" |
|
|
print(f"Loading Bio-Medical model: {self.model_name}") |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
self.model = AutoModel.from_pretrained(self.model_name).to(self.device) |
|
|
self.model.eval() |
|
|
print(f"Model loaded successfully on {self.device}") |
|
|
except Exception as e: |
|
|
print(f"Error loading model {self.model_name}: {str(e)}") |
|
|
print("Falling back to default model...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
|
self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
def mean_pooling(self, model_output, attention_mask): |
|
|
"""Apply mean pooling to get sentence embeddings.""" |
|
|
token_embeddings = model_output[0] |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
|
input_mask_expanded.sum(1), min=1e-9 |
|
|
) |
|
|
|
|
|
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: |
|
|
"""Embed documents using bio-medical BERT.""" |
|
|
if self.model is None: |
|
|
self.load_model() |
|
|
|
|
|
embeddings = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): |
|
|
batch = texts[i:i + batch_size] |
|
|
|
|
|
|
|
|
encoded_input = self.tokenizer( |
|
|
batch, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors='pt' |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
model_output = self.model(**encoded_input) |
|
|
|
|
|
|
|
|
batch_embeddings = self.mean_pooling( |
|
|
model_output, |
|
|
encoded_input['attention_mask'] |
|
|
) |
|
|
|
|
|
|
|
|
batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1) |
|
|
|
|
|
embeddings.append(batch_embeddings.cpu().numpy()) |
|
|
|
|
|
return np.vstack(embeddings) if embeddings else np.array([]) |
|
|
|
|
|
|
|
|
class GeminiEmbedding(EmbeddingModel): |
|
|
"""Gemini embedding model using Google AI API.""" |
|
|
|
|
|
def load_model(self): |
|
|
"""Load Gemini embedding model.""" |
|
|
print(f"Initializing Gemini embedding model: {self.model_name}") |
|
|
try: |
|
|
import google.generativeai as genai |
|
|
api_key = os.getenv("GEMINI_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError("GEMINI_API_KEY environment variable not set") |
|
|
genai.configure(api_key=api_key) |
|
|
self.model = genai |
|
|
print(f"Gemini model initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"Error loading Gemini model: {str(e)}") |
|
|
print("Falling back to default model...") |
|
|
self.model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) |
|
|
|
|
|
def embed_documents(self, texts: List[str], batch_size: int = 32) -> np.ndarray: |
|
|
"""Embed documents using Gemini API.""" |
|
|
if self.model is None: |
|
|
self.load_model() |
|
|
|
|
|
embeddings = [] |
|
|
|
|
|
|
|
|
for i in tqdm(range(0, len(texts), batch_size), desc="Embedding documents"): |
|
|
batch = texts[i:i + batch_size] |
|
|
|
|
|
for text in batch: |
|
|
try: |
|
|
if hasattr(self.model, 'embed_content'): |
|
|
result = self.model.embed_content( |
|
|
model="models/embedding-001", |
|
|
content=text, |
|
|
task_type="retrieval_document" |
|
|
) |
|
|
embeddings.append(result['embedding']) |
|
|
else: |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
fallback_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
emb = fallback_model.encode([text])[0] |
|
|
embeddings.append(emb) |
|
|
except Exception as e: |
|
|
print(f"Error embedding text: {str(e)}") |
|
|
|
|
|
embeddings.append(np.zeros(768)) |
|
|
|
|
|
return np.array(embeddings) |
|
|
|
|
|
|
|
|
class EmbeddingFactory: |
|
|
"""Factory for creating embedding model instances.""" |
|
|
|
|
|
|
|
|
MODEL_TYPES = { |
|
|
"sentence-transformers/all-mpnet-base-v2": "sentence-transformer", |
|
|
"emilyalsentzer/Bio_ClinicalBERT": "biomedical", |
|
|
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": "biomedical", |
|
|
"sentence-transformers/all-MiniLM-L6-v2": "sentence-transformer", |
|
|
"sentence-transformers/multilingual-MiniLM-L12-v2": "sentence-transformer", |
|
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "sentence-transformer", |
|
|
"allenai/specter": "biomedical", |
|
|
"gemini-embedding-001": "gemini" |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def create_embedding_model(cls, model_name: str, device: Optional[str] = None) -> EmbeddingModel: |
|
|
"""Create an embedding model instance. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the embedding model |
|
|
device: Device to run model on |
|
|
|
|
|
Returns: |
|
|
EmbeddingModel instance |
|
|
""" |
|
|
model_type = cls.MODEL_TYPES.get(model_name, "sentence-transformer") |
|
|
|
|
|
if model_type == "gemini": |
|
|
return GeminiEmbedding(model_name, device) |
|
|
elif model_type == "biomedical": |
|
|
return BioMedicalEmbedding(model_name, device) |
|
|
else: |
|
|
return SentenceTransformerEmbedding(model_name, device) |
|
|
|
|
|
@classmethod |
|
|
def get_available_models(cls) -> List[str]: |
|
|
"""Get list of available embedding models.""" |
|
|
return list(cls.MODEL_TYPES.keys()) |
|
|
|
|
|
@classmethod |
|
|
def get_model_info(cls, model_name: str) -> dict: |
|
|
"""Get information about a specific model. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the model |
|
|
|
|
|
Returns: |
|
|
Dictionary with model information |
|
|
""" |
|
|
info = { |
|
|
"sentence-transformers/all-mpnet-base-v2": { |
|
|
"description": "High-quality, general-purpose sentence embeddings (384d)", |
|
|
"dimension": 768, |
|
|
"type": "sentence-transformer", |
|
|
"note": "Recommended for general use" |
|
|
}, |
|
|
"emilyalsentzer/Bio_ClinicalBERT": { |
|
|
"description": "Clinical BERT for biomedical and clinical text", |
|
|
"dimension": 768, |
|
|
"type": "biomedical" |
|
|
}, |
|
|
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": { |
|
|
"description": "PubMedBERT for biomedical and medical text", |
|
|
"dimension": 768, |
|
|
"type": "biomedical" |
|
|
}, |
|
|
"sentence-transformers/all-MiniLM-L6-v2": { |
|
|
"description": "Fast, lightweight sentence embeddings", |
|
|
"dimension": 384, |
|
|
"type": "sentence-transformer", |
|
|
"note": "Good for speed-sensitive applications" |
|
|
}, |
|
|
"sentence-transformers/multilingual-MiniLM-L12-v2": { |
|
|
"description": "Fast multilingual sentence embeddings", |
|
|
"dimension": 384, |
|
|
"type": "sentence-transformer", |
|
|
"note": "Supports 50+ languages" |
|
|
}, |
|
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": { |
|
|
"description": "Multilingual paraphrase embeddings", |
|
|
"dimension": 384, |
|
|
"type": "sentence-transformer", |
|
|
"note": "Good for paraphrase detection" |
|
|
}, |
|
|
"allenai/specter": { |
|
|
"description": "Embeddings for academic papers and citations", |
|
|
"dimension": 768, |
|
|
"type": "biomedical", |
|
|
"note": "Optimized for scientific literature" |
|
|
}, |
|
|
"gemini-embedding-001": { |
|
|
"description": "Google Gemini embedding model via API", |
|
|
"dimension": 768, |
|
|
"type": "gemini", |
|
|
"url": "https://ai.google.dev/gemini-api/docs/embeddings", |
|
|
"note": "Requires GEMINI_API_KEY environment variable" |
|
|
} |
|
|
} |
|
|
return info.get(model_name, {"description": "Unknown model", "dimension": 768}) |
|
|
|
|
|
@classmethod |
|
|
def get_embedding_dimension(cls, model_name: str) -> int: |
|
|
"""Get embedding dimension for a model. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the model |
|
|
|
|
|
Returns: |
|
|
Embedding dimension |
|
|
""" |
|
|
|
|
|
dimensions = { |
|
|
"sentence-transformers/all-mpnet-base-v2": 768, |
|
|
"emilyalsentzer/Bio_ClinicalBERT": 768, |
|
|
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": 768, |
|
|
"sentence-transformers/all-MiniLM-L6-v2": 384, |
|
|
"sentence-transformers/multilingual-MiniLM-L12-v2": 384, |
|
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": 384, |
|
|
"allenai/specter": 768, |
|
|
"gemini-embedding-001": 768 |
|
|
} |
|
|
return dimensions.get(model_name, 768) |
|
|
|