ncylich commited on
Commit
ecac12b
·
verified ·
1 Parent(s): 7a8dc75

Upload gemma4_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gemma4_hf.py +60 -0
gemma4_hf.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """gemma4_hf.py — thin HF wrapper that exposes the same interface as gemma4.py.
2
+
3
+ Replaces our custom (broken) gemma4.py with transformers' validated implementation.
4
+ Exposes:
5
+ - N_LAYERS, HIDDEN_SIZE, INTERMEDIATE, INTERMEDIATE_WIDE, DOUBLE_WIDE_START, DEVICE, DTYPE
6
+ - load_gemma4() -> (model, tokenizer) where model behaves like our custom one:
7
+ * model(input_ids) -> logits tensor (not CausalLMOutputWithPast)
8
+ * model.layers is a proxy for the underlying ModuleList of decoder layers
9
+ * all nn.Module APIs (parameters, state_dict, etc.) work
10
+ """
11
+ import os
12
+ import torch
13
+ import torch.nn as nn
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ DTYPE = torch.bfloat16
18
+
19
+ # Architecture constants (match google/gemma-4-e2b-it)
20
+ N_LAYERS = 35
21
+ HIDDEN_SIZE = 1536
22
+ VOCAB_SIZE = 262144
23
+ INTERMEDIATE = 6144 # layers 0-14
24
+ INTERMEDIATE_WIDE = 12288 # layers 15-34
25
+ DOUBLE_WIDE_START = 15
26
+
27
+ HF_REPO = os.environ.get("GEMMA4_HF_REPO", "google/gemma-4-e2b-it")
28
+
29
+
30
+ class HFGemma4(nn.Module):
31
+ """Wraps HF's Gemma4ForConditionalGeneration to match our custom-model interface.
32
+
33
+ Forward returns raw logits (not the HF output struct), and `.layers` exposes
34
+ the decoder layer ModuleList for rung6_moe_g4.py's install_moe / Taylor hooks.
35
+ """
36
+
37
+ def __init__(self, inner: nn.Module):
38
+ super().__init__()
39
+ self.inner = inner
40
+
41
+ @property
42
+ def layers(self) -> nn.ModuleList:
43
+ return self.inner.model.language_model.layers
44
+
45
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
46
+ return self.inner(input_ids=input_ids).logits
47
+
48
+
49
+ def load_gemma4(device=None):
50
+ """Load HF Gemma-4 E2B-IT, wrapped in HFGemma4 for rung6 compatibility."""
51
+ if device is None:
52
+ device = DEVICE
53
+ print(f"Loading HF {HF_REPO} (dtype={DTYPE})...")
54
+ inner = AutoModelForCausalLM.from_pretrained(HF_REPO, dtype=DTYPE, device_map=device)
55
+ inner.eval()
56
+ model = HFGemma4(inner)
57
+ model.eval()
58
+ print(f"Loading tokenizer...")
59
+ tokenizer = AutoTokenizer.from_pretrained(HF_REPO)
60
+ return model, tokenizer