Update modeling_feynmodel.py
Browse files- modeling_feynmodel.py +15 -1
modeling_feynmodel.py
CHANGED
|
@@ -1469,7 +1469,21 @@ class FeynModelForCausalLM(Gemma2ForCausalLM):
|
|
| 1469 |
else:
|
| 1470 |
raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
|
| 1471 |
|
| 1472 |
-
min_dtype = torch.finfo(dtype).min
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1473 |
|
| 1474 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1475 |
attention_mask,
|
|
|
|
| 1469 |
else:
|
| 1470 |
raise TypeError(f"Type inattendu pour self.lm_head.weight : {type(self.lm_head.weight)}")
|
| 1471 |
|
| 1472 |
+
# min_dtype = torch.finfo(dtype).min
|
| 1473 |
+
# Obtenir le dtype des poids de lm_head
|
| 1474 |
+
if isinstance(self.lm_head, torch.ao.nn.quantized.dynamic.Linear):
|
| 1475 |
+
# Pour les modules quantifiés dynamiquement, utiliser _weight_bias()
|
| 1476 |
+
weight, bias = self.lm_head._weight_bias()
|
| 1477 |
+
dtype = weight.dtype
|
| 1478 |
+
else:
|
| 1479 |
+
dtype = self.lm_head.weight.dtype
|
| 1480 |
+
|
| 1481 |
+
# Vérifier si dtype est un type de données en virgule flottante
|
| 1482 |
+
if torch.is_floating_point(torch.empty(0, dtype=dtype)):
|
| 1483 |
+
min_dtype = torch.finfo(dtype).min
|
| 1484 |
+
else:
|
| 1485 |
+
min_dtype = torch.iinfo(dtype).min
|
| 1486 |
+
|
| 1487 |
|
| 1488 |
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1489 |
attention_mask,
|