Fix precision error
Browse files- modeling_chatglm.py +27 -8
modeling_chatglm.py
CHANGED
|
@@ -5,7 +5,7 @@ import copy
|
|
| 5 |
import warnings
|
| 6 |
import re
|
| 7 |
import sys
|
| 8 |
-
|
| 9 |
import torch
|
| 10 |
import torch.utils.checkpoint
|
| 11 |
import torch.nn.functional as F
|
|
@@ -177,15 +177,21 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
|
|
| 177 |
|
| 178 |
|
| 179 |
class RMSNorm(torch.nn.Module):
|
| 180 |
-
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
| 181 |
super().__init__()
|
| 182 |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
| 183 |
self.eps = eps
|
|
|
|
| 184 |
|
| 185 |
def forward(self, hidden_states: torch.Tensor):
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
return (self.weight * hidden_states).to(input_dtype)
|
| 191 |
|
|
@@ -515,10 +521,17 @@ class GLMBlock(torch.nn.Module):
|
|
| 515 |
|
| 516 |
self.fp32_residual_connection = config.fp32_residual_connection
|
| 517 |
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
# Layernorm on the input data.
|
| 520 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 521 |
-
|
| 522 |
|
| 523 |
# Self attention.
|
| 524 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
@@ -593,7 +606,13 @@ class GLMTransformer(torch.nn.Module):
|
|
| 593 |
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
| 594 |
|
| 595 |
if self.post_layer_norm:
|
| 596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
# Final layer norm before output.
|
| 598 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 599 |
dtype=config.torch_dtype)
|
|
|
|
| 5 |
import warnings
|
| 6 |
import re
|
| 7 |
import sys
|
| 8 |
+
import functools
|
| 9 |
import torch
|
| 10 |
import torch.utils.checkpoint
|
| 11 |
import torch.nn.functional as F
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
class RMSNorm(torch.nn.Module):
|
| 180 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, quantized=False, **kwargs):
|
| 181 |
super().__init__()
|
| 182 |
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
| 183 |
self.eps = eps
|
| 184 |
+
self.quantized = quantized
|
| 185 |
|
| 186 |
def forward(self, hidden_states: torch.Tensor):
|
| 187 |
+
if not self.quantized:
|
| 188 |
+
norm_x = torch.mean(hidden_states * hidden_states, dim=-1, keepdim=True)
|
| 189 |
+
x_normed = hidden_states * torch.rsqrt(norm_x + self.eps)
|
| 190 |
+
return self.weight * x_normed
|
| 191 |
+
else:
|
| 192 |
+
input_dtype = hidden_states.dtype
|
| 193 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 194 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
| 195 |
|
| 196 |
return (self.weight * hidden_states).to(input_dtype)
|
| 197 |
|
|
|
|
| 521 |
|
| 522 |
self.fp32_residual_connection = config.fp32_residual_connection
|
| 523 |
|
| 524 |
+
if config.rmsnorm:
|
| 525 |
+
if config.quantization_bit != 0:
|
| 526 |
+
LayerNormFunc = functools.partial(RMSNorm, quantized=True)
|
| 527 |
+
else:
|
| 528 |
+
LayerNormFunc = RMSNorm
|
| 529 |
+
else:
|
| 530 |
+
LayerNormFunc = LayerNorm
|
| 531 |
+
|
| 532 |
# Layernorm on the input data.
|
| 533 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 534 |
+
dtype=config.torch_dtype)
|
| 535 |
|
| 536 |
# Self attention.
|
| 537 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
|
|
| 606 |
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
| 607 |
|
| 608 |
if self.post_layer_norm:
|
| 609 |
+
if config.rmsnorm:
|
| 610 |
+
if config.quantization_bit != 0:
|
| 611 |
+
LayerNormFunc = functools.partial(RMSNorm, quantized=True)
|
| 612 |
+
else:
|
| 613 |
+
LayerNormFunc = RMSNorm
|
| 614 |
+
else:
|
| 615 |
+
LayerNormFunc = LayerNorm
|
| 616 |
# Final layer norm before output.
|
| 617 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 618 |
dtype=config.torch_dtype)
|