autonomousX
/

gpt
Model

Instinct-1-1.0B

Instinct-1-1.0B is a fully reproducible, from-scratch trained 1 Billion parameter language model built under the AutonomousX organization. It is trained on 20B tokens of the PILE dataset utilizing powerful TPU v4 infrastructure.

v1.0 Stable Base Release (No SFT or RLHF)

Model Specifications

1B Parameters TPU v4-8 20B Tokens JAX / Flax

Instinct-1-1B

Instinct-1-1B is a fully reproducible, from-scratch trained 1B parameter language model trained on 20B tokens of PILE using TPU v4 infrastructure.

Instinct-1-1B is a 1 Billion parameter Large Language Model built from scratch under the AutonomousX organization.

Compute for this project was supported by Google's TRC Program (TPU Research Cloud).

This model was developed by Rohit Yadav, a B.Tech 3rd year student from NIT Jalandhar, India E-mail: yrohit1825@gmail.com.


Model Overview

Attribute Value
Model Name Instinct-1-1B
Organization AutonomousX
Parameters ~1 Billion
Vocabulary Size 50,304
Training Dataset Pythia / The PILE
Tokens Seen 20 Billion
Training Hardware TPU v4-8

Validation was performed using rolling validation shards of the dataset.

image


Architecture Details

Unlike previous versions, this model utilizes standard Self-Attention without Rotary Position Embeddings (RoPE).

Hyperparameter Value
Layers 24
Model Dimension 1840
Attention Heads 16
Feed Forward Dimension 4968
Sequence Length 1024

Training Details

Instinct-1-1B was trained completely from scratch using JAX/Flax on TPU v4-8 hardware.

Training pipeline includes:

  • Dataset streaming from The PILE / Pythia Data
  • Custom tokenizer with 50,304 vocabulary size
  • TPU optimized JAX / Flax training loop with pmap
  • Checkpointing and validation during training
  • Rolling validation shard evaluation

The model was trained on 20B tokens and it's is a checkpoint of final version trained on 85B tokens in total


Reproducibility

The entire pipeline used to train the model is fully reproducible.

This includes:

  • Dataset pipeline
  • Tokenizer creation
  • Model architecture
  • TPU training loop
  • Checkpointing system

You can reproduce the complete training pipeline from scratch.

🚀Full training pipeline repository :- Github training pipeline

📊TPU Setup Guide :- Youtube


Run Inference (Model is available for Inference on Both GPUs and TPUs)

A ready-to-run Google Colab TPU/GPU inference script is provided below.

Simply open a notebook and run it with a TPU or GPU runtime. Please be patient, it may take some time to download and compile.


# Install huggingface_hub if not installed
!pip install -q huggingface_hub

from huggingface_hub import snapshot_download

repo_id = "autonomousX/Instinct-1-1B"

# Download entire repository
local_path = snapshot_download(
    repo_id=repo_id,
    repo_type="model",
    local_dir="TPU_1b",
    local_dir_use_symlinks=False
)

print("Download complete!")
print("Saved to:", local_path)

# =========================
# FAST 1B INFERENCE CELL
# =========================

import os
import math
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state, checkpoints
import optax
from transformers import AutoTokenizer
import jax.random as random

# ---------------- CONFIG ----------------
SEQ_LEN = 1024
VOCAB_SIZE = 50304

N_LAYERS = 24
D_MODEL  = 1840   
N_HEADS  = 16
D_FF     = 4968

CKPT_PATH = os.path.abspath("TPU_1b/checkpoint_0") # Ensure your checkpoint directory matches

# ---------------- MODEL ----------------
class RMSNorm(nn.Module):
    dim: int
    eps: float = 1e-6

    @nn.compact
    def __call__(self, x):
        scale = self.param("scale", nn.initializers.ones, (self.dim,))
        norm = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
        return x * (scale / norm)

class Block(nn.Module):
    @nn.compact
    def __call__(self, x, mask):

        # ---- Attention ----
        h = RMSNorm(D_MODEL)(x)
        h = nn.SelfAttention(
            num_heads=N_HEADS,
            dtype=jnp.bfloat16,
            use_bias=False,
            deterministic=True,
        )(h, mask=mask)
        x = x + h

        # ---- MLP ----
        h = RMSNorm(D_MODEL)(x)
        h = nn.Dense(D_FF, dtype=jnp.bfloat16)(h)
        h = nn.gelu(h)
        h = nn.Dense(D_MODEL, dtype=jnp.bfloat16)(h)

        return x + h

class GPT(nn.Module):
    @nn.compact
    def __call__(self, input_ids):

        batch, seq_len = input_ids.shape
        mask = nn.attention.make_causal_mask(
            jnp.ones((batch, seq_len), dtype=jnp.bool_)
        )

        x = nn.Embed(
            VOCAB_SIZE,
            D_MODEL,
            embedding_init=nn.initializers.normal(0.02),
            dtype=jnp.bfloat16,
        )(input_ids)

        RematBlock = nn.remat(Block)

        for _ in range(N_LAYERS):
            x = RematBlock()(x, mask)

        x = RMSNorm(D_MODEL)(x)

        logits = nn.Dense(
            VOCAB_SIZE,
            use_bias=False,
            dtype=jnp.bfloat16
        )(x)

        return logits

# ---------------- LOAD CHECKPOINT ----------------
def create_state():
    model = GPT()
    rng = jax.random.PRNGKey(0)
    params = model.init(rng, jnp.ones((1, SEQ_LEN), dtype=jnp.int32))
    tx = optax.adamw(1e-4) # Placeholder optimizer for loading
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )

state = create_state()
state = checkpoints.restore_checkpoint(CKPT_PATH, state)
params = state.params
model = GPT()

print("Checkpoint loaded.")

# ---------------- GENERATION ----------------
def generate(params, input_ids, max_new_tokens=30, temperature=0.9, top_k=40):
    rng = random.PRNGKey(0)

    for _ in range(max_new_tokens):
        logits = model.apply(params, input_ids)
        logits = logits[:, -1, :]
        logits = logits.astype(jnp.float32)

        logits = logits / temperature

        top_k_logits, top_k_indices = jax.lax.top_k(logits, top_k)
        probs = jax.nn.softmax(top_k_logits, axis=-1)

        rng, subkey = random.split(rng)
        next_token_idx = random.categorical(subkey, jnp.log(probs))

        next_token = jnp.take_along_axis(
            top_k_indices,
            next_token_idx[:, None],
            axis=-1
        )

        input_ids = jnp.concatenate([input_ids, next_token], axis=1)

    return input_ids

# ---------------- RUN ----------------
tokenizer = AutoTokenizer.from_pretrained("autonomousX/Instinct-1-1B")

prompt = "I am John,"
tokens = tokenizer(prompt, return_tensors="np")
input_ids = jnp.array(tokens["input_ids"], dtype=jnp.int32)

output_ids = generate(params, input_ids, 200)

print("\n=== GENERATED TEXT ===\n")
print(tokenizer.decode(output_ids[0].tolist()))

Sample Output

image


Author

Rohit Yadav

B.Tech 3rd Year
Dr. B.R. Ambedkar National Institute of Technology (NIT) Jalandhar, India

📧 E-mail: yrohit1825@gmail.com
🔗 LinkedIn: Rohit Yadav
💻 Github: YADAV1825

🚀 I am actively seeking Internships and Collaborations!

Research Interests

Bio_Informatics Large Language Models MultiModal Pipelines Systems Programming AI Infrastructure Distributed Training

Organization

AutonomousX

AutonomousX focuses on open-source contributions aimed at building Large Language Models from scratch using custom training pipelines.

Our work explores different training configurations including optimizers, datasets, and scalable TPU training using JAX and pmap. The goal is to provide transparent and reproducible implementations so that researchers, students, and developers can understand how modern LLMs are trained end-to-end.

Due to the current scarcity of complete beginner-friendly guides for training LLMs on TPUs, especially using JAX, AutonomousX aims to bridge this gap by publishing full training pipelines, scripts, and documentation for the open-source community.

Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including autonomousX/Instinct-1-1B