Spaces:
Running
on
Zero
Running
on
Zero
Add rope fp32 (#43)
Browse files* Log model
* Add flag for rope outer in fp32
---------
Co-authored-by: Srini Iyer <sviyer@meta.com>
- bytelatent/base_transformer.py +29 -4
- bytelatent/model/blt.py +3 -5
- bytelatent/model/local_models.py +1 -0
- bytelatent/train.py +1 -0
bytelatent/base_transformer.py
CHANGED
|
@@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel):
|
|
| 45 |
norm_eps: float = 1e-5
|
| 46 |
|
| 47 |
rope_theta: float = 10000.0
|
|
|
|
| 48 |
|
| 49 |
init_base_std: float | None = None
|
| 50 |
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
|
@@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
|
|
| 78 |
)
|
| 79 |
|
| 80 |
|
| 81 |
-
def precompute_freqs_cis(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
"""
|
| 83 |
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 84 |
|
|
@@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
|
| 96 |
"""
|
| 97 |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 98 |
t = torch.arange(end, device=freqs.device)
|
|
|
|
|
|
|
|
|
|
| 99 |
freqs = torch.outer(t, freqs).float()
|
| 100 |
|
| 101 |
cos, sin = freqs.cos(), freqs.sin()
|
|
@@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 232 |
RotaryEmbedding Module
|
| 233 |
"""
|
| 234 |
|
| 235 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
super().__init__()
|
| 237 |
|
| 238 |
self.theta = theta
|
| 239 |
self.head_dim = head_dim
|
| 240 |
self.max_seqlen = max_seqlen
|
|
|
|
| 241 |
|
| 242 |
self.register_buffer(
|
| 243 |
"freqs_cis",
|
| 244 |
-
precompute_freqs_cis(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
persistent=False,
|
| 246 |
)
|
| 247 |
|
| 248 |
def reset_parameters(self):
|
| 249 |
self.freqs_cis[...] = precompute_freqs_cis(
|
| 250 |
-
dim=self.head_dim,
|
|
|
|
|
|
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
def forward(
|
|
@@ -577,6 +601,7 @@ class BaseTransformer(nn.Module):
|
|
| 577 |
theta=args.rope_theta,
|
| 578 |
head_dim=args.head_dim or args.dim // args.n_heads,
|
| 579 |
max_seqlen=args.max_seqlen,
|
|
|
|
| 580 |
)
|
| 581 |
self.eos_id = args.eos_id
|
| 582 |
|
|
|
|
| 45 |
norm_eps: float = 1e-5
|
| 46 |
|
| 47 |
rope_theta: float = 10000.0
|
| 48 |
+
rope_use_fp32_in_outer_product: bool = False
|
| 49 |
|
| 50 |
init_base_std: float | None = None
|
| 51 |
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
|
| 82 |
+
def precompute_freqs_cis(
|
| 83 |
+
dim: int,
|
| 84 |
+
end: int,
|
| 85 |
+
theta: float = 10000.0,
|
| 86 |
+
rope_use_fp32_in_outer_product: bool = False,
|
| 87 |
+
):
|
| 88 |
"""
|
| 89 |
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 90 |
|
|
|
|
| 102 |
"""
|
| 103 |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 104 |
t = torch.arange(end, device=freqs.device)
|
| 105 |
+
if rope_use_fp32_in_outer_product:
|
| 106 |
+
t = t.to(torch.float32)
|
| 107 |
+
|
| 108 |
freqs = torch.outer(t, freqs).float()
|
| 109 |
|
| 110 |
cos, sin = freqs.cos(), freqs.sin()
|
|
|
|
| 241 |
RotaryEmbedding Module
|
| 242 |
"""
|
| 243 |
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
theta: float,
|
| 247 |
+
head_dim: int,
|
| 248 |
+
max_seqlen: int = 1024,
|
| 249 |
+
rope_use_fp32_in_outer_product: bool = False,
|
| 250 |
+
):
|
| 251 |
super().__init__()
|
| 252 |
|
| 253 |
self.theta = theta
|
| 254 |
self.head_dim = head_dim
|
| 255 |
self.max_seqlen = max_seqlen
|
| 256 |
+
self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
|
| 257 |
|
| 258 |
self.register_buffer(
|
| 259 |
"freqs_cis",
|
| 260 |
+
precompute_freqs_cis(
|
| 261 |
+
dim=head_dim,
|
| 262 |
+
end=max_seqlen,
|
| 263 |
+
theta=theta,
|
| 264 |
+
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
|
| 265 |
+
),
|
| 266 |
persistent=False,
|
| 267 |
)
|
| 268 |
|
| 269 |
def reset_parameters(self):
|
| 270 |
self.freqs_cis[...] = precompute_freqs_cis(
|
| 271 |
+
dim=self.head_dim,
|
| 272 |
+
end=self.max_seqlen,
|
| 273 |
+
theta=self.theta,
|
| 274 |
+
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
|
| 275 |
)
|
| 276 |
|
| 277 |
def forward(
|
|
|
|
| 601 |
theta=args.rope_theta,
|
| 602 |
head_dim=args.head_dim or args.dim // args.n_heads,
|
| 603 |
max_seqlen=args.max_seqlen,
|
| 604 |
+
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
| 605 |
)
|
| 606 |
self.eos_id = args.eos_id
|
| 607 |
|
bytelatent/model/blt.py
CHANGED
|
@@ -414,7 +414,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|
| 414 |
patch_in_forward: bool = False
|
| 415 |
|
| 416 |
# Architecture and dimensions
|
| 417 |
-
dim_token: int =
|
| 418 |
dim_global: int = 512
|
| 419 |
dim_local_decoder: int = 512
|
| 420 |
dim_local_encoder: int = 512
|
|
@@ -523,10 +523,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|
| 523 |
use_fsdp: bool = True
|
| 524 |
attn_to_keep: str = "all"
|
| 525 |
|
| 526 |
-
# RoPE parameters
|
| 527 |
-
rope_theta: float = 10000.0
|
| 528 |
-
rope_use_fp32_in_outer_product: bool = False
|
| 529 |
-
|
| 530 |
# Parameter mixing
|
| 531 |
pm_size: int = 0
|
| 532 |
|
|
@@ -619,6 +615,7 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
|
| 619 |
sliding_window=args.local_attention_window_len,
|
| 620 |
use_rope=args.use_rope,
|
| 621 |
rope_theta=args.rope_theta,
|
|
|
|
| 622 |
init_base_std=args.init_base_std,
|
| 623 |
init_std_factor=args.init_std_factor,
|
| 624 |
n_kv_heads=args.n_kv_heads,
|
|
@@ -661,6 +658,7 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
|
|
| 661 |
sliding_window=args.local_attention_window_len,
|
| 662 |
use_rope=args.use_rope,
|
| 663 |
rope_theta=args.rope_theta,
|
|
|
|
| 664 |
init_base_std=args.init_base_std,
|
| 665 |
init_std_factor=args.init_std_factor,
|
| 666 |
n_kv_heads=args.n_kv_heads,
|
|
|
|
| 414 |
patch_in_forward: bool = False
|
| 415 |
|
| 416 |
# Architecture and dimensions
|
| 417 |
+
dim_token: int | None = None
|
| 418 |
dim_global: int = 512
|
| 419 |
dim_local_decoder: int = 512
|
| 420 |
dim_local_encoder: int = 512
|
|
|
|
| 523 |
use_fsdp: bool = True
|
| 524 |
attn_to_keep: str = "all"
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
# Parameter mixing
|
| 527 |
pm_size: int = 0
|
| 528 |
|
|
|
|
| 615 |
sliding_window=args.local_attention_window_len,
|
| 616 |
use_rope=args.use_rope,
|
| 617 |
rope_theta=args.rope_theta,
|
| 618 |
+
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
| 619 |
init_base_std=args.init_base_std,
|
| 620 |
init_std_factor=args.init_std_factor,
|
| 621 |
n_kv_heads=args.n_kv_heads,
|
|
|
|
| 658 |
sliding_window=args.local_attention_window_len,
|
| 659 |
use_rope=args.use_rope,
|
| 660 |
rope_theta=args.rope_theta,
|
| 661 |
+
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
| 662 |
init_base_std=args.init_base_std,
|
| 663 |
init_std_factor=args.init_std_factor,
|
| 664 |
n_kv_heads=args.n_kv_heads,
|
bytelatent/model/local_models.py
CHANGED
|
@@ -86,6 +86,7 @@ class LocalModelBase(nn.Module):
|
|
| 86 |
theta=args.rope_theta,
|
| 87 |
head_dim=args.head_dim or args.dim // args.n_heads,
|
| 88 |
max_seqlen=args.max_seqlen,
|
|
|
|
| 89 |
)
|
| 90 |
self.pos_embeddings = None
|
| 91 |
|
|
|
|
| 86 |
theta=args.rope_theta,
|
| 87 |
head_dim=args.head_dim or args.dim // args.n_heads,
|
| 88 |
max_seqlen=args.max_seqlen,
|
| 89 |
+
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
| 90 |
)
|
| 91 |
self.pos_embeddings = None
|
| 92 |
|
bytelatent/train.py
CHANGED
|
@@ -325,6 +325,7 @@ def train(args: TrainArgs):
|
|
| 325 |
|
| 326 |
# log model size
|
| 327 |
|
|
|
|
| 328 |
logger.info(f"Model size: {model_param_count:,} total parameters")
|
| 329 |
|
| 330 |
gpu_memory_monitor = GPUMemoryMonitor("cuda")
|
|
|
|
| 325 |
|
| 326 |
# log model size
|
| 327 |
|
| 328 |
+
logger.info(model)
|
| 329 |
logger.info(f"Model size: {model_param_count:,} total parameters")
|
| 330 |
|
| 331 |
gpu_memory_monitor = GPUMemoryMonitor("cuda")
|