Update modeling_feynmodel.py
Browse files- modeling_feynmodel.py +5 -2
modeling_feynmodel.py
CHANGED
|
@@ -678,7 +678,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
| 678 |
#print(f"+++++++++++++++++ return it causal_mask {causal_mask.size()} !!!!!!!!! attention_mask {attention_mask.size()}")
|
| 679 |
else:
|
| 680 |
#print("+++++++++++++++++++++ else +++++++++++++++++")
|
| 681 |
-
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
|
|
| 682 |
#print(f"++++++++++++++++ causal_mask {causal_mask.size()} ++++++++++++++++++ sequence_length = {sequence_length} ")
|
| 683 |
if sequence_length != 1:
|
| 684 |
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
@@ -1480,9 +1481,11 @@ class FeynModelForCausalLM(Gemma2ForCausalLM):
|
|
| 1480 |
|
| 1481 |
# Vérifier si dtype est un type de données en virgule flottante
|
| 1482 |
if torch.is_floating_point(torch.empty(0, dtype=dtype)):
|
| 1483 |
-
min_dtype = torch.finfo(dtype).min
|
|
|
|
| 1484 |
else:
|
| 1485 |
min_dtype = torch.iinfo(dtype).min
|
|
|
|
| 1486 |
|
| 1487 |
|
| 1488 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
|
|
|
| 678 |
#print(f"+++++++++++++++++ return it causal_mask {causal_mask.size()} !!!!!!!!! attention_mask {attention_mask.size()}")
|
| 679 |
else:
|
| 680 |
#print("+++++++++++++++++++++ else +++++++++++++++++")
|
| 681 |
+
# causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
| 682 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=torch.float32, device=device)
|
| 683 |
#print(f"++++++++++++++++ causal_mask {causal_mask.size()} ++++++++++++++++++ sequence_length = {sequence_length} ")
|
| 684 |
if sequence_length != 1:
|
| 685 |
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
|
|
| 1481 |
|
| 1482 |
# Vérifier si dtype est un type de données en virgule flottante
|
| 1483 |
if torch.is_floating_point(torch.empty(0, dtype=dtype)):
|
| 1484 |
+
# min_dtype = torch.finfo(dtype).min
|
| 1485 |
+
min_dtype = torch.finfo(torch.float32).min
|
| 1486 |
else:
|
| 1487 |
min_dtype = torch.iinfo(dtype).min
|
| 1488 |
+
|
| 1489 |
|
| 1490 |
|
| 1491 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|