Update modeling_feynmodel.py
Browse files- modeling_feynmodel.py +11 -1
modeling_feynmodel.py
CHANGED
|
@@ -1458,7 +1458,17 @@ class FeynModelForCausalLM(Gemma2ForCausalLM):
|
|
| 1458 |
device = input_ids.device
|
| 1459 |
#print(f"22222222 +-+-+-+-+-+-+-+-+-+- sequence_length = input_ids.shape {sequence_length}")
|
| 1460 |
|
| 1461 |
-
dtype = self.lm_head.weight.dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1462 |
min_dtype = torch.finfo(dtype).min
|
| 1463 |
|
| 1464 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
|
|
|
| 1458 |
device = input_ids.device
|
| 1459 |
#print(f"22222222 +-+-+-+-+-+-+-+-+-+- sequence_length = input_ids.shape {sequence_length}")
|
| 1460 |
|
| 1461 |
+
# dtype = self.lm_head.weight.dtype
|
| 1462 |
+
# Obtenir le dtype des poids de lm_head
|
| 1463 |
+
if hasattr(self.lm_head, 'weight'):
|
| 1464 |
+
# Vérifier si weight est un attribut ou une méthode
|
| 1465 |
+
if isinstance(self.lm_head.weight, torch.Tensor):
|
| 1466 |
+
dtype = self.lm_head.weight.dtype
|
| 1467 |
+
elif callable(self.lm_head.weight):
|
| 1468 |
+
dtype = self.lm_head.weight().dtype
|
| 1469 |
+
else:
|
| 1470 |
+
raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
|
| 1471 |
+
|
| 1472 |
min_dtype = torch.finfo(dtype).min
|
| 1473 |
|
| 1474 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|