|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Contriever(nn.Module): |
|
|
def __init__(self, encoder, emb_dim=768): |
|
|
super(Contriever, self).__init__() |
|
|
self.encoder = encoder |
|
|
self.emb_dim = emb_dim |
|
|
|
|
|
def mean_pooling(self, token_embs, mask): |
|
|
token_embs = token_embs.masked_fill(~mask[..., None].bool(), 0.0) |
|
|
sentence_embeddings = token_embs.sum(dim=1) / (mask.sum(dim=1)[..., None].clamp(min=1e-9)) |
|
|
return sentence_embeddings |
|
|
|
|
|
def encode_seq(self, input_ids, attention_mask, token_type_ids=None): |
|
|
|
|
|
enc = {'input_ids': input_ids, 'attention_mask': attention_mask} |
|
|
if token_type_ids is not None: |
|
|
enc['token_type_ids'] = token_type_ids |
|
|
|
|
|
outputs = self.encoder(**enc) |
|
|
|
|
|
embedded = self.mean_pooling(outputs[0], attention_mask) |
|
|
|
|
|
return embedded |
|
|
|
|
|
def get_text_emb(self, input_ids, attention_mask, token_type_ids): |
|
|
emb = self.encode_seq(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
|
|
|
|
return emb |
|
|
|
|
|
def eval_batch(self, batch): |
|
|
|
|
|
q_emb = self.get_text_emb(batch['q_enc_input_ids'], batch['q_enc_attention_mask'], batch['q_enc_token_type_ids']) |
|
|
|
|
|
|
|
|
c_emb = self.get_text_emb(batch['c_enc_input_ids'], batch['c_enc_attention_mask'], batch['c_enc_token_type_ids']) |
|
|
|
|
|
|
|
|
|
|
|
return q_emb, c_emb |
|
|
|
|
|
def forward(self, batch): |
|
|
|
|
|
|
|
|
q_emb = self.get_text_emb(batch['q_enc_input_ids'], batch['q_enc_attention_mask'], batch['q_enc_token_type_ids']) |
|
|
|
|
|
p_emb = self.get_text_emb(batch['pos_enc_input_ids'], batch['pos_enc_attention_mask'], batch['pos_enc_token_type_ids']) |
|
|
|
|
|
n_emb = self.get_text_emb(batch['neg_enc_input_ids'], batch['neg_enc_attention_mask'], batch['neg_enc_token_type_ids']) |
|
|
|
|
|
|
|
|
|
|
|
return q_emb, p_emb, n_emb |
|
|
|
|
|
|
|
|
|
|
|
class NodeRouter(nn.Module): |
|
|
def __init__(self, input_dim=2, output_dim=1, emb_dim=128): |
|
|
super(NodeRouter, self).__init__() |
|
|
self.fc1 = nn.Linear(input_dim, emb_dim) |
|
|
self.fc2 = nn.Linear(emb_dim, output_dim) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def eval_batch(self, batch): |
|
|
scores_cand = self.fc1(batch['c_scores']) |
|
|
scores_cand = self.relu(scores_cand) |
|
|
scores_cand = self.fc2(scores_cand) |
|
|
scores_cand = self.relu(scores_cand) |
|
|
|
|
|
return scores_cand |
|
|
|
|
|
def forward(self, batch): |
|
|
scores_pos = self.fc1(batch['p_scores']) |
|
|
scores_neg = self.fc1(batch['n_scores']) |
|
|
scores_pos = self.relu(scores_pos) |
|
|
scores_neg = self.relu(scores_neg) |
|
|
|
|
|
scores_pos = self.fc2(scores_pos) |
|
|
scores_neg = self.fc2(scores_neg) |
|
|
scores_pos = self.relu(scores_pos) |
|
|
scores_neg = self.relu(scores_neg) |
|
|
|
|
|
return scores_pos, scores_neg |
|
|
|