Fixed bug in gen_attention_mask with len > max_len
#158
by
davidjwen
- opened
geneformer/in_silico_perturber.py
CHANGED
|
@@ -342,7 +342,6 @@ def quant_cos_sims(model,
|
|
| 342 |
max_range = min(i+forward_batch_size, total_batch_length)
|
| 343 |
|
| 344 |
perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
|
| 345 |
-
|
| 346 |
# determine if need to pad or truncate batch
|
| 347 |
minibatch_length_set = set(perturbation_minibatch["length"])
|
| 348 |
minibatch_lengths = perturbation_minibatch["length"]
|
|
@@ -354,12 +353,14 @@ def quant_cos_sims(model,
|
|
| 354 |
|
| 355 |
if needs_pad_or_trunc == True:
|
| 356 |
max_len = min(max(minibatch_length_set),model_input_size)
|
|
|
|
| 357 |
def pad_or_trunc_example(example):
|
| 358 |
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
|
| 359 |
pad_token_id,
|
| 360 |
max_len)
|
| 361 |
return example
|
| 362 |
perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
|
|
|
| 363 |
perturbation_minibatch.set_format(type="torch")
|
| 364 |
|
| 365 |
input_data_minibatch = perturbation_minibatch["input_ids"]
|
|
@@ -570,6 +571,8 @@ def gen_attention_mask(minibatch_encoding, max_len = None):
|
|
| 570 |
original_lens = minibatch_encoding["length"]
|
| 571 |
attention_mask = [[1]*original_len
|
| 572 |
+[0]*(max_len - original_len)
|
|
|
|
|
|
|
| 573 |
for original_len in original_lens]
|
| 574 |
return torch.tensor(attention_mask).to("cuda")
|
| 575 |
|
|
|
|
| 342 |
max_range = min(i+forward_batch_size, total_batch_length)
|
| 343 |
|
| 344 |
perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
|
|
|
|
| 345 |
# determine if need to pad or truncate batch
|
| 346 |
minibatch_length_set = set(perturbation_minibatch["length"])
|
| 347 |
minibatch_lengths = perturbation_minibatch["length"]
|
|
|
|
| 353 |
|
| 354 |
if needs_pad_or_trunc == True:
|
| 355 |
max_len = min(max(minibatch_length_set),model_input_size)
|
| 356 |
+
print(max_len)
|
| 357 |
def pad_or_trunc_example(example):
|
| 358 |
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
|
| 359 |
pad_token_id,
|
| 360 |
max_len)
|
| 361 |
return example
|
| 362 |
perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
| 363 |
+
|
| 364 |
perturbation_minibatch.set_format(type="torch")
|
| 365 |
|
| 366 |
input_data_minibatch = perturbation_minibatch["input_ids"]
|
|
|
|
| 571 |
original_lens = minibatch_encoding["length"]
|
| 572 |
attention_mask = [[1]*original_len
|
| 573 |
+[0]*(max_len - original_len)
|
| 574 |
+
if original_len <= max_len
|
| 575 |
+
else [1]*max_len
|
| 576 |
for original_len in original_lens]
|
| 577 |
return torch.tensor(attention_mask).to("cuda")
|
| 578 |
|