| """ |
| Gemma 4 E2B β clean PyTorch forward pass (text model only). |
| |
| Architecture: |
| - 35 decoder layers, hidden_size=1536, vocab=262144 |
| - 8 Q heads, 1 KV head (MQA) |
| - Sliding attention layers (0-3, 5-8, 10-13, 15-18, 20-23, 25-28, 30-33): |
| head_dim=256, sliding_window=512, rope_theta=10000 |
| - Full attention layers (every 5th: 4,9,14,19,24,29,34): |
| head_dim=512, partial_rotary_factor=0.25 (only first 128 of 512 dims rotated), |
| rope_theta=1000000 |
| - MLP (all layers): GeGLU, intermediate_size=6144 |
| - Per-layer auxiliary stream (full details below) |
| - layer_scalar: per-layer learned scalar multiplied onto residual contributions |
| - QK RMSNorm before RoPE, attn_scale=1.0 |
| - Final: RMSNorm + tied lm_head + logit softcapping at 30.0 |
| |
| Per-layer auxiliary stream: |
| Model-level (computed once, before all layers): |
| 1. embed_tokens_per_layer(input_ids) β [B, T, 35*256] (vocab lookup) |
| 2. per_layer_model_projection(x_embed) β [B, T, 35*256] (project hiddenβaux) |
| scaled by hidden_size**-0.5 |
| 3. per_layer_projection_norm (RMSNorm(256)) on the projection slice per layer |
| 4. Combine: per_layer_inputs = (embed_aux + proj_aux) * (1/sqrt(2)) |
| reshaped to [B, T, 35, 256] |
| |
| Per-layer (at layer i): |
| per_layer_input_i = per_layer_inputs[:, :, i, :] # [B, T, 256] |
| x_normed = input_layernorm(x) |
| gate = sigmoid(per_layer_input_gate(x_normed)) # [B, T, 256] |
| gated = gate * per_layer_input_i # [B, T, 256] |
| out = per_layer_projection(gated) # [B, T, 1536] (256β1536) |
| x = x + post_per_layer_input_norm(out) |
| |
| Weight shapes in checkpoint: |
| per_layer_model_projection.weight : [8960, 1536] (Linear 1536β8960) |
| per_layer_projection_norm.weight : [256] (RMSNorm on 256-dim slices) |
| layers.i.per_layer_input_gate.weight : [256, 1536] (Linear 1536β256) |
| layers.i.per_layer_projection.weight : [1536, 256] (Linear 256β1536) |
| layers.i.post_per_layer_input_norm.weight : [1536] (RMSNorm on 1536-dim output) |
| """ |
|
|
| import math |
| import os |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from safetensors import safe_open |
| from transformers import AutoTokenizer |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.bfloat16 |
|
|
| |
| |
| |
| _HUB_ROOT = Path(os.path.expanduser("~/.cache/huggingface/hub")) |
| _REPO_CANDIDATES = ( |
| os.environ.get("GEMMA4_HF_REPO", ""), |
| "gg-hf-gg/gemma-4-E2B", |
| "google/gemma-4-e2b-it", |
| ) |
|
|
|
|
| def _resolve_model_paths(): |
| """Return (snapshot_dir, safetensors_path). Picks first available repo+snapshot |
| that actually contains a .safetensors file. Iterates ALL snapshots per repo |
| before moving to the next repo β iterdir() order is not deterministic and HF |
| may keep multiple snapshots where only one has weights blob-resolved. |
| """ |
| for repo in _REPO_CANDIDATES: |
| if not repo: |
| continue |
| repo_cache = _HUB_ROOT / ("models--" + repo.replace("/", "--")) |
| snap_root = repo_cache / "snapshots" |
| if not snap_root.is_dir(): |
| continue |
| for snap in sorted(p for p in snap_root.iterdir() if p.is_dir()): |
| |
| sft = snap / "model.safetensors" |
| if not sft.exists(): |
| candidates = sorted(snap.glob("*.safetensors")) |
| if not candidates: |
| continue |
| sft = candidates[0] |
| return snap, sft |
| raise FileNotFoundError( |
| "No Gemma-4 E2B HF cache found. Tried: " + ", ".join(r for r in _REPO_CANDIDATES if r) |
| + ". Run `hf download google/gemma-4-e2b-it` or set GEMMA4_HF_REPO." |
| ) |
|
|
|
|
| MODEL_DIR, SAFETENSORS_BLOB = _resolve_model_paths() |
|
|
| |
| N_LAYERS = 35 |
| HIDDEN_SIZE = 1536 |
| VOCAB_SIZE = 262144 |
| N_Q_HEADS = 8 |
| N_KV_HEADS = 1 |
| HEAD_DIM_SLIDE = 256 |
| HEAD_DIM_FULL = 512 |
| PER_LAYER_DIM = 256 |
| INTERMEDIATE = 6144 |
| INTERMEDIATE_WIDE = 12288 |
| |
| DOUBLE_WIDE_START = 15 |
| SLIDING_WINDOW = 512 |
| ROPE_THETA_SLIDE = 10_000.0 |
| ROPE_THETA_FULL = 1_000_000.0 |
| PARTIAL_ROT_FULL = 0.25 |
| RMS_EPS = 1e-6 |
| LOGIT_CAP = 30.0 |
| ATTN_SCALE = 1.0 |
|
|
| |
| PER_LAYER_PROJ_SCALE = HIDDEN_SIZE ** -0.5 |
| |
| PER_LAYER_INPUT_SCALE = math.sqrt(0.5) |
|
|
| |
| FULL_ATTN_LAYERS = frozenset(range(4, N_LAYERS, 5)) |
|
|
|
|
| def is_full_attention(layer_idx: int) -> bool: |
| """Return True if layer_idx uses full (global) attention.""" |
| return layer_idx in FULL_ATTN_LAYERS |
|
|
|
|
| |
|
|
| class RMSNorm(nn.Module): |
| """RMSNorm with weight * normed, weight initialized to ones.""" |
|
|
| def __init__(self, dim: int): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_f32 = x.float() |
| normed = x_f32 * torch.rsqrt(x_f32.pow(2).mean(-1, keepdim=True) + RMS_EPS) |
| return (normed * self.weight.float()).to(x.dtype) |
|
|
|
|
| |
|
|
| def build_rope_freqs( |
| head_dim: int, |
| max_seq: int, |
| theta: float, |
| device: torch.device, |
| n_rot_pairs: int | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Build cos/sin tables of shape [max_seq, head_dim]. |
| |
| For full-attention layers with partial rotation, only the first |
| n_rot_pairs*2 positions carry actual frequencies; the rest are zeros |
| (NoPE β no positional encoding for those dims). |
| |
| Args: |
| head_dim: total head dimension |
| max_seq: maximum sequence length to precompute |
| theta: RoPE base frequency |
| device: target device |
| n_rot_pairs: if set, only compute real freqs for this many pairs; |
| remaining dims get freq=0 (cos=1, sin=0 β identity). |
| """ |
| half = head_dim // 2 |
| if n_rot_pairs is None: |
| n_rot_pairs = half |
|
|
| |
| inv_freq = 1.0 / (theta ** ( |
| torch.arange(0, n_rot_pairs, device=device).float() / half |
| )) |
|
|
| |
| if n_rot_pairs < half: |
| inv_freq = torch.cat([ |
| inv_freq, |
| torch.zeros(half - n_rot_pairs, device=device), |
| ]) |
|
|
| t = torch.arange(max_seq, device=device).float() |
| freqs = torch.outer(t, inv_freq) |
| freqs = torch.cat([freqs, freqs], dim=-1) |
| return freqs.cos(), freqs.sin() |
|
|
|
|
| def apply_rope( |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Apply rotary embeddings. |
| |
| Args: |
| x: [B, H, T, head_dim] |
| cos: [T, head_dim] (broadcastable) |
| sin: [T, head_dim] |
| """ |
| half = x.shape[-1] // 2 |
| x1, x2 = x[..., :half], x[..., half:] |
| rotated = torch.cat([-x2, x1], dim=-1) |
| T = x.shape[2] |
| cos_ = cos[:T].unsqueeze(0).unsqueeze(0).to(x.dtype) |
| sin_ = sin[:T].unsqueeze(0).unsqueeze(0).to(x.dtype) |
| return x * cos_ + rotated * sin_ |
|
|
|
|
| |
|
|
| class Attention(nn.Module): |
| """ |
| Multi-query attention (8 Q heads, 1 KV head). |
| |
| Sliding layers: head_dim=256, local window=512. |
| Full layers: head_dim=512, causal (no window restriction). |
| """ |
|
|
| def __init__(self, layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.full_attn = is_full_attention(layer_idx) |
| self.head_dim = HEAD_DIM_FULL if self.full_attn else HEAD_DIM_SLIDE |
| hd = self.head_dim |
|
|
| self.q_proj = nn.Linear(HIDDEN_SIZE, N_Q_HEADS * hd, bias=False) |
| self.k_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False) |
| self.v_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False) |
| self.o_proj = nn.Linear(N_Q_HEADS * hd, HIDDEN_SIZE, bias=False) |
|
|
| self.q_norm = RMSNorm(hd) |
| self.k_norm = RMSNorm(hd) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> torch.Tensor: |
| B, T, _ = x.shape |
| hd = self.head_dim |
|
|
| q = self.q_proj(x).view(B, T, N_Q_HEADS, hd).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2) |
|
|
| |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| |
| q = apply_rope(q, cos, sin) |
| k = apply_rope(k, cos, sin) |
|
|
| |
| k = k.expand(B, N_Q_HEADS, T, hd) |
| v = v.expand(B, N_Q_HEADS, T, hd) |
|
|
| if self.full_attn: |
| |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| is_causal=True, |
| scale=ATTN_SCALE, |
| ) |
| else: |
| |
| |
| |
| |
| idx = torch.arange(T, device=x.device) |
| |
| |
| |
| attn_mask = ( |
| (idx.unsqueeze(0) <= idx.unsqueeze(1)) & |
| (idx.unsqueeze(1) - idx.unsqueeze(0) < SLIDING_WINDOW) |
| ) |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=attn_mask, |
| scale=ATTN_SCALE, |
| ) |
|
|
| out = out.transpose(1, 2).contiguous().view(B, T, N_Q_HEADS * hd) |
| return self.o_proj(out) |
|
|
|
|
| |
|
|
| class MLP(nn.Module): |
| """ |
| GeGLU feed-forward network. |
| |
| Layers 0-14: intermediate_size=6144 |
| Layers 15-34: intermediate_size=12288 (double-wide) |
| """ |
|
|
| def __init__(self, layer_idx: int): |
| super().__init__() |
| inter = INTERMEDIATE_WIDE if layer_idx >= DOUBLE_WIDE_START else INTERMEDIATE |
| self.gate_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False) |
| self.up_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False) |
| self.down_proj = nn.Linear(inter, HIDDEN_SIZE, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| gate = F.gelu(self.gate_proj(x), approximate="tanh") |
| return self.down_proj(gate * self.up_proj(x)) |
|
|
|
|
| |
|
|
| class Gemma4TextLayer(nn.Module): |
| """ |
| Single Gemma 4 decoder layer. |
| |
| Execution order (per forward call): |
| 1. Per-layer auxiliary stream injection |
| 2. Self-attention block (pre/post norm, residual scaled by layer_scalar) |
| 3. MLP block (pre/post norm, residual scaled by layer_scalar) |
| |
| Per-layer auxiliary stream injection: |
| Receives per_layer_input [B,T,256] = combined embed+projection for this layer. |
| x_normed = input_layernorm(x) |
| gate = sigmoid(per_layer_input_gate(x_normed)) # [B,T,256] |
| gated = gate * per_layer_input # [B,T,256] |
| out_1536 = per_layer_projection(gated) # [B,T,1536] |
| x = x + post_per_layer_input_norm(out_1536) |
| """ |
|
|
| def __init__(self, layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
|
|
| |
| self.self_attn = Attention(layer_idx) |
|
|
| |
| self.mlp = MLP(layer_idx) |
|
|
| |
| self.input_layernorm = RMSNorm(HIDDEN_SIZE) |
| self.post_attention_layernorm = RMSNorm(HIDDEN_SIZE) |
| self.pre_feedforward_layernorm = RMSNorm(HIDDEN_SIZE) |
| self.post_feedforward_layernorm = RMSNorm(HIDDEN_SIZE) |
| self.post_per_layer_input_norm = RMSNorm(HIDDEN_SIZE) |
|
|
| |
| |
| |
| self.per_layer_input_gate = nn.Linear(HIDDEN_SIZE, PER_LAYER_DIM, bias=False) |
| self.per_layer_projection = nn.Linear(PER_LAYER_DIM, HIDDEN_SIZE, bias=False) |
|
|
| |
| self.layer_scalar = nn.Parameter(torch.ones(1)) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| per_layer_input: torch.Tensor, |
| ) -> torch.Tensor: |
|
|
| scalar = self.layer_scalar.to(x.dtype) |
|
|
| |
| |
| |
| |
| x_normed = self.input_layernorm(x) |
| gate = F.gelu(self.per_layer_input_gate(x_normed), approximate="tanh") |
| gated = gate * per_layer_input |
| out = self.per_layer_projection(gated) |
| x = x + scalar * self.post_per_layer_input_norm(out) |
|
|
| |
| |
| h = self.self_attn(self.input_layernorm(x), cos, sin) |
| x = x + scalar * self.post_attention_layernorm(h) |
|
|
| |
| h = self.mlp(self.pre_feedforward_layernorm(x)) |
| x = x + scalar * self.post_feedforward_layernorm(h) |
|
|
| return x |
|
|
|
|
| |
|
|
| class Gemma4ForCausalLM(nn.Module): |
| """ |
| Gemma 4 E2B text model (causal LM head, no vision/audio). |
| |
| Tied embeddings: embed_tokens.weight is shared with lm_head. |
| Output logits are softcapped: 30 * tanh(logits / 30). |
| |
| Per-layer auxiliary stream is computed model-level before layer iteration: |
| - embed_tokens_per_layer lookup: [B,T,35*256] |
| - per_layer_model_projection: Linear(1536β35*256) |
| - per_layer_projection_norm: RMSNorm(256) per layer-slice |
| - combine: per_layer_inputs = (embed_aux + proj_scaled) * (1/sqrt(2)) |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| |
| self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) |
| self.embed_tokens_per_layer = nn.Embedding(VOCAB_SIZE, N_LAYERS * PER_LAYER_DIM) |
|
|
| |
| self.norm = RMSNorm(HIDDEN_SIZE) |
|
|
| |
| self.layers = nn.ModuleList([Gemma4TextLayer(i) for i in range(N_LAYERS)]) |
|
|
| |
| |
| self.per_layer_model_projection = nn.Linear( |
| HIDDEN_SIZE, N_LAYERS * PER_LAYER_DIM, bias=False |
| ) |
| |
| self.per_layer_projection_norm = RMSNorm(PER_LAYER_DIM) |
|
|
| |
| self._rope_slide_cos: torch.Tensor | None = None |
| self._rope_slide_sin: torch.Tensor | None = None |
| self._rope_full_cos: torch.Tensor | None = None |
| self._rope_full_sin: torch.Tensor | None = None |
| self._rope_seq: int = 0 |
|
|
| @staticmethod |
| def is_full_attention(layer_idx: int) -> bool: |
| return is_full_attention(layer_idx) |
|
|
| def _ensure_rope(self, seq_len: int, device: torch.device) -> None: |
| """Precompute (or extend) RoPE tables on demand.""" |
| if self._rope_slide_cos is not None and self._rope_seq >= seq_len: |
| return |
| max_seq = max(seq_len, 2048) |
|
|
| |
| cs, sn = build_rope_freqs(HEAD_DIM_SLIDE, max_seq, ROPE_THETA_SLIDE, device) |
| self._rope_slide_cos = cs |
| self._rope_slide_sin = sn |
|
|
| |
| |
| n_rot = int(HEAD_DIM_FULL * PARTIAL_ROT_FULL) // 2 |
| cf, sf = build_rope_freqs( |
| HEAD_DIM_FULL, max_seq, ROPE_THETA_FULL, device, n_rot_pairs=n_rot |
| ) |
| self._rope_full_cos = cf |
| self._rope_full_sin = sf |
| self._rope_seq = max_seq |
|
|
| def _compute_per_layer_inputs( |
| self, input_ids: torch.Tensor, x_embed: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Precompute per-layer auxiliary inputs for all 35 layers. |
| |
| Returns: |
| per_layer_inputs: [B, T, N_LAYERS, PER_LAYER_DIM] |
| """ |
| B, T = input_ids.shape |
|
|
| |
| |
| embed_aux = self.embed_tokens_per_layer(input_ids).to(x_embed.dtype) |
| embed_aux = embed_aux * math.sqrt(PER_LAYER_DIM) |
| |
| embed_aux = embed_aux.view(B, T, N_LAYERS, PER_LAYER_DIM) |
|
|
| |
| proj_all = self.per_layer_model_projection(x_embed) |
| proj_all = proj_all * PER_LAYER_PROJ_SCALE |
| proj_all = proj_all.view(B, T, N_LAYERS, PER_LAYER_DIM) |
| |
| proj_all = self.per_layer_projection_norm(proj_all) |
|
|
| |
| per_layer_inputs = (embed_aux + proj_all) * PER_LAYER_INPUT_SCALE |
|
|
| return per_layer_inputs |
|
|
| def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| input_ids: [B, T] long tensor |
| |
| Returns: |
| logits: [B, T, vocab_size] with softcapping applied |
| """ |
| B, T = input_ids.shape |
| self._ensure_rope(T, input_ids.device) |
|
|
| |
| x = self.embed_tokens(input_ids) * math.sqrt(HIDDEN_SIZE) |
|
|
| |
| per_layer_inputs = self._compute_per_layer_inputs(input_ids, x) |
|
|
| for i, layer in enumerate(self.layers): |
| per_layer_i = per_layer_inputs[:, :, i, :] |
|
|
| if is_full_attention(i): |
| cos, sin = self._rope_full_cos, self._rope_full_sin |
| else: |
| cos, sin = self._rope_slide_cos, self._rope_slide_sin |
|
|
| x = layer(x, cos, sin, per_layer_i) |
|
|
| x = self.norm(x) |
|
|
| |
| logits = F.linear(x, self.embed_tokens.weight.to(x.dtype)) |
|
|
| |
| logits = LOGIT_CAP * torch.tanh(logits / LOGIT_CAP) |
| return logits |
|
|
| @classmethod |
| def load_weights( |
| cls, |
| safetensors_path: str | Path, |
| device: str = "cpu", |
| ) -> "Gemma4ForCausalLM": |
| """ |
| Load from the safetensors checkpoint. |
| |
| Weight names in the file follow the pattern: |
| model.language_model.X β self.X |
| """ |
| model = cls() |
| path = str(safetensors_path) |
| prefix = "model.language_model." |
| state = {} |
|
|
| with safe_open(path, framework="pt", device=device) as f: |
| for key in f.keys(): |
| if not key.startswith(prefix): |
| continue |
| local_key = key[len(prefix):] |
| state[local_key] = f.get_tensor(key) |
|
|
| missing, unexpected = model.load_state_dict(state, strict=False) |
| if missing: |
| print(f"[load_weights] {len(missing)} missing keys (first 5): {missing[:5]}") |
| if unexpected: |
| print(f"[load_weights] {len(unexpected)} unexpected keys (first 5): {unexpected[:5]}") |
|
|
| model = model.to(dtype=DTYPE) |
| return model |
|
|
|
|
| |
|
|
| def load_gemma4( |
| device: str | None = None, |
| ) -> tuple[Gemma4ForCausalLM, AutoTokenizer]: |
| """ |
| Load the Gemma 4 E2B model and tokenizer. |
| |
| Returns: |
| (model, tokenizer) β model is in eval mode on `device`. |
| """ |
| if device is None: |
| device = DEVICE |
|
|
| print(f"Loading Gemma 4 E2B from {SAFETENSORS_BLOB} ...") |
| model = Gemma4ForCausalLM.load_weights(SAFETENSORS_BLOB, device=device) |
| model = model.to(device).eval() |
|
|
| print(f"Loading tokenizer from {MODEL_DIR} ...") |
| tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR), local_files_only=True) |
|
|
| return model, tokenizer |
|
|
|
|
| |
|
|
| def ppl_on_text( |
| model: Gemma4ForCausalLM, |
| tokenizer: AutoTokenizer, |
| text: str, |
| device: str | None = None, |
| max_length: int = 1024, |
| ) -> float: |
| """ |
| Compute token-level perplexity on `text`. |
| |
| Args: |
| model: Gemma4ForCausalLM in eval mode |
| tokenizer: matching AutoTokenizer |
| text: input string |
| device: device for inference (defaults to DEVICE) |
| max_length: truncate to this many tokens |
| |
| Returns: |
| perplexity (float) |
| """ |
| if device is None: |
| device = DEVICE |
|
|
| enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) |
| input_ids = enc["input_ids"].to(device) |
|
|
| with torch.no_grad(): |
| logits = model(input_ids) |
|
|
| |
| shift_logits = logits[0, :-1, :] |
| shift_labels = input_ids[0, 1:] |
|
|
| log_probs = F.log_softmax(shift_logits.float(), dim=-1) |
| nll = -log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1).mean() |
| return nll.exp().item() |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| _WIKI_TEXT = ( |
| "The transformer architecture was introduced in the paper " |
| "'Attention Is All You Need' by Vaswani et al. in 2017. " |
| "It relies entirely on self-attention mechanisms, dispensing with " |
| "recurrence and convolutions entirely. Transformers have since become " |
| "the dominant architecture for natural language processing, powering " |
| "models such as BERT, GPT, T5, and the Gemma family. " |
| "The key innovation is the multi-head attention mechanism, which allows " |
| "the model to jointly attend to information from different representation " |
| "subspaces at different positions. This is complemented by position-wise " |
| "feed-forward networks and residual connections with layer normalisation. " |
| "Large language models built on this architecture are trained on massive " |
| "corpora using next-token prediction (autoregressive language modelling) " |
| "or masked language modelling. They exhibit emergent capabilities such as " |
| "few-shot and zero-shot generalisation across a wide variety of tasks." |
| ) |
|
|
| model, tokenizer = load_gemma4() |
|
|
| ppl = ppl_on_text(model, tokenizer, _WIKI_TEXT) |
| print(f"\nPerplexity on sample text: {ppl:.2f} (target: ~17β18 for bfloat16)") |
|
|