PyTorch module + dequantized safetensors weights extracted

#2
by mirifiuto - opened

We extracted the MTP drafter weights from the TFLite file and built a working PyTorch module.

What is done:

  • Parsed all 45 weight tensors from Section11_TFLiteModel_tf_lite_mtp_drafter.tflite using the tflite Python package
  • Dequantized int8 (per-channel) and int4 (per-channel, packed nibbles) to float16
  • Mapped to a clean PyTorch nn.Module with proper naming
  • Forward pass verified: produces vocab logits from dummy hidden states

Architecture found:

  • pre_proj: Linear(5120 to 256)
  • 4 transformer layers (256 hidden, 2048 intermediate, GELU gated MLP)
  • Layers 0-2: 4 attn heads, Layer 3: 8 attn heads
  • Q-only attention (KV from main model layers 22-23)
  • lm_head: Linear(256 to 262144)
  • 78M total parameters

Known gaps: parity validation against TFLite needed, layer 3 has a 512-dim post_attention_layernorm (handled with padding for now), int4 nibble ordering unverified.

Happy to share the code and weights if useful for the effort.

please share them

please share them

We had proceeded quite a bit after that, but no one seemed interested and I wanted the owner of the discovery and the repository to do it, but as I got no answer I might go ahead and share them myself in case someone is interested. It's 1:30am right now, so tomorrow though.

Ops there are news go look here: https://huggingface.co/shadowlilac/gemma-4-e4b-mtp-extraction-effort/discussions/3#69dd6183f983626b94d1bb05

Sign up or log in to comment