PyTorch module + dequantized safetensors weights extracted
We extracted the MTP drafter weights from the TFLite file and built a working PyTorch module.
What is done:
- Parsed all 45 weight tensors from Section11_TFLiteModel_tf_lite_mtp_drafter.tflite using the tflite Python package
- Dequantized int8 (per-channel) and int4 (per-channel, packed nibbles) to float16
- Mapped to a clean PyTorch nn.Module with proper naming
- Forward pass verified: produces vocab logits from dummy hidden states
Architecture found:
- pre_proj: Linear(5120 to 256)
- 4 transformer layers (256 hidden, 2048 intermediate, GELU gated MLP)
- Layers 0-2: 4 attn heads, Layer 3: 8 attn heads
- Q-only attention (KV from main model layers 22-23)
- lm_head: Linear(256 to 262144)
- 78M total parameters
Known gaps: parity validation against TFLite needed, layer 3 has a 512-dim post_attention_layernorm (handled with padding for now), int4 nibble ordering unverified.
Happy to share the code and weights if useful for the effort.
please share them
please share them
We had proceeded quite a bit after that, but no one seemed interested and I wanted the owner of the discovery and the repository to do it, but as I got no answer I might go ahead and share them myself in case someone is interested. It's 1:30am right now, so tomorrow though.
Ops there are news go look here: https://huggingface.co/shadowlilac/gemma-4-e4b-mtp-extraction-effort/discussions/3#69dd6183f983626b94d1bb05