import argparse import json import os import pickle import random import scanpy as sc import torch import torch.nn.functional as F import yaml from sklearn.metrics import accuracy_score from sklearn.metrics import f1_score from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader from src.data.dataset_ds import DatasetMultiPad from src.models.chromfd_mixer import PretrainModelMambaLM from src.utils.model_utils import ModelUtils from src.utils.tb_utils import setup_logging def warmup_lambda(current_step, warmup_steps=1000): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return 1.0 def load_data(file_path): if file_path.endswith('.h5ad'): print(f"Reading h5ad file from {file_path}") adata = sc.read_h5ad(file_path) return adata else: raise ValueError("Unsupported file format. Please provide a .h5ad file.") def init_weight(m): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_normal_(m.weight) torch.nn.init.zeros_(m.bias) class FocalLoss(torch.nn.Module): """ Focal Loss as described in https://arxiv.org/abs/1708.02002 """ def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha # Balance factor self.gamma = gamma # Modulating factor self.reduction = reduction # Reduction method: 'mean', 'sum', 'none' def forward(self, logits, targets): ce_loss = F.cross_entropy(logits, targets, reduction='none') pt = torch.exp(-ce_loss) # Probabilities of the predicted classes focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss class FinetuneModelMambaCellType(PretrainModelMambaLM): def __init__(self, **kwargs): super().__init__(**kwargs) max_length = self.model_args["max_length"] self.post_backbone_dropout = torch.nn.Dropout(p=0.3) self.feature_projection = torch.nn.Sequential( torch.nn.Linear(self.model_args["embedding_dim"], 256), torch.nn.GELU(), torch.nn.Dropout(p=0.3), torch.nn.Linear(256, 1), torch.nn.GELU() ) self.feature_projection.apply(init_weight) in_feature = max_length self.ft_cell_type_projection = torch.nn.Sequential( torch.nn.Linear(in_feature, 1024), torch.nn.GELU(), torch.nn.Dropout(p=0.3), torch.nn.Linear(1024, 512), torch.nn.GELU(), torch.nn.Dropout(p=0.3), torch.nn.Linear(512, 128), torch.nn.GELU(), torch.nn.Dropout(p=0.3), torch.nn.Linear(128, self.model_args["cell_type_num"]) ) self.ft_cell_type_projection.apply(init_weight) for name, param in self.mask_token_prediction.named_parameters(): param.requires_grad = False def forward(self, value, chromosome, hg38_start, hg38_end, **kwargs): x = self.embedding(value, chromosome.long(), hg38_start.long(), hg38_end.long()) x = self.backbone(x) x = self.feature_projection(x) x = torch.squeeze(x, dim=-1) x = self.post_backbone_dropout(x) x_cell_type_prediction = self.ft_cell_type_projection(x) return x_cell_type_prediction def evaluate_finetune_model(model, val_dataloader, criterion, device): model.eval() eval_loss = 0 accuracy = 0 data_shape = 0 eval_steps = 0 eval_f1_score = 0 cell_type_label_list = [] cell_type_pred_list = [] with torch.no_grad(): for val_batch in val_dataloader: value, chromosome, pos_start, pos_end, cell_type = val_batch value = value.to(device) chromosome = chromosome.to(device) cell_type_gpu = cell_type.to(device) pos_start = pos_start.to(device) pos_end = pos_end.to(device) cell_type_output = model(value, chromosome, pos_start, pos_end) tmp_loss_cell_type_prediction = criterion(cell_type_output, cell_type_gpu) cell_type_pred = torch.argmax(cell_type_output, dim=-1) cell_type_label_list.extend(cell_type.detach().cpu().numpy().tolist()) cell_type_pred_list.extend(cell_type_pred.detach().cpu().numpy().tolist()) tmp_f1_score = f1_score(cell_type, cell_type_pred.cpu().numpy(), average='macro') eval_f1_score += tmp_f1_score accuracy += torch.sum(cell_type_pred == cell_type_gpu).item() data_shape += cell_type.size(0) eval_loss += tmp_loss_cell_type_prediction.item() eval_steps += 1 eval_loss = eval_loss / eval_steps eval_loss_tensor = torch.tensor(eval_loss).to(device) eval_loss = eval_loss_tensor.item() eval_f1_score = f1_score(cell_type_label_list, cell_type_pred_list, average='macro') eval_f1_score_tensor = torch.tensor(eval_f1_score).to(device) eval_f1_score = eval_f1_score_tensor.item() accuracy = accuracy_score(cell_type_label_list, cell_type_pred_list) accuracy_tensor = torch.tensor(accuracy).to(device) accuracy = accuracy_tensor.item() return eval_loss, eval_f1_score, accuracy, cell_type_label_list, cell_type_pred_list def cell_type_finetune( model, finetune_args, train_dataloader, val_dataloader, test_dataloader, optimizer, lr_scheduler, device, logger ): model = model.to(device) cell_type_criterion = FocalLoss(alpha=1, gamma=2, reduction='mean') step = 0 best_f1_score = 0.0 for eph in range(finetune_args.get("epoch")): for batch in train_dataloader: model.train() value, chromosome, pos_start, pos_end, cell_type = batch value = value.to(device) chromosome = chromosome.to(device) pos_start = pos_start.to(device) pos_end = pos_end.to(device) cell_type = cell_type.to(device) cell_type_output = model(value, chromosome, pos_start, pos_end) # Compute Focal Loss loss_cell_type_prediction = cell_type_criterion(cell_type_output, cell_type) loss_cell_type_prediction.backward() optimizer.step() optimizer.zero_grad() lr_scheduler.step() if step % finetune_args.get("loss_evaluate", 10) == 0: accuracy = torch.sum(torch.argmax(cell_type_output, dim=-1) == cell_type).item() / cell_type.size(0) logger.info( f"[Train] loss at epoch {eph} step {step}: {loss_cell_type_prediction.item()}, " f"accuracy: {accuracy:.4f}, lr: {optimizer.param_groups[0]['lr']}" ) if step % finetune_args.get("val_evaluate", 10) == 0: eval_loss, eval_f1_score, eval_accuracy, eval_cell_type_label_list, eval_cell_type_pred_list = \ evaluate_finetune_model(model, val_dataloader, cell_type_criterion, device) test_loss, test_f1_score, test_accuracy, eval_cell_type_label_list, eval_cell_type_pred_list = \ evaluate_finetune_model(model, test_dataloader, cell_type_criterion, device) logger.info( f"[Evaluate] loss at epoch {eph} step {step}: {eval_loss}, " f"cell type accuracy: {eval_accuracy:.4f}, f1 score: {eval_f1_score:.4f}, " f"lr: {optimizer.param_groups[0]['lr']:.6f}" ) logger.info( f"[Test] loss at epoch {eph} step {step}: {test_loss}, " f"cell type accuracy: {test_accuracy:.4f}, f1 score: {test_f1_score:.4f}, " f"lr: {optimizer.param_groups[0]['lr']:.6f}" ) if eval_f1_score > best_f1_score: best_f1_score = eval_f1_score with open(os.path.join( finetune_args["log_path"], f"cell_type_label_pred.pkl"), "wb") as f: pickle.dump((eval_cell_type_label_list, eval_cell_type_pred_list), f) logger.info( f"[Evaluate] best validation f1_score: {best_f1_score:.4f} at epoch {eph} step {step}, " f"test accuracy: {test_accuracy:.4f}, f1_score: {test_f1_score:.4f}" ) torch.save(model.state_dict(), os.path.join(finetune_args["log_path"], "best_model.pt")) step += 1 torch.save(model.state_dict(), os.path.join(finetune_args["log_path"], f"epoch_{eph}.pt")) def main_finetune(): parser = argparse.ArgumentParser() parser.add_argument("--local_rank", type=int, help='local rank passed from distributed launcher', default=0) parser.add_argument("--batch_size", type=int, default=16, help="batch size for training") parser.add_argument("--learning_rate", type=float, required=True, help="learning rate for finetune") parser.add_argument("--pretrain_checkpoint_path", type=str, required=True, help="path to pretrain checkpoint") parser.add_argument("--pretrain_model_file", type=str, required=True, help="file name of pre-trained model") parser.add_argument('--pretrain_config_file', type=str, required=True, help='file name of pre-trained config') parser.add_argument("--cell_type_col", type=str, required=True, help="cell type column name") parser.add_argument("--epoch", type=int, required=True, help="epoch for training") parser.add_argument("--train_file_path", type=str, required=True, help="train file path") parser.add_argument("--test_file_path", type=str, required=True, help="validation file path") parser.add_argument("--log_path", type=str, required=True, help="log path") parser.add_argument("--load_pretrain_ckpt", action="store_true", default=True, help="load pre-trained model") args = parser.parse_args() with open(os.path.join(args.pretrain_checkpoint_path, args.pretrain_config_file), 'r') as file: pretrain_config = yaml.safe_load(file) device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu") torch.cuda.set_device(device) pretrain_data_args = pretrain_config["data_args"] pretrain_model_args = pretrain_config["model_args"] log_path = args.log_path chromosome_vocab = ModelUtils.get_chromosome_vocab( os.path.join(args.pretrain_checkpoint_path, "chromosome_vocab.yaml") ) pretrain_data_args["chromosome_vocab"] = chromosome_vocab adata_train_val = load_data(args.train_file_path) adata_test = load_data(args.test_file_path) adata_train_val.obs["tag"] = "train" adata_test.obs["tag"] = "test" adata_concat = sc.AnnData.concatenate(adata_train_val, adata_test) adata_train_val = adata_concat[adata_concat.obs["tag"] == "train"] adata_test = adata_concat[adata_concat.obs["tag"] == "test"] max_length = adata_concat.shape[1] cell_type = list(set(adata_train_val.obs[args.cell_type_col].unique().tolist() + adata_test.obs[ args.cell_type_col].unique().tolist())) cell_type_map = {cell_type: idx for idx, cell_type in enumerate(sorted(cell_type))} if not os.path.exists(log_path): os.mkdir(log_path) os.system(f"cp {os.path.join(args.pretrain_checkpoint_path, args.pretrain_config_file)} {log_path}") os.system(f"cp {os.path.join(args.pretrain_checkpoint_path, 'chromosome_vocab.yaml')} {log_path}") log_file_path = os.path.join(log_path, "finetune.log") finetune_logger = setup_logging(log_file_path) finetune_logger.info('PretrainLogger is configured and ready.') finetune_logger.info(f"args from parser: {args}") finetune_logger.info(f"max length for cell type finetune: {max_length}") with open(os.path.join(log_path, "cell_type_map.json"), "w") as f: json.dump(cell_type_map, f) pretrain_data_args['cell_type_map'] = cell_type_map pretrain_model_args["cell_type_num"] = len(cell_type_map) pretrain_data_args['cell_type_col'] = args.cell_type_col pretrain_data_args["feature_num"] = adata_train_val.shape[1] pretrain_model_args["feature_num"] = adata_train_val.shape[1] pretrain_model_args["batch_size"] = args.batch_size pretrain_data_args["max_length"] = max_length pretrain_model_args["max_length"] = max_length pretrain_model_args["device"] = device pretrain_model_args["mask_ratio"] = 0.0 pretrain_data_args["return_batch_label"] = False idx_list = [i for i in range(adata_train_val.X.shape[0])] random.shuffle(idx_list) split_idx = int(len(idx_list) * 0.9) train_idx = idx_list[:split_idx] val_idx = idx_list[split_idx:] adata_train = adata_train_val[train_idx] adata_val = adata_train_val[val_idx] train_dataset = DatasetMultiPad(*[adata_train], **pretrain_data_args) val_dataset = DatasetMultiPad(*[adata_val], **pretrain_data_args) test_dataset = DatasetMultiPad(*[adata_test], **pretrain_data_args) # Print dataset lengths print(f"Train Dataset Length: {len(train_dataset)}") print(f"Validation Dataset Length: {len(val_dataset)}") print(f"Test Dataset Length: {len(test_dataset)}") train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True ) val_dataloader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True ) test_dataloader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True ) model = FinetuneModelMambaCellType(**pretrain_model_args) model = model.to(device) finetune_logger.info(f'Model parameters: {model}') optimizer_params = { "lr": args.learning_rate, "betas": (0.8, 0.999), "eps": 1e-8, "weight_decay": 1e-6 } optimizer = torch.optim.AdamW(model.parameters(), **optimizer_params) lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: warmup_lambda(step, 200)) if args.load_pretrain_ckpt: state_dict = torch.load(str(os.path.join(args.pretrain_checkpoint_path, args.pretrain_model_file))) missing_keys, unexpected_keys = model.load_state_dict(state_dict['module'], strict=False) if missing_keys: print("Missing keys (not found in checkpoint):") for key in missing_keys: print(f" {key}") if unexpected_keys: print("Unexpected keys (found in checkpoint but not in model):") for key in unexpected_keys: print(f" {key}") finetune_config = { "pretrain_checkpoint_path": args.pretrain_checkpoint_path, "pretrain_model_name": args.pretrain_model_file, "pretrain_config_file": args.pretrain_config_file, "loss_evaluate": 20, "val_evaluate": 20, "log_path": log_path, "epoch": args.epoch, } cell_type_finetune( model, finetune_config, train_dataloader, val_dataloader, test_dataloader, optimizer, lr_scheduler, device, finetune_logger ) if __name__ == '__main__': main_finetune()