File size: 848 Bytes
28b14ff b3e9a96 0cde2a5 28b14ff b3e9a96 28b14ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
# --- Embedding model
EMBEDDING_MODEL_M3 = "BAAI/bge-m3"
EMBEDDING_MODEL_LARGE = "BAAI/bge-large-en-v1.5"
#EMBEDDING_MODEL = "all-MiniLM-L6-v2"
# print(torch.cuda.get_device_name(0))
# print("CUDA available:", torch.cuda.is_available())
# print("Current device:", torch.cuda.current_device())
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
embedding_model_m3 = SentenceTransformer(EMBEDDING_MODEL_M3, device=device)
embedding_model_large = SentenceTransformer(EMBEDDING_MODEL_LARGE, device=device)
embedding_dim_m3 = embedding_model_m3.get_sentence_embedding_dimension()
embedding_dim_large = embedding_model_large.get_sentence_embedding_dimension()
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |