CUDA kernels incompatible with standard PyTorch device movement with 4bit/8bit, necessitating device-specific handling
#416
by
madhavanvenkatesh
- opened
- geneformer/perturber_utils.py +60 -70
geneformer/perturber_utils.py
CHANGED
|
@@ -117,83 +117,73 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
|
|
| 117 |
model_type = "MTLCellClassifier"
|
| 118 |
quantize = True
|
| 119 |
|
| 120 |
-
|
| 121 |
-
output_hidden_states = True
|
| 122 |
-
elif mode == "train":
|
| 123 |
-
output_hidden_states = False
|
| 124 |
|
| 125 |
-
|
|
|
|
| 126 |
if model_type == "MTLCellClassifier":
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
"bnb_config": BitsAndBytesConfig(
|
| 130 |
-
load_in_8bit=True,
|
| 131 |
-
),
|
| 132 |
-
}
|
| 133 |
else:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
output_hidden_states=output_hidden_states,
|
| 180 |
-
output_attentions=False,
|
| 181 |
-
quantization_config=quantize["bnb_config"],
|
| 182 |
-
)
|
| 183 |
-
# if eval mode, put the model in eval mode for fwd pass
|
| 184 |
if mode == "eval":
|
| 185 |
model.eval()
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
model = model.to(
|
| 192 |
-
|
|
|
|
| 193 |
model.enable_input_require_grads()
|
| 194 |
-
model = get_peft_model(model,
|
| 195 |
-
return model
|
| 196 |
|
|
|
|
| 197 |
|
| 198 |
def quant_layers(model):
|
| 199 |
layer_nums = []
|
|
|
|
| 117 |
model_type = "MTLCellClassifier"
|
| 118 |
quantize = True
|
| 119 |
|
| 120 |
+
output_hidden_states = (mode == "eval")
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
# Quantization logic
|
| 123 |
+
if quantize:
|
| 124 |
if model_type == "MTLCellClassifier":
|
| 125 |
+
quantize_config = BitsAndBytesConfig(load_in_8bit=True)
|
| 126 |
+
peft_config = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
else:
|
| 128 |
+
quantize_config = BitsAndBytesConfig(
|
| 129 |
+
load_in_4bit=True,
|
| 130 |
+
bnb_4bit_use_double_quant=True,
|
| 131 |
+
bnb_4bit_quant_type="nf4",
|
| 132 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 133 |
+
)
|
| 134 |
+
peft_config = LoraConfig(
|
| 135 |
+
lora_alpha=128,
|
| 136 |
+
lora_dropout=0.1,
|
| 137 |
+
r=64,
|
| 138 |
+
bias="none",
|
| 139 |
+
task_type="TokenClassification",
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
quantize_config = None
|
| 143 |
+
peft_config = None
|
| 144 |
+
|
| 145 |
+
# Model class selection
|
| 146 |
+
model_classes = {
|
| 147 |
+
"Pretrained": BertForMaskedLM,
|
| 148 |
+
"GeneClassifier": BertForTokenClassification,
|
| 149 |
+
"CellClassifier": BertForSequenceClassification,
|
| 150 |
+
"MTLCellClassifier": BertForMaskedLM
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
model_class = model_classes.get(model_type)
|
| 154 |
+
if not model_class:
|
| 155 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 156 |
+
|
| 157 |
+
# Model loading
|
| 158 |
+
model_args = {
|
| 159 |
+
"pretrained_model_name_or_path": model_directory,
|
| 160 |
+
"output_hidden_states": output_hidden_states,
|
| 161 |
+
"output_attentions": False,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
if model_type != "Pretrained":
|
| 165 |
+
model_args["num_labels"] = num_classes
|
| 166 |
+
|
| 167 |
+
if quantize_config:
|
| 168 |
+
model_args["quantization_config"] = quantize_config
|
| 169 |
+
|
| 170 |
+
# Load the model
|
| 171 |
+
model = model_class.from_pretrained(**model_args)
|
| 172 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
if mode == "eval":
|
| 174 |
model.eval()
|
| 175 |
+
|
| 176 |
+
# Handle device placement and PEFT
|
| 177 |
+
if not quantize:
|
| 178 |
+
# Only move non-quantized models
|
| 179 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 180 |
+
model = model.to(device)
|
| 181 |
+
elif peft_config:
|
| 182 |
+
# Apply PEFT for quantized models (except MTLCellClassifier)
|
| 183 |
model.enable_input_require_grads()
|
| 184 |
+
model = get_peft_model(model, peft_config)
|
|
|
|
| 185 |
|
| 186 |
+
return model
|
| 187 |
|
| 188 |
def quant_layers(model):
|
| 189 |
layer_nums = []
|