import torch
from transformers import AutoProcessor
from transformers.models.gemma4_unified.configuration_gemma4_unified import (
Gemma4UnifiedConfig,
Gemma4UnifiedTextConfig,
Gemma4UnifiedVisionConfig,
)
from transformers.models.gemma4_unified.modeling_gemma4_unified import Gemma4UnifiedForConditionalGeneration
save_dir = "./tiny-random-gemma4-unified-it"
# Tiny text config mirroring the 12B gemma4_unified architecture:
# - PLE disabled (no hidden_size_per_layer_input field exists on unified text config)
# - attention_k_eq_v=True so full-attention layers fuse v_proj into k_proj
# - num_global_key_value_heads=1, global_head_dim larger than head_dim
# - num_kv_shared_layers=0 (12B does not share KV across layers)
# - use_bidirectional_attention="vision"
text_config = Gemma4UnifiedTextConfig(
hidden_size=32,
intermediate_size=64,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=16,
global_head_dim=32,
num_global_key_value_heads=1,
vocab_size=262144,
max_position_embeddings=512,
rms_norm_eps=1e-6,
hidden_activation="gelu_pytorch_tanh",
sliding_window=64,
layer_types=["sliding_attention", "sliding_attention", "sliding_attention", "full_attention"],
num_kv_shared_layers=0,
attention_k_eq_v=True,
use_double_wide_mlp=False,
use_bidirectional_attention="vision",
final_logit_softcapping=30.0,
tie_word_embeddings=True,
)
# Vision is an encoder-free embedder: model_patch_size = patch_size * pooling_kernel_size.
# mm_embed_dim / output_proj_dims must match the text hidden_size.
vision_config = Gemma4UnifiedVisionConfig(
patch_size=16,
pooling_kernel_size=3,
mm_embed_dim=32,
output_proj_dims=32,
mm_posemb_size=128,
rms_norm_eps=1e-6,
)
config = Gemma4UnifiedConfig(
text_config=text_config.to_dict(),
vision_config=vision_config.to_dict(),
audio_config=None,
boi_token_id=255999,
eoi_token_id=258882,
image_token_id=258880,
video_token_id=258884,
boa_token_id=256000,
eoa_token_index=258883,
audio_token_id=258881,
tie_word_embeddings=True,
)
# Seed before init so the random weights are reproducible. This seed produces a fixture
# whose greedy generation has no near-ties, so OV-vs-transformers token equality is stable
# under the small (~1e-4) numerical differences of OpenVINO inference.
torch.manual_seed(42)
model = Gemma4UnifiedForConditionalGeneration(config)
model = model.to(dtype=torch.float32)
model.eval()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
model.save_pretrained(save_dir)
# Reuse the reference processor but shrink the soft-token budget so the tiny
# position embedding table (mm_posemb_size=64) is large enough.
processor = AutoProcessor.from_pretrained("google/gemma-4-12b-it")
processor.image_processor.max_soft_tokens = 70
processor.image_processor.image_seq_length = 70
processor.save_pretrained(save_dir)
print(f"Tiny Gemma4Unified model saved to {save_dir}")
# Sanity forward pass
input_ids = torch.randint(0, 262144, (1, 10))
with torch.no_grad():
out = model(input_ids=input_ids)
print("logits shape:", out.logits.shape)
print("Forward pass OK!")
- Downloads last month
- -
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support