|
|
import lightning.pytorch as pl |
|
|
from transformers import ( |
|
|
AdamW, |
|
|
AutoModel, |
|
|
AutoConfig, |
|
|
get_linear_schedule_with_warmup, |
|
|
) |
|
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead |
|
|
import torch |
|
|
from torch import nn |
|
|
from loss import PCCL |
|
|
import config |
|
|
|
|
|
|
|
|
class CL_model(pl.LightningModule): |
|
|
def __init__( |
|
|
self, n_batches=None, n_epochs=None, lr=None, mlm_weight=None, **kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.n_batches = n_batches |
|
|
self.n_epochs = n_epochs |
|
|
self.lr = lr |
|
|
self.mlm_weight = mlm_weight |
|
|
self.config = AutoConfig.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
|
|
|
|
|
|
|
self.bert = AutoModel.from_pretrained( |
|
|
"emilyalsentzer/Bio_ClinicalBERT", return_dict=True |
|
|
) |
|
|
|
|
|
self.bert_layer_num = sum(1 for _ in self.bert.named_parameters()) |
|
|
self.num_unfreeze_layer = self.bert_layer_num |
|
|
self.ratio_unfreeze_layer = 0.0 |
|
|
if kwargs: |
|
|
for key, value in kwargs.items(): |
|
|
if key == "unfreeze" and isinstance(value, float): |
|
|
assert ( |
|
|
value >= 0.0 and value <= 1.0 |
|
|
), "ValueError: value must be a ratio between 0.0 and 1.0" |
|
|
self.ratio_unfreeze_layer = value |
|
|
if self.ratio_unfreeze_layer > 0.0: |
|
|
self.num_unfreeze_layer = int( |
|
|
self.bert_layer_num * self.ratio_unfreeze_layer |
|
|
) |
|
|
for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]: |
|
|
param.requires_grad = False |
|
|
|
|
|
self.lm_head = BertLMPredictionHead(self.config) |
|
|
|
|
|
print("Model Initialized!") |
|
|
|
|
|
|
|
|
self.cl_loss = PCCL() |
|
|
self.mlm_loss = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
self.num_batches = 0 |
|
|
self.train_loss, self.val_loss = 0, 0 |
|
|
self.train_loss_cl, self.val_loss_cl = 0, 0 |
|
|
self.train_loss_mlm, self.val_loss_mlm = 0, 0 |
|
|
self.training_step_outputs, self.validation_step_outputs = [], [] |
|
|
|
|
|
def forward(self, input_ids, attention_mask, masked_indices, eval=False): |
|
|
embs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
cls_tokens = embs.pooler_output |
|
|
mask_tokens = [] |
|
|
for idx, value in enumerate(masked_indices): |
|
|
masks = embs.last_hidden_state[idx][value] |
|
|
avg_mask = torch.mean(masks, dim=0) |
|
|
mask_tokens.append(avg_mask) |
|
|
mask_tokens = torch.stack(mask_tokens) |
|
|
cls_concat_mask = torch.cat((cls_tokens, mask_tokens), dim=1) |
|
|
if eval is True: |
|
|
return cls_tokens, mask_tokens, cls_concat_mask |
|
|
|
|
|
mlm_pred = self.lm_head(embs.last_hidden_state) |
|
|
mlm_pred = mlm_pred.view(-1, self.config.vocab_size) |
|
|
return cls_concat_mask, mlm_pred |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
input_ids = batch["input_ids"] |
|
|
attention_mask = batch["attention_mask"] |
|
|
mlm_labels = batch["mlm_labels"] |
|
|
masked_indices = batch["masked_indices"] |
|
|
tags = batch["tags"] |
|
|
scores = batch["scores"] |
|
|
cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices) |
|
|
loss_cl = self.cl_loss(cls_concat_mask, tags, scores) |
|
|
loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1)) |
|
|
loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm |
|
|
logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm} |
|
|
self.training_step_outputs.append(logs) |
|
|
self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True) |
|
|
|
|
|
self.num_batches += 1 |
|
|
self.train_loss_cl += loss_cl |
|
|
self.train_loss_mlm += loss_mlm |
|
|
self.train_loss += loss |
|
|
|
|
|
if self.num_batches % config.log_every_n_steps == 0: |
|
|
avg_loss_cl = self.train_loss_cl / self.num_batches |
|
|
avg_loss_mlm = self.train_loss_mlm / self.num_batches |
|
|
avg_loss = self.train_loss / self.num_batches |
|
|
self.log( |
|
|
"train_avg_cl_loss", |
|
|
avg_loss_cl, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
self.log( |
|
|
"train_avg_mlm_loss", |
|
|
avg_loss_mlm, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
self.log( |
|
|
"train_avg_loss", avg_loss, prog_bar=True, logger=True, sync_dist=True |
|
|
) |
|
|
self.train_loss_cl = 0 |
|
|
self.train_loss_mlm = 0 |
|
|
self.train_loss = 0 |
|
|
self.num_batches = 0 |
|
|
|
|
|
return loss |
|
|
|
|
|
def on_train_epoch_end(self): |
|
|
e_t_avg_loss = ( |
|
|
torch.stack([x["loss"] for x in self.training_step_outputs]) |
|
|
.mean() |
|
|
.detach() |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
self.log( |
|
|
"avg_loss_train_epoch", |
|
|
e_t_avg_loss.item(), |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
e_t_avg_loss_cl = ( |
|
|
torch.stack([x["loss_cl"] for x in self.training_step_outputs]) |
|
|
.mean() |
|
|
.detach() |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
self.log( |
|
|
"avg_loss_cl_train_epoch", |
|
|
e_t_avg_loss_cl.item(), |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
e_t_avg_loss_mlm = ( |
|
|
torch.stack([x["loss_mlm"] for x in self.training_step_outputs]) |
|
|
.mean() |
|
|
.detach() |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
self.log( |
|
|
"avg_loss_mlm_train_epoch", |
|
|
e_t_avg_loss_mlm.item(), |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
print( |
|
|
"train_epoch:", |
|
|
self.current_epoch, |
|
|
"avg_loss:", |
|
|
e_t_avg_loss, |
|
|
"avg_cl_loss:", |
|
|
e_t_avg_loss_cl, |
|
|
"avg_mlm_loss:", |
|
|
e_t_avg_loss_mlm, |
|
|
) |
|
|
self.training_step_outputs.clear() |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
input_ids = batch["input_ids"] |
|
|
attention_mask = batch["attention_mask"] |
|
|
mlm_labels = batch["mlm_labels"] |
|
|
masked_indices = batch["masked_indices"] |
|
|
tags = batch["tags"] |
|
|
scores = batch["scores"] |
|
|
cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices) |
|
|
loss_cl = self.cl_loss(cls_concat_mask, tags, scores) |
|
|
loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1)) |
|
|
loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm |
|
|
logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm} |
|
|
self.validation_step_outputs.append(logs) |
|
|
self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True) |
|
|
|
|
|
self.num_batches += 1 |
|
|
self.val_loss_cl += loss_cl |
|
|
self.val_loss_mlm += loss_mlm |
|
|
self.val_loss += loss |
|
|
|
|
|
if self.num_batches % config.log_every_n_steps == 0: |
|
|
avg_loss_cl = self.val_loss_cl / self.num_batches |
|
|
avg_loss_mlm = self.val_loss_mlm / self.num_batches |
|
|
avg_loss = self.val_loss / self.num_batches |
|
|
self.log( |
|
|
"val_avg_cl_loss", |
|
|
avg_loss_cl, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
self.log( |
|
|
"val_avg_mlm_loss", |
|
|
avg_loss_mlm, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
self.log( |
|
|
"val_avg_loss", |
|
|
avg_loss, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
self.val_loss_cl = 0 |
|
|
self.val_loss_mlm = 0 |
|
|
self.val_loss = 0 |
|
|
self.num_batches = 0 |
|
|
|
|
|
return loss |
|
|
|
|
|
def on_validation_epoch_end(self): |
|
|
e_v_avg_loss = ( |
|
|
torch.stack([x["loss"] for x in self.validation_step_outputs]) |
|
|
.mean() |
|
|
.detach() |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
self.log( |
|
|
"avg_loss_val_epoch", |
|
|
e_v_avg_loss.item(), |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
e_v_avg_loss_cl = ( |
|
|
torch.stack([x["loss_cl"] for x in self.validation_step_outputs]) |
|
|
.mean() |
|
|
.detach() |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
self.log( |
|
|
"avg_loss_cl_val_epoch", |
|
|
e_v_avg_loss_cl.item(), |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
e_v_avg_loss_mlm = ( |
|
|
torch.stack([x["loss_mlm"] for x in self.validation_step_outputs]) |
|
|
.mean() |
|
|
.detach() |
|
|
.cpu() |
|
|
.numpy() |
|
|
) |
|
|
self.log( |
|
|
"avg_loss_mlm_val_epoch", |
|
|
e_v_avg_loss_mlm.item(), |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
sync_dist=True, |
|
|
) |
|
|
print( |
|
|
"val_epoch:", |
|
|
self.current_epoch, |
|
|
"avg_loss:", |
|
|
e_v_avg_loss, |
|
|
"avg_cl_loss:", |
|
|
e_v_avg_loss_cl, |
|
|
"avg_mlm_loss:", |
|
|
e_v_avg_loss_mlm, |
|
|
) |
|
|
self.validation_step_outputs.clear() |
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
self.trainable_params = [ |
|
|
param for param in self.parameters() if param.requires_grad |
|
|
] |
|
|
optimizer = AdamW(self.trainable_params, lr=self.lr) |
|
|
|
|
|
|
|
|
warmup_steps = self.n_batches // 3 |
|
|
total_steps = self.n_batches * self.n_epochs - warmup_steps |
|
|
scheduler = get_linear_schedule_with_warmup( |
|
|
optimizer, warmup_steps, total_steps |
|
|
) |
|
|
return [optimizer], [scheduler] |
|
|
|