Remove stale path: examples/legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py
Browse files
examples/legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
| 1 |
-
# %%
|
| 2 |
-
import torch
|
| 3 |
-
from transformers import Trainer
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
import pyarrow as pa
|
| 7 |
-
import pandas as pd
|
| 8 |
-
import numpy as np
|
| 9 |
-
|
| 10 |
-
from matplotlib import pyplot as plt
|
| 11 |
-
|
| 12 |
-
from torch.utils.data import Dataset
|
| 13 |
-
from transformers import AutoTokenizer, TrainingArguments
|
| 14 |
-
|
| 15 |
-
import argparse
|
| 16 |
-
|
| 17 |
-
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
| 18 |
-
from transformers import AutoTokenizer, TrainingArguments, MambaForCausalLM
|
| 19 |
-
|
| 20 |
-
from dotmap import DotMap
|
| 21 |
-
|
| 22 |
-
import sys
|
| 23 |
-
import os
|
| 24 |
-
import torch
|
| 25 |
-
|
| 26 |
-
sys.path.append("/project/zhiwei/cq5/PythonWorkSpace/gene_mamba")
|
| 27 |
-
|
| 28 |
-
from models import Classifier, GeneMamba, GeneMambaForCellAnnotation, GeneMambaForGeneClassification, GeneMamba2, GeneMamba2ForCellClassification
|
| 29 |
-
from utils import permute_genes_by_expression, build_downstream_dataset, get_last_checkpoint
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
import importlib
|
| 33 |
-
importlib.reload(sys.modules['models'])
|
| 34 |
-
importlib.reload(sys.modules['utils'])
|
| 35 |
-
|
| 36 |
-
# %%
|
| 37 |
-
import scanpy as sc
|
| 38 |
-
|
| 39 |
-
# import argparse
|
| 40 |
-
|
| 41 |
-
# parser = argparse.ArgumentParser()
|
| 42 |
-
# parser.add_argument("--dataset_name", type=str)
|
| 43 |
-
|
| 44 |
-
# args2 = parser.parse_args()
|
| 45 |
-
|
| 46 |
-
# dataset_name = args2.dataset_name
|
| 47 |
-
|
| 48 |
-
dataset_name = "pbmc12k"
|
| 49 |
-
|
| 50 |
-
assert dataset_name in ["pbmc12k", "perirhinal_cortex", "covid19"]
|
| 51 |
-
|
| 52 |
-
adata = sc.read_h5ad(f'/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/downstream/processed/{dataset_name}_processed.h5ad')
|
| 53 |
-
|
| 54 |
-
assert "celltype" in adata.obs
|
| 55 |
-
|
| 56 |
-
print(adata)
|
| 57 |
-
|
| 58 |
-
# %%
|
| 59 |
-
from transformers import PretrainedConfig
|
| 60 |
-
|
| 61 |
-
config = PretrainedConfig.from_dict({
|
| 62 |
-
"d_model": 512,
|
| 63 |
-
"mamba_layer": 24,
|
| 64 |
-
})
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# %%
|
| 68 |
-
model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
|
| 69 |
-
|
| 70 |
-
# %%
|
| 71 |
-
permuted_gene_ids = permute_genes_by_expression(adata, dataset_name, model.tokenizer, model.symbol2id)
|
| 72 |
-
permuted_gene_ids
|
| 73 |
-
|
| 74 |
-
# %%
|
| 75 |
-
num_samples = permuted_gene_ids.shape[0]
|
| 76 |
-
num_avaliable_gpu = torch.cuda.device_count()
|
| 77 |
-
|
| 78 |
-
# %%
|
| 79 |
-
from dotmap import DotMap
|
| 80 |
-
|
| 81 |
-
args = DotMap(
|
| 82 |
-
{
|
| 83 |
-
# "model": "state-spaces/mamba-130m-hf",
|
| 84 |
-
# "tokenizer": "state-spaces/mamba-130m-hf",
|
| 85 |
-
"learning_rate": 5e-5,
|
| 86 |
-
"batch_size": 16,
|
| 87 |
-
"gradient_accumulation_steps": 1,
|
| 88 |
-
"optim": "adamw_torch",
|
| 89 |
-
# "data_path": "/home/cong/study/codeSpace/VSCodeSpace/PythonWorkSpace/TCRPrediction/mamba_transformer/smiles_data.txt",
|
| 90 |
-
# "num_epochs": args2.num_epochs,
|
| 91 |
-
"seq_len": 2048,
|
| 92 |
-
"num_samples": num_samples,
|
| 93 |
-
"num_gpus": num_avaliable_gpu,
|
| 94 |
-
"output_dir": "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned",
|
| 95 |
-
}
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
#%%
|
| 100 |
-
model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
|
| 101 |
-
|
| 102 |
-
model.resize_token_embeddings()
|
| 103 |
-
|
| 104 |
-
#%%
|
| 105 |
-
def get_last_checkpoint(output_dir):
|
| 106 |
-
checkpoints = os.listdir(output_dir)
|
| 107 |
-
checkpoints = [ckpt for ckpt in checkpoints if "checkpoint" in ckpt]
|
| 108 |
-
checkpoints = [int(ckpt.split("-")[1]) for ckpt in checkpoints]
|
| 109 |
-
checkpoints = sorted(checkpoints)
|
| 110 |
-
last_checkpoint = checkpoints[-1]
|
| 111 |
-
last_checkpoint = os.path.join(output_dir, f"checkpoint-{last_checkpoint}")
|
| 112 |
-
return last_checkpoint
|
| 113 |
-
|
| 114 |
-
ckpt_pth = f"/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned/{dataset_name}"
|
| 115 |
-
|
| 116 |
-
last_checkpoint = get_last_checkpoint(ckpt_pth)
|
| 117 |
-
state_dict_pth = os.path.join(last_checkpoint, "model.safetensors")
|
| 118 |
-
|
| 119 |
-
print(state_dict_pth)
|
| 120 |
-
|
| 121 |
-
#%%
|
| 122 |
-
from safetensors.torch import load_file
|
| 123 |
-
|
| 124 |
-
state_dict = load_file(state_dict_pth)
|
| 125 |
-
|
| 126 |
-
model.model.load_state_dict(state_dict)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
# %%
|
| 130 |
-
input_data = permuted_gene_ids[:, :args.seq_len]
|
| 131 |
-
|
| 132 |
-
# %%
|
| 133 |
-
input_data.shape
|
| 134 |
-
|
| 135 |
-
#%%
|
| 136 |
-
# check if cls_token in the tokenizer:
|
| 137 |
-
if model.tokenizer.cls_token_id is None:
|
| 138 |
-
model.tokenizer.add_special_tokens({'cls_token': '[CLS]'})
|
| 139 |
-
|
| 140 |
-
#%%
|
| 141 |
-
input_data = np.hstack([np.array([model.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
|
| 142 |
-
|
| 143 |
-
#%%
|
| 144 |
-
input_data.shape
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
#%%
|
| 148 |
-
from torch.utils.data import DataLoader, Dataset
|
| 149 |
-
|
| 150 |
-
class GeneDataset(Dataset):
|
| 151 |
-
def __init__(self, data):
|
| 152 |
-
self.data = data
|
| 153 |
-
|
| 154 |
-
def __len__(self):
|
| 155 |
-
return len(self.data)
|
| 156 |
-
|
| 157 |
-
def __getitem__(self, idx):
|
| 158 |
-
return self.data[idx]
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
#%%
|
| 162 |
-
all_dataset = GeneDataset(input_data)
|
| 163 |
-
all_loader = DataLoader(all_dataset, batch_size = args.batch_size, shuffle=False)
|
| 164 |
-
|
| 165 |
-
# %%
|
| 166 |
-
def cell_embeddings(data_loader, model):
|
| 167 |
-
|
| 168 |
-
cell_repr = []
|
| 169 |
-
|
| 170 |
-
for i, batch in enumerate(data_loader):
|
| 171 |
-
batch = batch.to(model.device)
|
| 172 |
-
outputs = model(batch)
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
cls_representation = outputs.hidden_states[:, 0, :]
|
| 176 |
-
cell_repr.append(cls_representation.detach().cpu().numpy())
|
| 177 |
-
|
| 178 |
-
if i % 10 == 0:
|
| 179 |
-
print(f"Processed {i} batches")
|
| 180 |
-
|
| 181 |
-
cell_repr = np.concatenate(cell_repr)
|
| 182 |
-
return cell_repr
|
| 183 |
-
|
| 184 |
-
# %%
|
| 185 |
-
model = model.to("cuda")
|
| 186 |
-
model.eval()
|
| 187 |
-
|
| 188 |
-
# %%
|
| 189 |
-
cell_repr = cell_embeddings(all_loader, model)
|
| 190 |
-
cell_repr.shape
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
# cell_repr = np.concatenate(cell_repr)
|
| 194 |
-
# %%
|
| 195 |
-
np.save(f"/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/embeddings/fine-tuned/{dataset_name}_cell_repr.npy", cell_repr)
|
| 196 |
-
|
| 197 |
-
# %%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|