import lightning.pytorch as pl from transformers import ( AdamW, AutoModel, AutoConfig, get_linear_schedule_with_warmup, ) from transformers.models.bert.modeling_bert import BertLMPredictionHead import torch from torch import nn from loss import PCCL import config class CL_model(pl.LightningModule): def __init__( self, n_batches=None, n_epochs=None, lr=None, mlm_weight=None, **kwargs ): super().__init__() ## Params self.n_batches = n_batches self.n_epochs = n_epochs self.lr = lr self.mlm_weight = mlm_weight self.config = AutoConfig.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") ## Encoder self.bert = AutoModel.from_pretrained( "emilyalsentzer/Bio_ClinicalBERT", return_dict=True ) # Unfreeze layers self.bert_layer_num = sum(1 for _ in self.bert.named_parameters()) self.num_unfreeze_layer = self.bert_layer_num self.ratio_unfreeze_layer = 0.0 if kwargs: for key, value in kwargs.items(): if key == "unfreeze" and isinstance(value, float): assert ( value >= 0.0 and value <= 1.0 ), "ValueError: value must be a ratio between 0.0 and 1.0" self.ratio_unfreeze_layer = value if self.ratio_unfreeze_layer > 0.0: self.num_unfreeze_layer = int( self.bert_layer_num * self.ratio_unfreeze_layer ) for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]: param.requires_grad = False self.lm_head = BertLMPredictionHead(self.config) # self.projector = nn.Linear(self.bert.config.hidden_size, 128) print("Model Initialized!") ## Losses self.cl_loss = PCCL() self.mlm_loss = nn.CrossEntropyLoss() ## Logs self.num_batches = 0 self.train_loss, self.val_loss = 0, 0 self.train_loss_cl, self.val_loss_cl = 0, 0 self.train_loss_mlm, self.val_loss_mlm = 0, 0 self.training_step_outputs, self.validation_step_outputs = [], [] def forward(self, input_ids, attention_mask, masked_indices, eval=False): embs = self.bert(input_ids=input_ids, attention_mask=attention_mask) cls_tokens = embs.pooler_output mask_tokens = [] for idx, value in enumerate(masked_indices): masks = embs.last_hidden_state[idx][value] avg_mask = torch.mean(masks, dim=0) mask_tokens.append(avg_mask) mask_tokens = torch.stack(mask_tokens) cls_concat_mask = torch.cat((cls_tokens, mask_tokens), dim=1) if eval is True: return cls_tokens, mask_tokens, cls_concat_mask mlm_pred = self.lm_head(embs.last_hidden_state) mlm_pred = mlm_pred.view(-1, self.config.vocab_size) return cls_concat_mask, mlm_pred def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] mlm_labels = batch["mlm_labels"] masked_indices = batch["masked_indices"] tags = batch["tags"] scores = batch["scores"] cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices) loss_cl = self.cl_loss(cls_concat_mask, tags, scores) loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1)) loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm} self.training_step_outputs.append(logs) self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True) self.num_batches += 1 self.train_loss_cl += loss_cl self.train_loss_mlm += loss_mlm self.train_loss += loss if self.num_batches % config.log_every_n_steps == 0: avg_loss_cl = self.train_loss_cl / self.num_batches avg_loss_mlm = self.train_loss_mlm / self.num_batches avg_loss = self.train_loss / self.num_batches self.log( "train_avg_cl_loss", avg_loss_cl, prog_bar=True, logger=True, sync_dist=True, ) self.log( "train_avg_mlm_loss", avg_loss_mlm, prog_bar=True, logger=True, sync_dist=True, ) self.log( "train_avg_loss", avg_loss, prog_bar=True, logger=True, sync_dist=True ) self.train_loss_cl = 0 self.train_loss_mlm = 0 self.train_loss = 0 self.num_batches = 0 return loss def on_train_epoch_end(self): e_t_avg_loss = ( torch.stack([x["loss"] for x in self.training_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.log( "avg_loss_train_epoch", e_t_avg_loss.item(), on_step=False, on_epoch=True, sync_dist=True, ) e_t_avg_loss_cl = ( torch.stack([x["loss_cl"] for x in self.training_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.log( "avg_loss_cl_train_epoch", e_t_avg_loss_cl.item(), on_step=False, on_epoch=True, sync_dist=True, ) e_t_avg_loss_mlm = ( torch.stack([x["loss_mlm"] for x in self.training_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.log( "avg_loss_mlm_train_epoch", e_t_avg_loss_mlm.item(), on_step=False, on_epoch=True, sync_dist=True, ) print( "train_epoch:", self.current_epoch, "avg_loss:", e_t_avg_loss, "avg_cl_loss:", e_t_avg_loss_cl, "avg_mlm_loss:", e_t_avg_loss_mlm, ) self.training_step_outputs.clear() def validation_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] mlm_labels = batch["mlm_labels"] masked_indices = batch["masked_indices"] tags = batch["tags"] scores = batch["scores"] cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices) loss_cl = self.cl_loss(cls_concat_mask, tags, scores) loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1)) loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm} self.validation_step_outputs.append(logs) self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True) self.num_batches += 1 self.val_loss_cl += loss_cl self.val_loss_mlm += loss_mlm self.val_loss += loss if self.num_batches % config.log_every_n_steps == 0: avg_loss_cl = self.val_loss_cl / self.num_batches avg_loss_mlm = self.val_loss_mlm / self.num_batches avg_loss = self.val_loss / self.num_batches self.log( "val_avg_cl_loss", avg_loss_cl, prog_bar=True, logger=True, sync_dist=True, ) self.log( "val_avg_mlm_loss", avg_loss_mlm, prog_bar=True, logger=True, sync_dist=True, ) self.log( "val_avg_loss", avg_loss, prog_bar=True, logger=True, sync_dist=True, ) self.val_loss_cl = 0 self.val_loss_mlm = 0 self.val_loss = 0 self.num_batches = 0 return loss def on_validation_epoch_end(self): e_v_avg_loss = ( torch.stack([x["loss"] for x in self.validation_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.log( "avg_loss_val_epoch", e_v_avg_loss.item(), on_step=False, on_epoch=True, sync_dist=True, ) e_v_avg_loss_cl = ( torch.stack([x["loss_cl"] for x in self.validation_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.log( "avg_loss_cl_val_epoch", e_v_avg_loss_cl.item(), on_step=False, on_epoch=True, sync_dist=True, ) e_v_avg_loss_mlm = ( torch.stack([x["loss_mlm"] for x in self.validation_step_outputs]) .mean() .detach() .cpu() .numpy() ) self.log( "avg_loss_mlm_val_epoch", e_v_avg_loss_mlm.item(), on_step=False, on_epoch=True, sync_dist=True, ) print( "val_epoch:", self.current_epoch, "avg_loss:", e_v_avg_loss, "avg_cl_loss:", e_v_avg_loss_cl, "avg_mlm_loss:", e_v_avg_loss_mlm, ) self.validation_step_outputs.clear() def configure_optimizers(self): # Optimizer self.trainable_params = [ param for param in self.parameters() if param.requires_grad ] optimizer = AdamW(self.trainable_params, lr=self.lr) # Scheduler warmup_steps = self.n_batches // 3 total_steps = self.n_batches * self.n_epochs - warmup_steps scheduler = get_linear_schedule_with_warmup( optimizer, warmup_steps, total_steps ) return [optimizer], [scheduler]