MedSearchPro / embeddings /embedding_models.py
paulhemb's picture
Initial Backend Deployment
1367957
# embeddings/embedding_models.py
"""
Multiple embedding model implementations
"""
import numpy as np
from typing import List, Optional
import torch
from sentence_transformers import SentenceTransformer
class EmbeddingManager:
"""Manager for multiple embedding models"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.model_name = model_name
self.model = None
self._load_model()
def _load_model(self):
"""Load the specified embedding model"""
try:
print(f"πŸ”§ Loading embedding model: {self.model_name}")
self.model = SentenceTransformer(self.model_name)
print(f"βœ… Model loaded successfully: {self.model_name}")
except Exception as e:
print(f"❌ Failed to load model {self.model_name}: {e}")
# Fallback to default model
self.model_name = "all-MiniLM-L6-v2"
self.model = SentenceTransformer(self.model_name)
print(f"πŸ”„ Using fallback model: {self.model_name}")
def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Encode texts into embeddings"""
if not self.model:
raise ValueError("Embedding model not loaded")
if isinstance(texts, str):
texts = [texts]
try:
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False,
convert_to_numpy=True
)
return embeddings
except Exception as e:
print(f"❌ Embedding encoding error: {e}")
raise
def get_embedding_dimensions(self) -> int:
"""Get the dimensions of the embeddings"""
# Test encoding to get dimensions
test_embedding = self.encode(["test"])
return test_embedding.shape[1]
def get_model_info(self) -> dict:
"""Get information about the current model"""
return {
"model_name": self.model_name,
"dimensions": self.get_embedding_dimensions(),
"max_sequence_length": getattr(self.model, 'max_seq_length', 512)
}
class MultiEmbeddingManager:
"""Manager that can switch between multiple embedding models"""
def __init__(self):
self.models = {}
self.current_model = None
def load_model(self, model_name: str) -> EmbeddingManager:
"""Load a specific embedding model"""
if model_name not in self.models:
self.models[model_name] = EmbeddingManager(model_name)
self.current_model = self.models[model_name]
return self.current_model
def get_model(self, model_name: str = None) -> EmbeddingManager:
"""Get a model instance"""
if model_name:
return self.load_model(model_name)
elif self.current_model:
return self.current_model
else:
# Load default model
return self.load_model("all-MiniLM-L6-v2")
def list_loaded_models(self) -> List[str]:
"""List all currently loaded models"""
return list(self.models.keys())
# Quick test function
def test_embedding_models():
"""Test all available embedding models"""
from config.vector_config import EMBEDDING_MODELS
multi_manager = MultiEmbeddingManager()
test_texts = [
"Deep learning for medical image analysis",
"Transformer architectures in genomics",
"AI-driven drug discovery methods"
]
print("πŸ§ͺ Testing Embedding Models")
print("=" * 50)
for model_name in EMBEDDING_MODELS.keys():
try:
print(f"\nπŸ”¬ Testing: {model_name}")
manager = multi_manager.load_model(model_name)
info = manager.get_model_info()
print(f" Dimensions: {info['dimensions']}")
print(f" Max sequence length: {info['max_sequence_length']}")
# Test encoding
embeddings = manager.encode(test_texts)
print(f" Embedding shape: {embeddings.shape}")
print(f" βœ… {model_name} working correctly")
except Exception as e:
print(f" ❌ {model_name} failed: {e}")
if __name__ == "__main__":
test_embedding_models()