mineself2016 commited on
Commit
59cb3e3
·
verified ·
1 Parent(s): c3f0155

Remove stale path: examples/legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py

Browse files
examples/legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py DELETED
@@ -1,197 +0,0 @@
1
- # %%
2
- import torch
3
- from transformers import Trainer
4
- import os
5
-
6
- import pyarrow as pa
7
- import pandas as pd
8
- import numpy as np
9
-
10
- from matplotlib import pyplot as plt
11
-
12
- from torch.utils.data import Dataset
13
- from transformers import AutoTokenizer, TrainingArguments
14
-
15
- import argparse
16
-
17
- from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
18
- from transformers import AutoTokenizer, TrainingArguments, MambaForCausalLM
19
-
20
- from dotmap import DotMap
21
-
22
- import sys
23
- import os
24
- import torch
25
-
26
- sys.path.append("/project/zhiwei/cq5/PythonWorkSpace/gene_mamba")
27
-
28
- from models import Classifier, GeneMamba, GeneMambaForCellAnnotation, GeneMambaForGeneClassification, GeneMamba2, GeneMamba2ForCellClassification
29
- from utils import permute_genes_by_expression, build_downstream_dataset, get_last_checkpoint
30
-
31
-
32
- import importlib
33
- importlib.reload(sys.modules['models'])
34
- importlib.reload(sys.modules['utils'])
35
-
36
- # %%
37
- import scanpy as sc
38
-
39
- # import argparse
40
-
41
- # parser = argparse.ArgumentParser()
42
- # parser.add_argument("--dataset_name", type=str)
43
-
44
- # args2 = parser.parse_args()
45
-
46
- # dataset_name = args2.dataset_name
47
-
48
- dataset_name = "pbmc12k"
49
-
50
- assert dataset_name in ["pbmc12k", "perirhinal_cortex", "covid19"]
51
-
52
- adata = sc.read_h5ad(f'/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/downstream/processed/{dataset_name}_processed.h5ad')
53
-
54
- assert "celltype" in adata.obs
55
-
56
- print(adata)
57
-
58
- # %%
59
- from transformers import PretrainedConfig
60
-
61
- config = PretrainedConfig.from_dict({
62
- "d_model": 512,
63
- "mamba_layer": 24,
64
- })
65
-
66
-
67
- # %%
68
- model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
69
-
70
- # %%
71
- permuted_gene_ids = permute_genes_by_expression(adata, dataset_name, model.tokenizer, model.symbol2id)
72
- permuted_gene_ids
73
-
74
- # %%
75
- num_samples = permuted_gene_ids.shape[0]
76
- num_avaliable_gpu = torch.cuda.device_count()
77
-
78
- # %%
79
- from dotmap import DotMap
80
-
81
- args = DotMap(
82
- {
83
- # "model": "state-spaces/mamba-130m-hf",
84
- # "tokenizer": "state-spaces/mamba-130m-hf",
85
- "learning_rate": 5e-5,
86
- "batch_size": 16,
87
- "gradient_accumulation_steps": 1,
88
- "optim": "adamw_torch",
89
- # "data_path": "/home/cong/study/codeSpace/VSCodeSpace/PythonWorkSpace/TCRPrediction/mamba_transformer/smiles_data.txt",
90
- # "num_epochs": args2.num_epochs,
91
- "seq_len": 2048,
92
- "num_samples": num_samples,
93
- "num_gpus": num_avaliable_gpu,
94
- "output_dir": "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned",
95
- }
96
- )
97
-
98
-
99
- #%%
100
- model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
101
-
102
- model.resize_token_embeddings()
103
-
104
- #%%
105
- def get_last_checkpoint(output_dir):
106
- checkpoints = os.listdir(output_dir)
107
- checkpoints = [ckpt for ckpt in checkpoints if "checkpoint" in ckpt]
108
- checkpoints = [int(ckpt.split("-")[1]) for ckpt in checkpoints]
109
- checkpoints = sorted(checkpoints)
110
- last_checkpoint = checkpoints[-1]
111
- last_checkpoint = os.path.join(output_dir, f"checkpoint-{last_checkpoint}")
112
- return last_checkpoint
113
-
114
- ckpt_pth = f"/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned/{dataset_name}"
115
-
116
- last_checkpoint = get_last_checkpoint(ckpt_pth)
117
- state_dict_pth = os.path.join(last_checkpoint, "model.safetensors")
118
-
119
- print(state_dict_pth)
120
-
121
- #%%
122
- from safetensors.torch import load_file
123
-
124
- state_dict = load_file(state_dict_pth)
125
-
126
- model.model.load_state_dict(state_dict)
127
-
128
-
129
- # %%
130
- input_data = permuted_gene_ids[:, :args.seq_len]
131
-
132
- # %%
133
- input_data.shape
134
-
135
- #%%
136
- # check if cls_token in the tokenizer:
137
- if model.tokenizer.cls_token_id is None:
138
- model.tokenizer.add_special_tokens({'cls_token': '[CLS]'})
139
-
140
- #%%
141
- input_data = np.hstack([np.array([model.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
142
-
143
- #%%
144
- input_data.shape
145
-
146
-
147
- #%%
148
- from torch.utils.data import DataLoader, Dataset
149
-
150
- class GeneDataset(Dataset):
151
- def __init__(self, data):
152
- self.data = data
153
-
154
- def __len__(self):
155
- return len(self.data)
156
-
157
- def __getitem__(self, idx):
158
- return self.data[idx]
159
-
160
-
161
- #%%
162
- all_dataset = GeneDataset(input_data)
163
- all_loader = DataLoader(all_dataset, batch_size = args.batch_size, shuffle=False)
164
-
165
- # %%
166
- def cell_embeddings(data_loader, model):
167
-
168
- cell_repr = []
169
-
170
- for i, batch in enumerate(data_loader):
171
- batch = batch.to(model.device)
172
- outputs = model(batch)
173
-
174
-
175
- cls_representation = outputs.hidden_states[:, 0, :]
176
- cell_repr.append(cls_representation.detach().cpu().numpy())
177
-
178
- if i % 10 == 0:
179
- print(f"Processed {i} batches")
180
-
181
- cell_repr = np.concatenate(cell_repr)
182
- return cell_repr
183
-
184
- # %%
185
- model = model.to("cuda")
186
- model.eval()
187
-
188
- # %%
189
- cell_repr = cell_embeddings(all_loader, model)
190
- cell_repr.shape
191
-
192
-
193
- # cell_repr = np.concatenate(cell_repr)
194
- # %%
195
- np.save(f"/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/embeddings/fine-tuned/{dataset_name}_cell_repr.npy", cell_repr)
196
-
197
- # %%