AstroBERT Small: Domain-specialized small models

Community Article
Published July 1, 2026

This article introduces the new AstroBERT Small series of models. This is a solid-performing small model (22.7M) trained on ArXiv abstracts categorized as astro-ph and Wikipedia articles labeled as astronomy related.

These domain-specialized small models often perform as good as models 10-100x larger. It demonstrates that narrowing down a model to a small domain requires less overall parameters than models generalized for all problems.

The following new models are released as part of this effort. All models have an Apache 2.0 license.

Model Description
AstroBERT Small Base 22.7M parameter language model
AstroBERT Small Embeddings Small Sentence Transformers model for embeddings

Building a Strong Baseline

The first step was to build a 22.7M parameter BERT encoder-only model was trained on ArXiv abstracts categorized as astro-ph and Wikipedia articles labeled as astronomy related.

The model was trained using masked language modeling with the following code.

import csv

from datasets import concatenate_datasets, load_dataset, load_from_disk
from transformers import AutoTokenizer
from transformers import BertConfig, BertForMaskedLM

from txtai.pipeline import HFTrainer

def loadids():
    labels, count = set(), 0
    with open("/data/sources/wikipedia/labels/labels.csv", mode="r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if row["label"] == domain:
                labels.add(row["id"])
            count += 1

    return labels, count

# Load target domain
domain = "astronomy"
uids, total = loadids()

# Filter by domain labels
dataset = load_dataset("neuml/wikipedia-20260401", split="train")
dataset = dataset.filter(lambda x: x["title"] in uids)
dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"])

# Add arxiv abstracts
dataset = concatenate_datasets([dataset, load_from_disk(f"datasets/arxiv-{domain}")])

# Calculate number of epochs based on size of filtered dataset
epochs = 3 * int(total / len(dataset))
print(f"Calculated {epochs} epochs")

# Standard tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained(f"tokenizers/{domain}")

# Configuration - bert small
config = BertConfig(
       hidden_size=384,
       num_hidden_layers=6,
       num_attention_heads=6,
       intermediate_size=1536
)

# Model to train
model = BertForMaskedLM(config)

print(config)
print("Total parameters:", sum(p.numel() for p in model.bert.parameters()))

train = HFTrainer()

#
# Train using MLM
#
# Settings copied from original BERT training - override when HF Trainer defaults don't match
#   - BERT Paper (pg. 13): https://arxiv.org/pdf/1810.04805
#   - BERT Tiny Paper: https://arxiv.org/pdf/1908.08962
#   - BERT Parameters: https://github.com/google-research/bert/blob/master/optimization.py#L59
#   - HF Trainer defaults: https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments
train((model, tokenizer), dataset, task="language-modeling", output_dir="output",
       fp16=True, learning_rate=1e-3, per_device_train_batch_size=64, num_train_epochs=epochs,
       warmup_steps=2500, weight_decay=0.01, adam_epsilon=1e-6,
       tokenizers=True, dataloader_num_workers=20,
       save_strategy="steps", save_steps=5000, logging_steps=500,
)

The model is intended to be further fine-tuned for a specific task such as Text Classification, Entity Extraction, Sentence Embeddings and so on.


Training a Small Embeddings model

Next a sentence-transformers model was fined-tuned to generate vector embeddings. The training dataset was generated using a random sample of ArXiv abstracts labeled as astro-ph.

The model was trained by distilling embeddings from the larger Qwen3-Embedding-8B model using EmbedDistillLoss over the generated training dataset.

As noted in the paper Well-Read Students Learn Better: On the Importance of Pre-training Compact Models, it's important that the base model is pretrained on a large corpus of relevant documents prior to distillation.

The training code is shown below.

import json
import logging

import numpy as np

from datasets import Dataset, load_from_disk, concatenate_datasets
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    models,
    losses
)

logging.basicConfig(level=logging.INFO)

def load(path):
    rows = []
    with open(path, "r", encoding="utf-8") as inputs:
        for line in inputs:
            row = json.loads(line)
            rows.append({
                "query": row[0],
                "document": row[1],
            })

    return rows

def run(domain):
    # Training prompts
    prompts = {
        "query": "query: ",
        "document": "document: "
    }

    # Embeddings model
    embeddings = models.Transformer("astrobert-small")

    # Pooling model
    pooling = models.Pooling(embeddings.get_embedding_dimension())

    # Create sentence-transformers model
    model = SentenceTransformer(modules=[embeddings, pooling], prompts=prompts)

    # Teacher model
    teacher = SentenceTransformer("Qwen/Qwen3-Embedding-8B")

    # Load training data
    train = load(f"training/{domain}-similarity-train.jsonl") + load(f"training/{domain}-similarity-train-questions.jsonl")
    train = Dataset.from_list(train)

    def compute(batch):
        embed1 = teacher.encode(batch["query"], prompt_name="query", show_progress_bar=False, batch_size=8)
        embed2 = teacher.encode(batch["document"], prompt_name="document", show_progress_bar=False, batch_size=8)
        return {"label": np.stack([embed1, embed2], axis=1).tolist()}

    # Build separate shards
    shards = 10
    for x in range(shards):
        shard = train.shard(num_shards=shards, index=x)
        shard = shard.map(
            compute,
            batched=True,
            batch_size=1_000,
            writer_batch_size=1_000,
            new_fingerprint=f"embeddings-{x}"
        )
        shard.save_to_disk(f"training/{domain}-embeddings/shard-{x}")

    train = concatenate_datasets([
        load_from_disk(f"training/{domain}-embeddings/shard-{x}")
        for x in range(shards)
    ])

    path = f"{domain}bert-embeddings"
    args = SentenceTransformerTrainingArguments(
        output_dir=path,
        num_train_epochs=25,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        gradient_accumulation_steps=2,
        fp16=True,
        learning_rate=3e-4,
        save_steps=0,
        logging_steps=500,
        dataloader_num_workers=20,
        prompts={
            "query": prompts["query"],
            "document": prompts["document"], 
        }
    )

    # Create the trainer & start training
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train,
        loss=losses.EmbedDistillLoss(model, distance_metric="cosine", projection_dim=4096),
    )
    trainer.train()
    model.save(path)

# Train model
run("astronomy")

Evaluation Results

A BEIR-compatible dataset was generated to facilitate the evaluation process. This is a separate random sample of ArXiv abstracts alongside generated user queries.

Evaluation results are shown below. NDCG is used as the evaluation metric.

Model Parameters NDCG Index Time Search Time Disk
AstroBERT Small Embeddings 22.7M 69.09 9.9s 0.42s 16 MB
all-MiniLM-L6-v2 22.7M 40.45 12.50s 0.38s 16 MB
DenseOn 149M 61.46 67.35s 0.77s 31 MB
EmbeddingGemma 300M 57.44 86.17s 1.43s 31 MB
Qwen3-Embedding-0.6B 600M 65.73 114.17s 2.20s 41 MB
Qwen3-Embedding-4B 4000M 71.14 545.28s 9.89s 103 MB
Qwen3-Embedding-8B 8000M 73.84 941.82s 17.24s 164 MB

This model is a solid performer at a small size. It beats the same sized all-MiniLM-L6-v2 model by a significant margin. It beats the 600M parameter Qwen3 Embeddings model which is over 25x larger. It scores slightly lower than the model it's distilled from (Qwen3-Embedding-8B).

This is a great model that can be used in CPU-only setups without trading off much on the accuracy front. It shows how small models can excel at specialized domains, requiring less compute and disk space.


Wrapping up

This article introduced the new AstroBERT Small series of models. It demonstrates that narrowing down a model to a small domain requires less overall parameters than models generalized for all problems.

If you're interested in building custom models like this for your data or domain area, feel free to reach out!

NeuML is the company behind txtai and we provide AI consulting services around our stack. Schedule a meeting or send a message to learn more.

We're also building an easy and secure way to run hosted txtai applications with txtai.cloud.

Community

Sign up or log in to comment