CHOPT-NEW / model.py
sxtforreal's picture
Upload 5 files
975624b verified
import lightning.pytorch as pl
from transformers import (
AdamW,
AutoModel,
AutoConfig,
get_linear_schedule_with_warmup,
)
from transformers.models.bert.modeling_bert import BertLMPredictionHead
import torch
from torch import nn
from loss import PCCL
import config
class CL_model(pl.LightningModule):
def __init__(
self, n_batches=None, n_epochs=None, lr=None, mlm_weight=None, **kwargs
):
super().__init__()
## Params
self.n_batches = n_batches
self.n_epochs = n_epochs
self.lr = lr
self.mlm_weight = mlm_weight
self.config = AutoConfig.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
## Encoder
self.bert = AutoModel.from_pretrained(
"emilyalsentzer/Bio_ClinicalBERT", return_dict=True
)
# Unfreeze layers
self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
self.num_unfreeze_layer = self.bert_layer_num
self.ratio_unfreeze_layer = 0.0
if kwargs:
for key, value in kwargs.items():
if key == "unfreeze" and isinstance(value, float):
assert (
value >= 0.0 and value <= 1.0
), "ValueError: value must be a ratio between 0.0 and 1.0"
self.ratio_unfreeze_layer = value
if self.ratio_unfreeze_layer > 0.0:
self.num_unfreeze_layer = int(
self.bert_layer_num * self.ratio_unfreeze_layer
)
for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
param.requires_grad = False
self.lm_head = BertLMPredictionHead(self.config)
# self.projector = nn.Linear(self.bert.config.hidden_size, 128)
print("Model Initialized!")
## Losses
self.cl_loss = PCCL()
self.mlm_loss = nn.CrossEntropyLoss()
## Logs
self.num_batches = 0
self.train_loss, self.val_loss = 0, 0
self.train_loss_cl, self.val_loss_cl = 0, 0
self.train_loss_mlm, self.val_loss_mlm = 0, 0
self.training_step_outputs, self.validation_step_outputs = [], []
def forward(self, input_ids, attention_mask, masked_indices, eval=False):
embs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_tokens = embs.pooler_output
mask_tokens = []
for idx, value in enumerate(masked_indices):
masks = embs.last_hidden_state[idx][value]
avg_mask = torch.mean(masks, dim=0)
mask_tokens.append(avg_mask)
mask_tokens = torch.stack(mask_tokens)
cls_concat_mask = torch.cat((cls_tokens, mask_tokens), dim=1)
if eval is True:
return cls_tokens, mask_tokens, cls_concat_mask
mlm_pred = self.lm_head(embs.last_hidden_state)
mlm_pred = mlm_pred.view(-1, self.config.vocab_size)
return cls_concat_mask, mlm_pred
def training_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
mlm_labels = batch["mlm_labels"]
masked_indices = batch["masked_indices"]
tags = batch["tags"]
scores = batch["scores"]
cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices)
loss_cl = self.cl_loss(cls_concat_mask, tags, scores)
loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1))
loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm
logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm}
self.training_step_outputs.append(logs)
self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
self.num_batches += 1
self.train_loss_cl += loss_cl
self.train_loss_mlm += loss_mlm
self.train_loss += loss
if self.num_batches % config.log_every_n_steps == 0:
avg_loss_cl = self.train_loss_cl / self.num_batches
avg_loss_mlm = self.train_loss_mlm / self.num_batches
avg_loss = self.train_loss / self.num_batches
self.log(
"train_avg_cl_loss",
avg_loss_cl,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
"train_avg_mlm_loss",
avg_loss_mlm,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
"train_avg_loss", avg_loss, prog_bar=True, logger=True, sync_dist=True
)
self.train_loss_cl = 0
self.train_loss_mlm = 0
self.train_loss = 0
self.num_batches = 0
return loss
def on_train_epoch_end(self):
e_t_avg_loss = (
torch.stack([x["loss"] for x in self.training_step_outputs])
.mean()
.detach()
.cpu()
.numpy()
)
self.log(
"avg_loss_train_epoch",
e_t_avg_loss.item(),
on_step=False,
on_epoch=True,
sync_dist=True,
)
e_t_avg_loss_cl = (
torch.stack([x["loss_cl"] for x in self.training_step_outputs])
.mean()
.detach()
.cpu()
.numpy()
)
self.log(
"avg_loss_cl_train_epoch",
e_t_avg_loss_cl.item(),
on_step=False,
on_epoch=True,
sync_dist=True,
)
e_t_avg_loss_mlm = (
torch.stack([x["loss_mlm"] for x in self.training_step_outputs])
.mean()
.detach()
.cpu()
.numpy()
)
self.log(
"avg_loss_mlm_train_epoch",
e_t_avg_loss_mlm.item(),
on_step=False,
on_epoch=True,
sync_dist=True,
)
print(
"train_epoch:",
self.current_epoch,
"avg_loss:",
e_t_avg_loss,
"avg_cl_loss:",
e_t_avg_loss_cl,
"avg_mlm_loss:",
e_t_avg_loss_mlm,
)
self.training_step_outputs.clear()
def validation_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
mlm_labels = batch["mlm_labels"]
masked_indices = batch["masked_indices"]
tags = batch["tags"]
scores = batch["scores"]
cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices)
loss_cl = self.cl_loss(cls_concat_mask, tags, scores)
loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1))
loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm
logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm}
self.validation_step_outputs.append(logs)
self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True)
self.num_batches += 1
self.val_loss_cl += loss_cl
self.val_loss_mlm += loss_mlm
self.val_loss += loss
if self.num_batches % config.log_every_n_steps == 0:
avg_loss_cl = self.val_loss_cl / self.num_batches
avg_loss_mlm = self.val_loss_mlm / self.num_batches
avg_loss = self.val_loss / self.num_batches
self.log(
"val_avg_cl_loss",
avg_loss_cl,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
"val_avg_mlm_loss",
avg_loss_mlm,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.log(
"val_avg_loss",
avg_loss,
prog_bar=True,
logger=True,
sync_dist=True,
)
self.val_loss_cl = 0
self.val_loss_mlm = 0
self.val_loss = 0
self.num_batches = 0
return loss
def on_validation_epoch_end(self):
e_v_avg_loss = (
torch.stack([x["loss"] for x in self.validation_step_outputs])
.mean()
.detach()
.cpu()
.numpy()
)
self.log(
"avg_loss_val_epoch",
e_v_avg_loss.item(),
on_step=False,
on_epoch=True,
sync_dist=True,
)
e_v_avg_loss_cl = (
torch.stack([x["loss_cl"] for x in self.validation_step_outputs])
.mean()
.detach()
.cpu()
.numpy()
)
self.log(
"avg_loss_cl_val_epoch",
e_v_avg_loss_cl.item(),
on_step=False,
on_epoch=True,
sync_dist=True,
)
e_v_avg_loss_mlm = (
torch.stack([x["loss_mlm"] for x in self.validation_step_outputs])
.mean()
.detach()
.cpu()
.numpy()
)
self.log(
"avg_loss_mlm_val_epoch",
e_v_avg_loss_mlm.item(),
on_step=False,
on_epoch=True,
sync_dist=True,
)
print(
"val_epoch:",
self.current_epoch,
"avg_loss:",
e_v_avg_loss,
"avg_cl_loss:",
e_v_avg_loss_cl,
"avg_mlm_loss:",
e_v_avg_loss_mlm,
)
self.validation_step_outputs.clear()
def configure_optimizers(self):
# Optimizer
self.trainable_params = [
param for param in self.parameters() if param.requires_grad
]
optimizer = AdamW(self.trainable_params, lr=self.lr)
# Scheduler
warmup_steps = self.n_batches // 3
total_steps = self.n_batches * self.n_epochs - warmup_steps
scheduler = get_linear_schedule_with_warmup(
optimizer, warmup_steps, total_steps
)
return [optimizer], [scheduler]