Instructions to use shibatch/tinygemma3-2m with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use shibatch/tinygemma3-2m with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("shibatch/tinygemma3-2m", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| license: mit | |
| tags: | |
| - gemma3 | |
| - safetensors | |
| - transformers | |
| - tinygemma | |
| - tinystories | |
| - validation | |
| - test-suite | |
| # TinyStories Gemma3 Text Validation Artifact | |
| This directory contains a tiny Gemma 3 text-only model trained with official | |
| Hugging Face Transformers classes. | |
| It is intended for inference-engine validation, not for production language | |
| quality. | |
| ## Official classes used | |
| - `Gemma3TextConfig` | |
| - `Gemma3ForCausalLM` | |
| - `Trainer` | |
| No custom Gemma 3 modeling code is used. | |
| ## Key validation targets | |
| - `model_type = gemma3_text` | |
| - `architectures = Gemma3ForCausalLM` | |
| - local/global attention pattern through `layer_types` | |
| - sliding-window attention | |
| - full attention | |
| - GQA | |
| - per-head `q_norm` / `k_norm` | |
| - Gemma3 four-norm decoder layer structure | |
| - gated MLP: `silu(gate_proj(x)) * up_proj(x)` | |
| - tied output head through `model.embed_tokens.weight` | |
| ## Tiny architecture | |
| - vocab_size: 1024 | |
| - hidden_size: 128 | |
| - intermediate_size: 512 | |
| - num_hidden_layers: 6 | |
| - num_attention_heads: 4 | |
| - num_key_value_heads: 1 | |
| - head_dim: 32 | |
| - sliding_window: 32 | |
| - layer_types: ['sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'sliding_attention', 'full_attention'] | |
| ## Files | |
| - `hf/`: Hugging Face model/tokenizer artifact | |
| - `reference/reference.pt`: deterministic reference tensors | |
| - `reference/reference.json`: JSON summary of reference logits | |
| - `gemma3_text_config_dump.json`: normalized config dump | |
| - `safetensors_keys.json`: tensor names and shapes | |
| - `artifact_metadata.json`: generation metadata | |
| ## Usage | |
| ```python | |
| import torch | |
| from transformers import Gemma3ForCausalLM, PreTrainedTokenizerFast | |
| def main(): | |
| repo_id = "shibatch/tinygemma3-2m" | |
| print("Loading tokenizer...") | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(repo_id, subfolder="hf") | |
| print("Loading Gemma3 model weights...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = Gemma3ForCausalLM.from_pretrained( | |
| repo_id, | |
| subfolder="hf", | |
| torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, | |
| ).to(device) | |
| model.eval() | |
| prompt = "Once upon" | |
| print(f"\nInput prompt: {prompt}") | |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False) | |
| input_ids = [tokenizer.bos_token_id] + input_ids | |
| input_ids = torch.tensor([input_ids], dtype=torch.long, device=device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=100, | |
| do_sample=False, | |
| repetition_penalty=1.0, | |
| top_p=1.0, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print(f"Generated output: {generated_text}") | |
| if __name__ == "__main__": | |
| main() | |
| ``` | |