tinygemma3-2m / README.md
shibatch's picture
Upload README.md with huggingface_hub
7aab2f4 verified
|
Raw
History Blame Contribute Delete
2.88 kB
metadata
license: mit
tags:
  - gemma3
  - safetensors
  - transformers
  - tinygemma
  - tinystories
  - validation
  - test-suite

TinyStories Gemma3 Text Validation Artifact

This directory contains a tiny Gemma 3 text-only model trained with official Hugging Face Transformers classes.

It is intended for inference-engine validation, not for production language quality.

Official classes used

  • Gemma3TextConfig
  • Gemma3ForCausalLM
  • Trainer

No custom Gemma 3 modeling code is used.

Key validation targets

  • model_type = gemma3_text
  • architectures = Gemma3ForCausalLM
  • local/global attention pattern through layer_types
  • sliding-window attention
  • full attention
  • GQA
  • per-head q_norm / k_norm
  • Gemma3 four-norm decoder layer structure
  • gated MLP: silu(gate_proj(x)) * up_proj(x)
  • tied output head through model.embed_tokens.weight

Tiny architecture

  • vocab_size: 1024
  • hidden_size: 128
  • intermediate_size: 512
  • num_hidden_layers: 6
  • num_attention_heads: 4
  • num_key_value_heads: 1
  • head_dim: 32
  • sliding_window: 32
  • layer_types: ['sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'full_attention']

Files

  • hf/: Hugging Face model/tokenizer artifact
  • reference/reference.pt: deterministic reference tensors
  • reference/reference.json: JSON summary of reference logits
  • gemma3_text_config_dump.json: normalized config dump
  • safetensors_keys.json: tensor names and shapes
  • artifact_metadata.json: generation metadata

Usage

import torch
from transformers import Gemma3ForCausalLM, PreTrainedTokenizerFast

def main():
    repo_id = "shibatch/tinygemma3-2m"

    print("Loading tokenizer...")
    tokenizer = PreTrainedTokenizerFast.from_pretrained(repo_id, subfolder="hf")

    print("Loading Gemma3 model weights...")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = Gemma3ForCausalLM.from_pretrained(
        repo_id,
        subfolder="hf",
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    ).to(device)
    model.eval()

    prompt = "Once upon"
    print(f"\nInput prompt: {prompt}")

    input_ids = tokenizer.encode(prompt, add_special_tokens=False)
    input_ids = [tokenizer.bos_token_id] + input_ids
    input_ids = torch.tensor([input_ids], dtype=torch.long, device=device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=100,
            do_sample=False,
            repetition_penalty=1.0,
            top_p=1.0,
            pad_token_id=tokenizer.pad_token_id or tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Generated output: {generated_text}")

if __name__ == "__main__":
    main()