File size: 1,421 Bytes
51349bc
b098829
 
 
40fb62e
51349bc
 
 
 
 
cf0f337
51349bc
 
 
 
b098829
c48531b
51349bc
 
 
 
 
40fb62e
51349bc
 
40fb62e
 
 
 
b098829
40fb62e
b098829
40fb62e
b098829
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import torch
from typing import List
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

class EmbeddingManager:
    def __init__(self, model_name: str = "pritamdeka/S-BioBERT-snli-multinli-stsb"):
        self.model_name = model_name
        self.model = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.load_model()

    def load_model(self):
        print("Loading embedding model:", self.model_name)
        print('Using device', self.device)
        self.model = SentenceTransformer(model_name_or_path=self.model_name, device=self.device)
        print("Model loaded.")

    def get_model(self):
        return self.model

    def embed_texts(self, texts: List[str], batch_size: int = 16) -> np.ndarray:
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")
        
        embeddings = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Embedding texts"):
            batch = texts[i:i + batch_size]
            emb = self.model.encode(batch, batch_size=batch_size, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
            embeddings.extend(emb)
        return np.vstack(embeddings)

    def embed_query(self, text: str) -> np.ndarray:
        return self.model.encode(text, convert_to_numpy=True, normalize_embeddings=True).flatten()