microsoft/ms_marco
Viewer • Updated • 1.11M • 26.5k • 244
This repository contains a Dual Encoder (Two-Tower) model trained on the Microsoft MS MARCO dataset for information retrieval tasks.
import torch
from model import QryTower, DocTower
# Load the models
embedding_dim = 128
qry_model = QryTower(embedding_dim)
doc_model = DocTower(embedding_dim)
qry_model.load_state_dict(torch.load("qry_tower.pth"))
doc_model.load_state_dict(torch.load("doc_tower.pth"))
# Get embeddings for query and document
query_embedding = qry_model(preprocessed_query)
document_embedding = doc_model(preprocessed_document)
# Calculate similarity
similarity = torch.cosine_similarity(query_embedding, document_embedding)
This model was trained for 5 epochs with a batch size of 32 and learning rate of 0.001.
MIT