add flash-attn support
Browse files
configuration_chartmoe.py
CHANGED
|
@@ -53,6 +53,7 @@ class ChartMoEConfig(PretrainedConfig):
|
|
| 53 |
rope_scaling=None,
|
| 54 |
num_experts=4,
|
| 55 |
num_selected=2,
|
|
|
|
| 56 |
**kwargs,
|
| 57 |
):
|
| 58 |
self.num_experts = num_experts
|
|
@@ -77,6 +78,10 @@ class ChartMoEConfig(PretrainedConfig):
|
|
| 77 |
self.rope_theta = rope_theta
|
| 78 |
self.rope_scaling = rope_scaling
|
| 79 |
self._rope_scaling_validation()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
super().__init__(
|
| 81 |
pad_token_id=pad_token_id,
|
| 82 |
bos_token_id=bos_token_id,
|
|
|
|
| 53 |
rope_scaling=None,
|
| 54 |
num_experts=4,
|
| 55 |
num_selected=2,
|
| 56 |
+
attn_implementation=None,
|
| 57 |
**kwargs,
|
| 58 |
):
|
| 59 |
self.num_experts = num_experts
|
|
|
|
| 78 |
self.rope_theta = rope_theta
|
| 79 |
self.rope_scaling = rope_scaling
|
| 80 |
self._rope_scaling_validation()
|
| 81 |
+
|
| 82 |
+
self.attn_implementation = attn_implementation
|
| 83 |
+
if self.attn_implementation is None:
|
| 84 |
+
self.attn_implementation = "eager"
|
| 85 |
super().__init__(
|
| 86 |
pad_token_id=pad_token_id,
|
| 87 |
bos_token_id=bos_token_id,
|