Enable flash_attention_2 support since the underlying Mistral model supports it
#3
by
winglian
- opened
- modeling_eurus_rm.py +2 -0
modeling_eurus_rm.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Optional, List
|
|
| 5 |
|
| 6 |
class EurusRewardModel(PreTrainedModel):
|
| 7 |
config_class = MistralConfig
|
|
|
|
|
|
|
| 8 |
def __init__(self, config):
|
| 9 |
super().__init__(config)
|
| 10 |
self.model = MistralModel(config)
|
|
|
|
| 5 |
|
| 6 |
class EurusRewardModel(PreTrainedModel):
|
| 7 |
config_class = MistralConfig
|
| 8 |
+
_supports_flash_attn_2 = True
|
| 9 |
+
|
| 10 |
def __init__(self, config):
|
| 11 |
super().__init__(config)
|
| 12 |
self.model = MistralModel(config)
|