File size: 827 Bytes
888aba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from typing import List
from sentence_transformers import SentenceTransformer

class Encoder():
    def __init__(self):
        print("Loading embedding model...")
        self.model = SentenceTransformer(
            "KaLM-Embedding/KaLM-embedding-multilingual-mini-instruct-v2.5",
            model_kwargs={"attn_implementation": "sdpa"}
        )
        self.model = self.model.half()

    def encode(
            self,
            texts: List[str], 
            batch_size: int = 8,
            show_progress_bar: bool = False,
            save_path: str = None):
        
        embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=show_progress_bar, batch_size=batch_size)
        
        # if save_path:
        #     torch.save(embeddings, save_path)
            
        return embeddings