Second-Pass / models /hf_models.py
Ujjwal123's picture
Second Pass Model Runs fine
35290d0
import logging
from typing import Tuple
import torch
from torch import Tensor as T
from torch import nn
from transformers import BertConfig, BertModel
from transformers.optimization import AdamW
from transformers import BertTokenizer
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig
import sys
import os
current_dir = os.path.dirname(__file__)
data_utils_path = os.path.join(current_dir, '..')
sys.path.append(data_utils_path)
from Data_utils_inf import Tensorizer
from .biencoder import BiEncoder, DistilBertBiEncoder
logger = logging.getLogger(__name__)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_bert_biencoder_components(args, inference_only: bool = False, **kwargs):
dropout = args.dropout if hasattr(args, "dropout") else 0.0
question_encoder = HFBertEncoder.init_encoder(
args.pretrained_model_cfg,
projection_dim=args.projection_dim,
dropout=dropout,
**kwargs
)
ctx_encoder = HFBertEncoder.init_encoder(
args.pretrained_model_cfg,
projection_dim=args.projection_dim,
dropout=dropout,
**kwargs
)
fix_ctx_encoder = (
args.fix_ctx_encoder if hasattr(args, "fix_ctx_encoder") else False
)
biencoder = BiEncoder(
question_encoder, ctx_encoder, fix_ctx_encoder=fix_ctx_encoder
)
optimizer = (
get_optimizer(
biencoder,
learning_rate=args.learning_rate,
adam_eps=args.adam_eps,
weight_decay=args.weight_decay,
)
if not inference_only
else None
)
tensorizer = get_bert_tensorizer(args)
return tensorizer, biencoder, optimizer
def get_distilbert_biencoder_components(args, inference_only: bool = False, **kwargs):
dropout = args.dropout if hasattr(args, "dropout") else 0.0
question_encoder = HFDistilBertEncoder.init_encoder(
args.pretrained_model_cfg,
projection_dim=args.projection_dim,
dropout=dropout,
**kwargs
)
ctx_encoder = HFDistilBertEncoder.init_encoder(
args.pretrained_model_cfg,
projection_dim=args.projection_dim,
dropout=dropout,
**kwargs
)
fix_ctx_encoder = (
args.fix_ctx_encoder if hasattr(args, "fix_ctx_encoder") else False
)
biencoder = DistilBertBiEncoder(
question_encoder, ctx_encoder, fix_ctx_encoder = fix_ctx_encoder
)
optimizer = (
get_optimizer(
biencoder,
learning_rate=args.learning_rate,
adam_eps=args.adam_eps,
weight_decay=args.weight_decay,
)
if not inference_only
else None
)
tensorizer = get_distilbert_tensorizer(args)
return tensorizer, biencoder, optimizer
def get_bert_tensorizer(args, tokenizer=None):
if not tokenizer:
tokenizer = get_bert_tokenizer(
args.pretrained_model_cfg, do_lower_case=args.do_lower_case
)
return BertTensorizer(tokenizer, args.sequence_length)
def get_distilbert_tensorizer(args, tokenizer=None):
if not tokenizer:
tokenizer = get_distilbert_tokenizer(
args.pretrained_model_cfg, do_lower_case=args.do_lower_case
)
return DistilBertTensorizer(tokenizer, args.sequence_length)
def get_optimizer(
model: nn.Module,
learning_rate: float = 1e-5,
adam_eps: float = 1e-8,
weight_decay: float = 0.0,
) -> torch.optim.Optimizer:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps)
return optimizer
def get_bert_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True):
return BertTokenizer.from_pretrained(
pretrained_cfg_name, do_lower_case=do_lower_case
)
def get_distilbert_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True):
# still uses HF code for tokenizer since they are the same
return DistilBertTokenizer.from_pretrained(
pretrained_cfg_name, do_lower_case=do_lower_case
)
class HFDistilBertEncoder(DistilBertModel):
def __init__(self, config, project_dim: int = 0):
DistilBertModel.__init__(self, config)
assert config.hidden_size > 0, "Encoder hidden_size can't be zero"
self.encode_proj = (
nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None
)
self.init_weights()
@classmethod
def init_encoder(
cls, cfg_name: str, projection_dim: int = 0, dropout: float = 0.1, **kwargs
) -> DistilBertModel:
cfg = DistilBertConfig.from_pretrained(cfg_name if cfg_name else "distilbert-base-uncased")
if dropout != 0:
cfg.attention_probs_dropout_prob = dropout
cfg.hidden_dropout_prob = dropout
return cls.from_pretrained(
cfg_name, config=cfg, project_dim=projection_dim, **kwargs
)
def forward(
self, input_ids: T, attention_mask: T
) -> Tuple[T, ...]:
if self.config.output_hidden_states:
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
)
sequence_output = outputs.last_hidden_state
pooled_output = outputs.last_hidden_state[:, 0, :]
hidden_states = outputs.hidden_states
else:
hidden_states = None
outputs = super().forward(
input_ids = input_ids,
attention_mask = attention_mask,
)
sequence_output = outputs.last_hidden_state
pooled_output = outputs.last_hidden_state[:, 0, :]
if self.encode_proj:
pooled_output = self.encode_proj(pooled_output)
return sequence_output, pooled_output, hidden_states
def get_out_size(self):
if self.encode_proj:
return self.encode_proj.out_features
return self.config.hidden_size
class HFBertEncoder(BertModel):
def __init__(self, config, project_dim: int = 0):
BertModel.__init__(self, config)
assert config.hidden_size > 0, "Encoder hidden_size can't be zero"
self.encode_proj = (
nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None
)
self.init_weights()
@classmethod
def init_encoder(
cls, cfg_name: str, projection_dim: int = 0, dropout: float = 0.1, **kwargs
) -> BertModel:
cfg = BertConfig.from_pretrained(cfg_name if cfg_name else "bert-base-uncased")
if dropout != 0:
cfg.attention_probs_dropout_prob = dropout
cfg.hidden_dropout_prob = dropout
return cls.from_pretrained(
cfg_name, config=cfg, project_dim=projection_dim, **kwargs
)
def forward(
self, input_ids: T, token_type_ids: T, attention_mask: T
) -> Tuple[T, ...]:
if self.config.output_hidden_states:
outputs = super().forward(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
)
sequence_output = outputs.last_hidden_state
pooled_output = outputs.pooler_output
hidden_states = outputs.hidden_states
else:
hidden_states = None
outputs = super().forward(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
)
sequence_output = outputs.last_hidden_state
pooled_output = outputs.pooler_output
if self.encode_proj:
pooled_output = self.encode_proj(pooled_output)
return sequence_output, pooled_output, hidden_states
def get_out_size(self):
if self.encode_proj:
return self.encode_proj.out_features
return self.config.hidden_size
class DistilBertTensorizer(Tensorizer):
def __init__(
self, tokenizer: DistilBertTokenizer, max_length: int, pad_to_max: bool = True
):
self.tokenizer = tokenizer
self.max_length = max_length
self.pad_to_max = pad_to_max
def text_to_tensor(
self, text: str, title: str = None, add_special_tokens: bool = True
):
if isinstance(text, float):
text = 'nan'
text = text.strip()
# tokenizer automatic padding is explicitly disabled since its inconsistent behavior
if title:
token_ids = self.tokenizer.encode(
title,
text_pair = text,
add_special_tokens = add_special_tokens,
max_length = self.max_length,
pad_to_max_length = False,
truncation = True,
)
else:
token_ids = self.tokenizer.encode(
text,
add_special_tokens = add_special_tokens,
max_length = self.max_length,
pad_to_max_length = False,
truncation = True,
)
seq_len = self.max_length
if self.pad_to_max and len(token_ids) < seq_len:
token_ids = token_ids + [self.tokenizer.pad_token_id] * (
seq_len - len(token_ids)
)
if len(token_ids) > seq_len:
token_ids = token_ids[0:seq_len]
token_ids[-1] = self.tokenizer.sep_token_id
return torch.tensor(token_ids)
class BertTensorizer(Tensorizer):
def __init__(
self, tokenizer: BertTokenizer, max_length: int, pad_to_max: bool = True
):
self.tokenizer = tokenizer
self.max_length = max_length
self.pad_to_max = pad_to_max
def text_to_tensor(
self, text: str, title: str = None, add_special_tokens: bool = True
):
if isinstance(text, float):
text = 'nan'
text = text.strip()
# tokenizer automatic padding is explicitly disabled since its inconsistent behavior
if title:
token_ids = self.tokenizer.encode(
title,
text_pair=text,
add_special_tokens=add_special_tokens,
max_length=self.max_length,
pad_to_max_length=False,
truncation=True,
)
else:
token_ids = self.tokenizer.encode(
text,
add_special_tokens=add_special_tokens,
max_length=self.max_length,
pad_to_max_length=False,
truncation=True,
)
seq_len = self.max_length
if self.pad_to_max and len(token_ids) < seq_len:
token_ids = token_ids + [self.tokenizer.pad_token_id] * (
seq_len - len(token_ids)
)
if len(token_ids) > seq_len:
token_ids = token_ids[0:seq_len]
token_ids[-1] = self.tokenizer.sep_token_id
return torch.tensor(token_ids)
def get_pair_separator_ids(self) -> T:
return torch.tensor([self.tokenizer.sep_token_id])
def get_pad_id(self) -> int:
return self.tokenizer.pad_token_id
def get_attn_mask(self, tokens_tensor: T) -> T:
return tokens_tensor != self.get_pad_id()
def is_sub_word_id(self, token_id: int):
token = self.tokenizer.convert_ids_to_tokens([token_id])[0]
return token.startswith("##") or token.startswith(" ##")
def to_string(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=True)
def set_pad_to_max(self, do_pad: bool):
self.pad_to_max = do_pad