LLM_Model / embedding_model_instance.py
shreekantkalwar's picture
ensemble
0cde2a5
raw
history blame
848 Bytes
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
# --- Embedding model
EMBEDDING_MODEL_M3 = "BAAI/bge-m3"
EMBEDDING_MODEL_LARGE = "BAAI/bge-large-en-v1.5"
#EMBEDDING_MODEL = "all-MiniLM-L6-v2"
# print(torch.cuda.get_device_name(0))
# print("CUDA available:", torch.cuda.is_available())
# print("Current device:", torch.cuda.current_device())
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
embedding_model_m3 = SentenceTransformer(EMBEDDING_MODEL_M3, device=device)
embedding_model_large = SentenceTransformer(EMBEDDING_MODEL_LARGE, device=device)
embedding_dim_m3 = embedding_model_m3.get_sentence_embedding_dimension()
embedding_dim_large = embedding_model_large.get_sentence_embedding_dimension()
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")