Spaces:
Sleeping
Sleeping
File size: 4,416 Bytes
1367957 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# embeddings/embedding_models.py
"""
Multiple embedding model implementations
"""
import numpy as np
from typing import List, Optional
import torch
from sentence_transformers import SentenceTransformer
class EmbeddingManager:
"""Manager for multiple embedding models"""
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.model_name = model_name
self.model = None
self._load_model()
def _load_model(self):
"""Load the specified embedding model"""
try:
print(f"π§ Loading embedding model: {self.model_name}")
self.model = SentenceTransformer(self.model_name)
print(f"β
Model loaded successfully: {self.model_name}")
except Exception as e:
print(f"β Failed to load model {self.model_name}: {e}")
# Fallback to default model
self.model_name = "all-MiniLM-L6-v2"
self.model = SentenceTransformer(self.model_name)
print(f"π Using fallback model: {self.model_name}")
def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Encode texts into embeddings"""
if not self.model:
raise ValueError("Embedding model not loaded")
if isinstance(texts, str):
texts = [texts]
try:
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False,
convert_to_numpy=True
)
return embeddings
except Exception as e:
print(f"β Embedding encoding error: {e}")
raise
def get_embedding_dimensions(self) -> int:
"""Get the dimensions of the embeddings"""
# Test encoding to get dimensions
test_embedding = self.encode(["test"])
return test_embedding.shape[1]
def get_model_info(self) -> dict:
"""Get information about the current model"""
return {
"model_name": self.model_name,
"dimensions": self.get_embedding_dimensions(),
"max_sequence_length": getattr(self.model, 'max_seq_length', 512)
}
class MultiEmbeddingManager:
"""Manager that can switch between multiple embedding models"""
def __init__(self):
self.models = {}
self.current_model = None
def load_model(self, model_name: str) -> EmbeddingManager:
"""Load a specific embedding model"""
if model_name not in self.models:
self.models[model_name] = EmbeddingManager(model_name)
self.current_model = self.models[model_name]
return self.current_model
def get_model(self, model_name: str = None) -> EmbeddingManager:
"""Get a model instance"""
if model_name:
return self.load_model(model_name)
elif self.current_model:
return self.current_model
else:
# Load default model
return self.load_model("all-MiniLM-L6-v2")
def list_loaded_models(self) -> List[str]:
"""List all currently loaded models"""
return list(self.models.keys())
# Quick test function
def test_embedding_models():
"""Test all available embedding models"""
from config.vector_config import EMBEDDING_MODELS
multi_manager = MultiEmbeddingManager()
test_texts = [
"Deep learning for medical image analysis",
"Transformer architectures in genomics",
"AI-driven drug discovery methods"
]
print("π§ͺ Testing Embedding Models")
print("=" * 50)
for model_name in EMBEDDING_MODELS.keys():
try:
print(f"\n㪠Testing: {model_name}")
manager = multi_manager.load_model(model_name)
info = manager.get_model_info()
print(f" Dimensions: {info['dimensions']}")
print(f" Max sequence length: {info['max_sequence_length']}")
# Test encoding
embeddings = manager.encode(test_texts)
print(f" Embedding shape: {embeddings.shape}")
print(f" β
{model_name} working correctly")
except Exception as e:
print(f" β {model_name} failed: {e}")
if __name__ == "__main__":
test_embedding_models() |