Spaces:
Running
Running
| # embeddings/embedding_models.py | |
| """ | |
| Multiple embedding model implementations | |
| """ | |
| import numpy as np | |
| from typing import List, Optional | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| class EmbeddingManager: | |
| """Manager for multiple embedding models""" | |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2"): | |
| self.model_name = model_name | |
| self.model = None | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the specified embedding model""" | |
| try: | |
| print(f"π§ Loading embedding model: {self.model_name}") | |
| self.model = SentenceTransformer(self.model_name) | |
| print(f"β Model loaded successfully: {self.model_name}") | |
| except Exception as e: | |
| print(f"β Failed to load model {self.model_name}: {e}") | |
| # Fallback to default model | |
| self.model_name = "all-MiniLM-L6-v2" | |
| self.model = SentenceTransformer(self.model_name) | |
| print(f"π Using fallback model: {self.model_name}") | |
| def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Encode texts into embeddings""" | |
| if not self.model: | |
| raise ValueError("Embedding model not loaded") | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| try: | |
| embeddings = self.model.encode( | |
| texts, | |
| batch_size=batch_size, | |
| show_progress_bar=False, | |
| convert_to_numpy=True | |
| ) | |
| return embeddings | |
| except Exception as e: | |
| print(f"β Embedding encoding error: {e}") | |
| raise | |
| def get_embedding_dimensions(self) -> int: | |
| """Get the dimensions of the embeddings""" | |
| # Test encoding to get dimensions | |
| test_embedding = self.encode(["test"]) | |
| return test_embedding.shape[1] | |
| def get_model_info(self) -> dict: | |
| """Get information about the current model""" | |
| return { | |
| "model_name": self.model_name, | |
| "dimensions": self.get_embedding_dimensions(), | |
| "max_sequence_length": getattr(self.model, 'max_seq_length', 512) | |
| } | |
| class MultiEmbeddingManager: | |
| """Manager that can switch between multiple embedding models""" | |
| def __init__(self): | |
| self.models = {} | |
| self.current_model = None | |
| def load_model(self, model_name: str) -> EmbeddingManager: | |
| """Load a specific embedding model""" | |
| if model_name not in self.models: | |
| self.models[model_name] = EmbeddingManager(model_name) | |
| self.current_model = self.models[model_name] | |
| return self.current_model | |
| def get_model(self, model_name: str = None) -> EmbeddingManager: | |
| """Get a model instance""" | |
| if model_name: | |
| return self.load_model(model_name) | |
| elif self.current_model: | |
| return self.current_model | |
| else: | |
| # Load default model | |
| return self.load_model("all-MiniLM-L6-v2") | |
| def list_loaded_models(self) -> List[str]: | |
| """List all currently loaded models""" | |
| return list(self.models.keys()) | |
| # Quick test function | |
| def test_embedding_models(): | |
| """Test all available embedding models""" | |
| from config.vector_config import EMBEDDING_MODELS | |
| multi_manager = MultiEmbeddingManager() | |
| test_texts = [ | |
| "Deep learning for medical image analysis", | |
| "Transformer architectures in genomics", | |
| "AI-driven drug discovery methods" | |
| ] | |
| print("π§ͺ Testing Embedding Models") | |
| print("=" * 50) | |
| for model_name in EMBEDDING_MODELS.keys(): | |
| try: | |
| print(f"\n㪠Testing: {model_name}") | |
| manager = multi_manager.load_model(model_name) | |
| info = manager.get_model_info() | |
| print(f" Dimensions: {info['dimensions']}") | |
| print(f" Max sequence length: {info['max_sequence_length']}") | |
| # Test encoding | |
| embeddings = manager.encode(test_texts) | |
| print(f" Embedding shape: {embeddings.shape}") | |
| print(f" β {model_name} working correctly") | |
| except Exception as e: | |
| print(f" β {model_name} failed: {e}") | |
| if __name__ == "__main__": | |
| test_embedding_models() |