Christina Theodoris
commited on
Commit
·
7b591f6
1
Parent(s):
69e6887
add quantization for pretrained model
Browse files
geneformer/in_silico_perturber.py
CHANGED
|
@@ -62,7 +62,7 @@ class InSilicoPerturber:
|
|
| 62 |
"genes_to_perturb": {"all", list},
|
| 63 |
"combos": {0, 1},
|
| 64 |
"anchor_gene": {None, str},
|
| 65 |
-
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"},
|
| 66 |
"num_classes": {int},
|
| 67 |
"emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
|
| 68 |
"cell_emb_style": {"mean_pool"},
|
|
@@ -132,7 +132,7 @@ class InSilicoPerturber:
|
|
| 132 |
| ENSEMBL ID of gene to use as anchor in combination perturbations.
|
| 133 |
| For example, if combos=1 and anchor_gene="ENSG00000148400":
|
| 134 |
| anchor gene will be perturbed in combination with each other gene.
|
| 135 |
-
model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
|
| 136 |
| Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
|
| 137 |
num_classes : int
|
| 138 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
|
|
|
| 62 |
"genes_to_perturb": {"all", list},
|
| 63 |
"combos": {0, 1},
|
| 64 |
"anchor_gene": {None, str},
|
| 65 |
+
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "Pretrained-Quantized", "MTLCellClassifier-Quantized"},
|
| 66 |
"num_classes": {int},
|
| 67 |
"emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
|
| 68 |
"cell_emb_style": {"mean_pool"},
|
|
|
|
| 132 |
| ENSEMBL ID of gene to use as anchor in combination perturbations.
|
| 133 |
| For example, if combos=1 and anchor_gene="ENSG00000148400":
|
| 134 |
| anchor gene will be perturbed in combination with each other gene.
|
| 135 |
+
model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "Pretrained-Quantized", "MTLCellClassifier-Quantized"}
|
| 136 |
| Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
|
| 137 |
num_classes : int
|
| 138 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
geneformer/perturber_utils.py
CHANGED
|
@@ -113,15 +113,22 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
|
|
| 113 |
|
| 114 |
# load model to GPU
|
| 115 |
def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
| 116 |
-
if model_type == "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
model_type = "MTLCellClassifier"
|
| 118 |
quantize = True
|
|
|
|
|
|
|
| 119 |
|
| 120 |
output_hidden_states = (mode == "eval")
|
| 121 |
|
| 122 |
# Quantization logic
|
| 123 |
if quantize:
|
| 124 |
-
if
|
| 125 |
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 126 |
peft_config = None
|
| 127 |
else:
|
|
@@ -179,7 +186,7 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
| 179 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 180 |
model = model.to(device)
|
| 181 |
elif peft_config:
|
| 182 |
-
# Apply PEFT for quantized models (except MTLCellClassifier)
|
| 183 |
model.enable_input_require_grads()
|
| 184 |
model = get_peft_model(model, peft_config)
|
| 185 |
|
|
|
|
| 113 |
|
| 114 |
# load model to GPU
|
| 115 |
def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
| 116 |
+
if model_type == "Pretrained-Quantized":
|
| 117 |
+
inference_only = True
|
| 118 |
+
model_type = "Pretrained"
|
| 119 |
+
quantize = True
|
| 120 |
+
elif model_type == "MTLCellClassifier-Quantized":
|
| 121 |
+
inference_only = True
|
| 122 |
model_type = "MTLCellClassifier"
|
| 123 |
quantize = True
|
| 124 |
+
else:
|
| 125 |
+
inference_only = False
|
| 126 |
|
| 127 |
output_hidden_states = (mode == "eval")
|
| 128 |
|
| 129 |
# Quantization logic
|
| 130 |
if quantize:
|
| 131 |
+
if inference_only:
|
| 132 |
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 133 |
peft_config = None
|
| 134 |
else:
|
|
|
|
| 186 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 187 |
model = model.to(device)
|
| 188 |
elif peft_config:
|
| 189 |
+
# Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
|
| 190 |
model.enable_input_require_grads()
|
| 191 |
model = get_peft_model(model, peft_config)
|
| 192 |
|