Question Answering
Transformers
Safetensors
GagaLey's picture
framework
7bf4b88
import torch.nn as nn
# ***** Reranking Model *****
class Contriever(nn.Module):
def __init__(self, encoder, emb_dim=768):
super(Contriever, self).__init__()
self.encoder = encoder
self.emb_dim = emb_dim
def mean_pooling(self, token_embs, mask):
token_embs = token_embs.masked_fill(~mask[..., None].bool(), 0.0)
sentence_embeddings = token_embs.sum(dim=1) / (mask.sum(dim=1)[..., None].clamp(min=1e-9))
return sentence_embeddings
def encode_seq(self, input_ids, attention_mask, token_type_ids=None):
# Combine inputs into a dictionary
enc = {'input_ids': input_ids, 'attention_mask': attention_mask}
if token_type_ids is not None:
enc['token_type_ids'] = token_type_ids
outputs = self.encoder(**enc)
# Mean pooling of last hidden states
embedded = self.mean_pooling(outputs[0], attention_mask)
# print(f"777, {embedded.shape}")
return embedded
def get_text_emb(self, input_ids, attention_mask, token_type_ids):
emb = self.encode_seq(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
return emb
def eval_batch(self, batch):
# q_emb: [batch_size/num_gpus, token_dim]
q_emb = self.get_text_emb(batch['q_enc_input_ids'], batch['q_enc_attention_mask'], batch['q_enc_token_type_ids'])
# c_emb: [batch_size * num_candidates/num_gpus, token_dim]
c_emb = self.get_text_emb(batch['c_enc_input_ids'], batch['c_enc_attention_mask'], batch['c_enc_token_type_ids'])
return q_emb, c_emb
def forward(self, batch):
# q_emb: [batch_size/num_gpus, token_dim]
q_emb = self.get_text_emb(batch['q_enc_input_ids'], batch['q_enc_attention_mask'], batch['q_enc_token_type_ids'])
# p_emb: [batch_size*max_len/num_gpus, token_dim]
p_emb = self.get_text_emb(batch['pos_enc_input_ids'], batch['pos_enc_attention_mask'], batch['pos_enc_token_type_ids'])
# n_emb: [batch_size*max_len*num_sampled_negs/num_gpus, token_dim]
n_emb = self.get_text_emb(batch['neg_enc_input_ids'], batch['neg_enc_attention_mask'], batch['neg_enc_token_type_ids'])
return q_emb, p_emb, n_emb
# ***** Reranking Model *****
class NodeRouter(nn.Module):
def __init__(self, input_dim=2, output_dim=1, emb_dim=128):
super(NodeRouter, self).__init__()
self.fc1 = nn.Linear(input_dim, emb_dim)
self.fc2 = nn.Linear(emb_dim, output_dim)
self.relu = nn.ReLU()
def eval_batch(self, batch):
scores_cand = self.fc1(batch['c_scores'])
scores_cand = self.relu(scores_cand)
scores_cand = self.fc2(scores_cand)
scores_cand = self.relu(scores_cand)
return scores_cand
def forward(self, batch):
scores_pos = self.fc1(batch['p_scores'])
scores_neg = self.fc1(batch['n_scores'])
scores_pos = self.relu(scores_pos)
scores_neg = self.relu(scores_neg)
scores_pos = self.fc2(scores_pos)
scores_neg = self.fc2(scores_neg)
scores_pos = self.relu(scores_pos)
scores_neg = self.relu(scores_neg)
return scores_pos, scores_neg