File size: 4,416 Bytes
1367957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# embeddings/embedding_models.py
"""

Multiple embedding model implementations

"""

import numpy as np
from typing import List, Optional
import torch
from sentence_transformers import SentenceTransformer


class EmbeddingManager:
    """Manager for multiple embedding models"""

    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        self.model_name = model_name
        self.model = None
        self._load_model()

    def _load_model(self):
        """Load the specified embedding model"""
        try:
            print(f"πŸ”§ Loading embedding model: {self.model_name}")
            self.model = SentenceTransformer(self.model_name)
            print(f"βœ… Model loaded successfully: {self.model_name}")
        except Exception as e:
            print(f"❌ Failed to load model {self.model_name}: {e}")
            # Fallback to default model
            self.model_name = "all-MiniLM-L6-v2"
            self.model = SentenceTransformer(self.model_name)
            print(f"πŸ”„ Using fallback model: {self.model_name}")

    def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
        """Encode texts into embeddings"""
        if not self.model:
            raise ValueError("Embedding model not loaded")

        if isinstance(texts, str):
            texts = [texts]

        try:
            embeddings = self.model.encode(
                texts,
                batch_size=batch_size,
                show_progress_bar=False,
                convert_to_numpy=True
            )
            return embeddings
        except Exception as e:
            print(f"❌ Embedding encoding error: {e}")
            raise

    def get_embedding_dimensions(self) -> int:
        """Get the dimensions of the embeddings"""
        # Test encoding to get dimensions
        test_embedding = self.encode(["test"])
        return test_embedding.shape[1]

    def get_model_info(self) -> dict:
        """Get information about the current model"""
        return {
            "model_name": self.model_name,
            "dimensions": self.get_embedding_dimensions(),
            "max_sequence_length": getattr(self.model, 'max_seq_length', 512)
        }


class MultiEmbeddingManager:
    """Manager that can switch between multiple embedding models"""

    def __init__(self):
        self.models = {}
        self.current_model = None

    def load_model(self, model_name: str) -> EmbeddingManager:
        """Load a specific embedding model"""
        if model_name not in self.models:
            self.models[model_name] = EmbeddingManager(model_name)

        self.current_model = self.models[model_name]
        return self.current_model

    def get_model(self, model_name: str = None) -> EmbeddingManager:
        """Get a model instance"""
        if model_name:
            return self.load_model(model_name)
        elif self.current_model:
            return self.current_model
        else:
            # Load default model
            return self.load_model("all-MiniLM-L6-v2")

    def list_loaded_models(self) -> List[str]:
        """List all currently loaded models"""
        return list(self.models.keys())


# Quick test function
def test_embedding_models():
    """Test all available embedding models"""
    from config.vector_config import EMBEDDING_MODELS

    multi_manager = MultiEmbeddingManager()
    test_texts = [
        "Deep learning for medical image analysis",
        "Transformer architectures in genomics",
        "AI-driven drug discovery methods"
    ]

    print("πŸ§ͺ Testing Embedding Models")
    print("=" * 50)

    for model_name in EMBEDDING_MODELS.keys():
        try:
            print(f"\nπŸ”¬ Testing: {model_name}")
            manager = multi_manager.load_model(model_name)
            info = manager.get_model_info()

            print(f"   Dimensions: {info['dimensions']}")
            print(f"   Max sequence length: {info['max_sequence_length']}")

            # Test encoding
            embeddings = manager.encode(test_texts)
            print(f"   Embedding shape: {embeddings.shape}")
            print(f"   βœ… {model_name} working correctly")

        except Exception as e:
            print(f"   ❌ {model_name} failed: {e}")


if __name__ == "__main__":
    test_embedding_models()