import torch
from transformers import AutoProcessor
from transformers.models.gemma4_unified.configuration_gemma4_unified import (
    Gemma4UnifiedConfig,
    Gemma4UnifiedTextConfig,
    Gemma4UnifiedVisionConfig,
)
from transformers.models.gemma4_unified.modeling_gemma4_unified import Gemma4UnifiedForConditionalGeneration


save_dir = "./tiny-random-gemma4-unified-it"

# Tiny text config mirroring the 12B gemma4_unified architecture:
#  - PLE disabled (no hidden_size_per_layer_input field exists on unified text config)
#  - attention_k_eq_v=True so full-attention layers fuse v_proj into k_proj
#  - num_global_key_value_heads=1, global_head_dim larger than head_dim
#  - num_kv_shared_layers=0 (12B does not share KV across layers)
#  - use_bidirectional_attention="vision"
text_config = Gemma4UnifiedTextConfig(
    hidden_size=32,
    intermediate_size=64,
    num_hidden_layers=4,
    num_attention_heads=4,
    num_key_value_heads=2,
    head_dim=16,
    global_head_dim=32,
    num_global_key_value_heads=1,
    vocab_size=262144,
    max_position_embeddings=512,
    rms_norm_eps=1e-6,
    hidden_activation="gelu_pytorch_tanh",
    sliding_window=64,
    layer_types=["sliding_attention", "sliding_attention", "sliding_attention", "full_attention"],
    num_kv_shared_layers=0,
    attention_k_eq_v=True,
    use_double_wide_mlp=False,
    use_bidirectional_attention="vision",
    final_logit_softcapping=30.0,
    tie_word_embeddings=True,
)

# Vision is an encoder-free embedder: model_patch_size = patch_size * pooling_kernel_size.
# mm_embed_dim / output_proj_dims must match the text hidden_size.
vision_config = Gemma4UnifiedVisionConfig(
    patch_size=16,
    pooling_kernel_size=3,
    mm_embed_dim=32,
    output_proj_dims=32,
    mm_posemb_size=128,
    rms_norm_eps=1e-6,
)

config = Gemma4UnifiedConfig(
    text_config=text_config.to_dict(),
    vision_config=vision_config.to_dict(),
    audio_config=None,
    boi_token_id=255999,
    eoi_token_id=258882,
    image_token_id=258880,
    video_token_id=258884,
    boa_token_id=256000,
    eoa_token_index=258883,
    audio_token_id=258881,
    tie_word_embeddings=True,
)

# Seed before init so the random weights are reproducible. This seed produces a fixture
# whose greedy generation has no near-ties, so OV-vs-transformers token equality is stable
# under the small (~1e-4) numerical differences of OpenVINO inference.
torch.manual_seed(42)
model = Gemma4UnifiedForConditionalGeneration(config)
model = model.to(dtype=torch.float32)
model.eval()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

model.save_pretrained(save_dir)

# Reuse the reference processor but shrink the soft-token budget so the tiny
# position embedding table (mm_posemb_size=64) is large enough.
processor = AutoProcessor.from_pretrained("google/gemma-4-12b-it")
processor.image_processor.max_soft_tokens = 70
processor.image_processor.image_seq_length = 70
processor.save_pretrained(save_dir)

print(f"Tiny Gemma4Unified model saved to {save_dir}")

# Sanity forward pass
input_ids = torch.randint(0, 262144, (1, 10))
with torch.no_grad():
    out = model(input_ids=input_ids)
print("logits shape:", out.logits.shape)
print("Forward pass OK!")
Downloads last month
-
Safetensors
Model size
8.69M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support