Imagroune commited on
Commit
07fc5ac
·
verified ·
1 Parent(s): c20defd

Update modeling_feynmodel.py

Browse files
Files changed (1) hide show
  1. modeling_feynmodel.py +5 -2
modeling_feynmodel.py CHANGED
@@ -678,7 +678,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
678
  #print(f"+++++++++++++++++ return it causal_mask {causal_mask.size()} !!!!!!!!! attention_mask {attention_mask.size()}")
679
  else:
680
  #print("+++++++++++++++++++++ else +++++++++++++++++")
681
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
 
682
  #print(f"++++++++++++++++ causal_mask {causal_mask.size()} ++++++++++++++++++ sequence_length = {sequence_length} ")
683
  if sequence_length != 1:
684
  causal_mask = torch.triu(causal_mask, diagonal=1)
@@ -1480,9 +1481,11 @@ class FeynModelForCausalLM(Gemma2ForCausalLM):
1480
 
1481
  # Vérifier si dtype est un type de données en virgule flottante
1482
  if torch.is_floating_point(torch.empty(0, dtype=dtype)):
1483
- min_dtype = torch.finfo(dtype).min
 
1484
  else:
1485
  min_dtype = torch.iinfo(dtype).min
 
1486
 
1487
 
1488
  attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
 
678
  #print(f"+++++++++++++++++ return it causal_mask {causal_mask.size()} !!!!!!!!! attention_mask {attention_mask.size()}")
679
  else:
680
  #print("+++++++++++++++++++++ else +++++++++++++++++")
681
+ # causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
682
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=torch.float32, device=device)
683
  #print(f"++++++++++++++++ causal_mask {causal_mask.size()} ++++++++++++++++++ sequence_length = {sequence_length} ")
684
  if sequence_length != 1:
685
  causal_mask = torch.triu(causal_mask, diagonal=1)
 
1481
 
1482
  # Vérifier si dtype est un type de données en virgule flottante
1483
  if torch.is_floating_point(torch.empty(0, dtype=dtype)):
1484
+ # min_dtype = torch.finfo(dtype).min
1485
+ min_dtype = torch.finfo(torch.float32).min
1486
  else:
1487
  min_dtype = torch.iinfo(dtype).min
1488
+
1489
 
1490
 
1491
  attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(