| |
| |
| |
| |
| @@ -320,9 +320,7 @@ def forward( |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
| |
| - if hidden_states.dtype == torch.float16 and ( |
| - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| - ): |
| + if hidden_states.dtype == torch.float16: |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| |
| |
| |
| |
| @@ -631,9 +631,7 @@ def forward( |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
| |
| - if hidden_states.dtype == torch.float16 and ( |
| - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| - ): |
| + if hidden_states.dtype == torch.float16: |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| |
| |
| |
| |
| @@ -580,9 +580,7 @@ def forward( |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
| |
| - if hidden_states.dtype == torch.float16 and ( |
| - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| - ): |
| + if hidden_states.dtype == torch.float16: |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| |
| |
| |
| |
| @@ -321,9 +321,7 @@ def forward( |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
| |
| - if hidden_states.dtype == torch.float16 and ( |
| - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| - ): |
| + if hidden_states.dtype == torch.float16: |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| |
| |
| |
| |
| @@ -427,9 +427,7 @@ def forward( |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
| |
| - if hidden_states.dtype == torch.float16 and ( |
| - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| - ): |
| + if hidden_states.dtype == torch.float16: |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| |
| |
| |
| |
| @@ -386,9 +386,7 @@ def forward( |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
| |
| - if hidden_states.dtype == torch.float16 and ( |
| - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| - ): |
| + if hidden_states.dtype == torch.float16: |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| |
| |
| |
| |
| @@ -637,9 +637,7 @@ def forward( |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
| |
| - if hidden_states.dtype == torch.float16 and ( |
| - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| - ): |
| + if hidden_states.dtype == torch.float16: |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
|
|