Model

Instinct-1-0.5B

Instinct-1-0.5B is a fully reproducible, from-scratch trained 0.5 Billion parameter language model built under the AutonomousX organization. It is trained on 40B tokens of the PILE dataset utilizing powerful TPU v4 infrastructure.

v1.0 Stable Base Release (No SFT or RLHF)

Model Specifications

0.5B Parameters TPU v4-8 40B Tokens JAX / Flax

Instinct-1-0.5B

Instinct-1-0.5B is a fully reproducible, from-scratch trained 500M parameter language model trained on 150B tokens using TPU v4 infrastructure.

Instinct-1-0.5B is a 500M parameter Large Language Model built from scratch under the AutonomousX organization

Compute for this project was supported by Google's TRC Program (TPU Research Cloud).

This model was developed by Rohit Yadav, a B.Tech 3rd year student from NIT Jalandhar, India

E-mail: yrohit1825@gmail.com.


Model Overview

Attribute Value
Model Name Instinct-1-0.5B
Organization AutonomousX
Parameters 500 Million
Vocabulary Size 50,277
Training Dataset The PILE
Tokens Seen 40 Billion
Training Hardware TPU v4-8
Initial Loss 10.82
Final Validation Loss ~2.6

Validation was performed using rolling validation shards of The PILE dataset.

image


Training Details

Instinct-1-0.5B was trained completely from scratch using JAX/Flax on TPU v4-8 hardware.

Training pipeline includes:

  • Dataset streaming from The PILE
  • Custom tokenizer with 50,277 vocabulary size
  • TPU optimized JAX / Flax training loop
  • Checkpointing and validation during training
  • Rolling validation shard evaluation

The model was trained on 40B tokens from The PILE dataset.


Reproducibility

The entire pipeline used to train the model is fully reproducible.

This includes:

  • Dataset pipeline
  • Tokenizer creation
  • Model architecture
  • TPU training loop
  • Checkpointing system

You can reproduce the complete training pipeline from scratch.

๐Ÿš€Full training pipeline repository :- Github training pipeline

๐Ÿ“ŠTPU Setup Guide :- Youtube


Run Inference (Model is available for Inference on Both GPUs and TPUs)

A ready-to-run Google Colab TPU/GPU inference script is provided below.

Simply open the notebook and run it with TPU or GPU runtime.

please be patient It may take 20 mins to run the model

#please be patient It may take 20 mins to run the model
# Install huggingface_hub if not installed
!pip install -q huggingface_hub

from huggingface_hub import snapshot_download


repo_id = "autonomousX/Instinct-1-0.5B"

# Download entire repository
local_path = snapshot_download(
    repo_id=repo_id,
    repo_type="model",
    local_dir="TPU_500m",
    local_dir_use_symlinks=False
)

print("Download complete!")
print("Saved to:", local_path)

# =========================
# FAST 500M INFERENCE CELL
# =========================

import os
import math
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state, checkpoints
import optax
from transformers import AutoTokenizer

# ---------------- CONFIG ----------------
SEQ_LEN = 1024
VOCAB_SIZE = 50304

N_LAYERS = 32
D_MODEL = 1024
N_HEADS = 16
D_HEAD = 64
D_FF = 4096
ROTARY_PCT = 0.25

CKPT_PATH = os.path.abspath("TPU_500m/checkpoint_0")

# ---------------- RoPE ----------------
def build_rope_cache(seq_len, head_dim, rotary_pct):
    dim = int(head_dim * rotary_pct)
    freqs = 1.0 / (10000 ** (jnp.arange(0, dim, 2) / dim))
    pos = jnp.arange(seq_len)
    angles = jnp.einsum("i,j->ij", pos, freqs)
    return jnp.sin(angles), jnp.cos(angles)

ROPE_SIN, ROPE_COS = build_rope_cache(SEQ_LEN, D_HEAD, ROTARY_PCT)

def apply_rope(q, k):
    dim = int(D_HEAD * ROTARY_PCT)
    T = q.shape[1]

    sin = ROPE_SIN[:T][None, :, None, :]
    cos = ROPE_COS[:T][None, :, None, :]

    q_rot, q_pass = q[..., :dim], q[..., dim:]
    k_rot, k_pass = k[..., :dim], k[..., dim:]

    q1, q2 = q_rot[..., ::2], q_rot[..., 1::2]
    k1, k2 = k_rot[..., ::2], k_rot[..., 1::2]

    q_rot = jnp.concatenate(
        [q1 * cos - q2 * sin,
         q1 * sin + q2 * cos],
        axis=-1
    )

    k_rot = jnp.concatenate(
        [k1 * cos - k2 * sin,
         k1 * sin + k2 * cos],
        axis=-1
    )

    return (
        jnp.concatenate([q_rot, q_pass], axis=-1),
        jnp.concatenate([k_rot, k_pass], axis=-1),
    )

# ---------------- MODEL ----------------
class RMSNorm(nn.Module):
    dim: int
    eps: float = 1e-6
    @nn.compact
    def __call__(self, x):
        scale = self.param("scale", nn.initializers.ones, (self.dim,))
        norm = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
        return x * (scale / norm)

class Attention(nn.Module):
    @nn.compact
    def __call__(self, x, mask):
        B, T, C = x.shape
        qkv = nn.Dense(3 * C, use_bias=False, dtype=jnp.bfloat16)(x)
        qkv = qkv.reshape(B, T, 3, N_HEADS, D_HEAD)

        q = qkv[:, :, 0]
        k = qkv[:, :, 1]
        v = qkv[:, :, 2]

        q, k = apply_rope(q, k)

        att = jnp.einsum("bthd,bshd->bhts", q, k)
        att = att / math.sqrt(D_HEAD)

        mask = mask.astype(jnp.float32)
        mask = (1.0 - mask) * -1e10
        att = att + mask

        att = nn.softmax(att.astype(jnp.float32), axis=-1)
        att = att.astype(jnp.bfloat16)

        out = jnp.einsum("bhts,bshd->bthd", att, v)
        out = out.reshape(B, T, C)

        return nn.Dense(C, use_bias=False, dtype=jnp.bfloat16)(out)

class Block(nn.Module):
    @nn.compact
    def __call__(self, x, mask):
        h = RMSNorm(D_MODEL)(x)
        h = Attention()(h, mask)
        x = x + h

        h = RMSNorm(D_MODEL)(x)
        h = nn.Dense(D_FF, dtype=jnp.bfloat16)(h)
        h = nn.gelu(h)
        h = nn.Dense(D_MODEL, dtype=jnp.bfloat16)(h)

        return x + h

class GPT(nn.Module):
    @nn.compact
    def __call__(self, input_ids):
        batch, seq_len = input_ids.shape
        mask = nn.attention.make_causal_mask(
            jnp.ones((batch, seq_len), dtype=jnp.bool_)
        )

        x = nn.Embed(
            VOCAB_SIZE,
            D_MODEL,
            embedding_init=nn.initializers.normal(0.02),
            dtype=jnp.bfloat16,
        )(input_ids)

        RematBlock = nn.remat(Block)

        for _ in range(N_LAYERS):
            x = RematBlock()(x, mask)

        x = RMSNorm(D_MODEL)(x)

        return nn.Dense(
            VOCAB_SIZE,
            use_bias=False,
            dtype=jnp.bfloat16
        )(x)
# ---------------- LOAD CHECKPOINT ----------------
def create_state():
    model = GPT()
    rng = jax.random.PRNGKey(0)
    params = model.init(rng, jnp.ones((1, SEQ_LEN), dtype=jnp.int32))
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optax.adamw(1e-4),
    )

state = create_state()
state = checkpoints.restore_checkpoint(CKPT_PATH, state)
params = state.params
model = GPT()

print("Checkpoint loaded.")

@jax.jit
def forward(params, input_ids):
    return model.apply(params, input_ids)

import jax.random as random

import jax.random as random

def generate(params, input_ids, max_new_tokens=30, temperature=0.9, top_k=40):
    rng = random.PRNGKey(0)

    for _ in range(max_new_tokens):

        logits = model.apply(params, input_ids)
        logits = logits[:, -1, :]
        logits = logits.astype(jnp.float32)

        logits = logits / temperature

        top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k)
        probs = jax.nn.softmax(top_k_logits, axis=-1)

        rng, subkey = random.split(rng)
        next_token_idx = random.categorical(subkey, jnp.log(probs))

        next_token = jnp.take_along_axis(
            top_k_indices,
            next_token_idx[:, None],
            axis=-1
        )

        input_ids = jnp.concatenate([input_ids, next_token], axis=1)

    return input_ids
# ---------------- RUN ----------------
tokenizer = AutoTokenizer.from_pretrained("autonomousX/Instinct-1-0.5B")

prompt = "I am John,"
tokens = tokenizer(prompt, return_tensors="np")
input_ids = jnp.array(tokens["input_ids"], dtype=jnp.int32)

output_ids = generate(params, input_ids, 200)

print("\n=== GENERATED TEXT ===\n")
print(tokenizer.decode(output_ids[0].tolist()))

Sample Output:

image

Author

Rohit Yadav

B.Tech 3rd Year
Dr. B.R. Ambedkar National Institute of Technology (NIT) Jalandhar, India

๐Ÿ“ง E-mail: yrohit1825@gmail.com
๐Ÿ”— LinkedIn: Rohit Yadav
๐Ÿ’ป Github: YADAV1825

๐Ÿš€ I am actively seeking Internships and Collaborations!

Research Interests

Bio_Informatics Large Language Models MultiModal Pipelines Systems Programming AI Infrastructure Distributed Training

Organization

AutonomousX

AutonomousX focuses on open-source contributions aimed at building Large Language Models from scratch using custom training pipelines.

Our work explores different training configurations including optimizers, datasets, and scalable TPU training using JAX and pmap. The goal is to provide transparent and reproducible implementations so that researchers, students, and developers can understand how modern LLMs are trained end-to-end.

Due to the current scarcity of complete beginner-friendly guides for training LLMs on TPUs, especially using JAX, AutonomousX aims to bridge this gap by publishing full training pipelines, scripts, and documentation for the open-source community.

Downloads last month
20
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including autonomousX/Instinct-1-0.5B