SimpleLM SFT + RAG

Custom decoder-only Transformer, continued-fine-tuned on MegaScience in a retrieval-augmented format: each training example is a (passage, question, short-extractive-answer) triple. Architecture is defined in modeling_simple_lm.py (bundled in this repo) and loaded via trust_remote_code=True.

  • RAG SFT source checkpoint: models/sft_full_science_rag_tune.pt
  • Started from: /home/etan/simple_llm/models/sft_full_science.pt (which was itself SFT-tuned)
  • Training data: /home/etan/simple_llm/datasets/MegaScience/data
  • subject_filter: None
  • subject_exclude: ['math']
  • RAG SFT epochs: 2 at learning_rate 1.5e-05
  • Passage budget: 60-250 tokens
  • Summative answer budget: 15-180 tokens (first sentence + top-(N-1) overlap-relevant sentences, N=3)

Prompt format

The model was fine-tuned on a fixed scaffolding with three named slots:

Context: Photosynthesis is the process by which plants convert light energy into chemical energy stored in glucose.

Question: What is photosynthesis?
Answer: <answer></s>

The bundled chat_template.jinja maps the three slots to chat roles, so tokenizer.apply_chat_template(...) produces this string byte-for-byte:

chat role training slot
system {{context}} (the retrieved passage)
user {{question}} (the user's query)
assistant the answer text (loss-bearing, ends with EOS)

Distractor / refusal contract

During RAG fine-tuning, 0.2 of training triples received a deliberately unrelated passage paired with a real question, with the target answer fixed to:

"The passage does not provide enough information to answer this question."

This was the strongest anti-hallucination signal in the run -- the model learned to read the supplied passage and refuse rather than answer from parametric weights when context isn't on-topic. As a result:

  • Retrieve aggressively, not selectively. It's better to give the model a vaguely-related passage and let it refuse than to drop the passage and let it guess.
  • Empty/no context degrades to plain QA mode. With no system message, the chat template emits "Question: {{q}}\nAnswer: " -- the model's earlier (non-RAG) SFT format -- so the prior science SFT behavior is still reachable, but without the distractor-refusal contract.

Usage

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

repo = "etanlightstone/simple-lm-rag-science"
tok   = AutoTokenizer.from_pretrained(repo)
model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True).eval()

passage = (
    "Photosynthesis is the process by which plants convert light energy "
    "into chemical energy stored in glucose, releasing oxygen as a "
    "byproduct."
)
messages = [
    {"role": "system", "content": passage},
    {"role": "user",   "content": "What is photosynthesis?"},
]
inputs = tok.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
)
prompt_len = inputs["input_ids"].shape[1]
with torch.no_grad():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.3,
        top_p=0.9,
        repetition_penalty=1.1,
    )
answer = tok.decode(out[0, prompt_len:], skip_special_tokens=True)
print(answer)

To test the distractor / refusal behavior, replace the system content with something unrelated -- a well-trained checkpoint should respond near verbatim with the refusal string above.

Architecture

field value
vocab_size 32000
context_length 512
d_model 768
n_layers 12
n_heads 8
d_ff 2048
activation gelu
bias True
tie_word_embeddings True

Tokenizer source: TinyLlama/TinyLlama-1.1B-Chat-v1.0

Training settings

{
  "mode": "sft",
  "source_pretrain_checkpoint": "/home/etan/simple_llm/models/sft_full_science.pt",
  "source_pretrain_train_settings": {
    "mode": "sft",
    "source_pretrain_checkpoint": "/home/etan/simple_llm/checkpoints/lm_checkpoint_008_shutdown.pt",
    "source_pretrain_train_settings": {
      "batch_size": 10,
      "batch_size_note": "per GPU when using torchrun",
      "world_size": 1,
      "learning_rate": 0.0003,
      "weight_decay": 0.01,
      "num_epochs": 3,
      "max_steps": null,
      "grad_clip": 1.0,
      "seed": 42,
      "docs_dir": "/home/etan/simple_llm/docs",
      "block_size": 512,
      "stride": 448,
      "stride_overlap_tokens": 64
    },
    "data_dir": "/home/etan/simple_llm/datasets/MegaScience/data",
    "data_glob": "*.parquet",
    "subject_filter": null,
    "subject_exclude": [
      "math"
    ],
    "question_regex_filter": null,
    "batch_size": 10,
    "world_size": 1,
    "learning_rate": 3e-05,
    "min_lr": 3e-06,
    "warmup_steps": 200,
    "weight_decay": 0.0,
    "num_epochs": 1,
    "max_steps": null,
    "grad_clip": 1.0,
    "seed": 42,
    "block_size": 512,
    "eval_fraction": 0.005,
    "eval_every": 500,
    "max_train_examples": null,
    "freezing": {
      "freeze_embeddings": false,
      "freeze_lm_head": false,
      "freeze_blocks_below": 0,
      "tie_word_embeddings": true,
      "trainable_params": 91138560,
      "total_params": 91138560,
      "frozen_params": 0,
      "frozen_blocks": 0,
      "total_blocks": 12
    },
    "prompt_template": "Question: {question}\nAnswer: ",
    "completion_suffix": "</s>"
  },
  "data_dir": "/home/etan/simple_llm/datasets/MegaScience/data",
  "data_glob": "*.parquet",
  "subject_filter": null,
  "subject_exclude": [
    "math"
  ],
  "question_regex_filter": null,
  "batch_size": 10,
  "world_size": 1,
  "learning_rate": 1.5e-05,
  "min_lr": 3e-06,
  "warmup_steps": 200,
  "weight_decay": 0.01,
  "num_epochs": 2,
  "max_steps": null,
  "grad_clip": 1.0,
  "seed": 42,
  "block_size": 512,
  "eval_fraction": 0.02,
  "eval_every": 500,
  "max_train_examples": 50000,
  "freezing": {
    "freeze_embeddings": false,
    "freeze_lm_head": false,
    "freeze_blocks_below": 2,
    "tie_word_embeddings": true,
    "trainable_params": 80110592,
    "total_params": 91138560,
    "frozen_params": 11027968,
    "frozen_blocks": 2,
    "total_blocks": 12
  },
  "prompt_template": "Question: {question}\nAnswer: ",
  "completion_suffix": "</s>",
  "rag_mode": true,
  "rag": {
    "prompt_template": "Context: {context}\n\nQuestion: {question}\nAnswer: ",
    "passage_min_tokens": 60,
    "passage_max_tokens": 250,
    "answer_min_tokens": 15,
    "answer_max_tokens": 180,
    "answer_num_sentences": 3,
    "distractor_fraction": 0.2,
    "distractor_answer": "The passage does not provide enough information to answer this question."
  }
}
Downloads last month
14
Safetensors
Model size
91.1M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support