SimpleLM SFT + RAG
Custom decoder-only Transformer, continued-fine-tuned on
MegaScience in a
retrieval-augmented format: each training example is a (passage, question,
short-extractive-answer) triple. Architecture is defined in
modeling_simple_lm.py (bundled in this repo) and loaded via
trust_remote_code=True.
- RAG SFT source checkpoint:
models/sft_full_science_rag_tune.pt - Started from:
/home/etan/simple_llm/models/sft_full_science.pt(which was itself SFT-tuned) - Training data:
/home/etan/simple_llm/datasets/MegaScience/data subject_filter:Nonesubject_exclude:['math']- RAG SFT epochs:
2at learning_rate1.5e-05 - Passage budget:
60-250tokens - Summative answer budget:
15-180tokens (first sentence + top-(N-1) overlap-relevant sentences, N=3)
Prompt format
The model was fine-tuned on a fixed scaffolding with three named slots:
Context: Photosynthesis is the process by which plants convert light energy into chemical energy stored in glucose.
Question: What is photosynthesis?
Answer: <answer></s>
The bundled chat_template.jinja maps the three slots to chat roles, so
tokenizer.apply_chat_template(...) produces this string byte-for-byte:
| chat role | training slot |
|---|---|
system |
{{context}} (the retrieved passage) |
user |
{{question}} (the user's query) |
assistant |
the answer text (loss-bearing, ends with EOS) |
Distractor / refusal contract
During RAG fine-tuning, 0.2 of training triples received
a deliberately unrelated passage paired with a real question, with the
target answer fixed to:
"The passage does not provide enough information to answer this question."
This was the strongest anti-hallucination signal in the run -- the model learned to read the supplied passage and refuse rather than answer from parametric weights when context isn't on-topic. As a result:
- Retrieve aggressively, not selectively. It's better to give the model a vaguely-related passage and let it refuse than to drop the passage and let it guess.
- Empty/no context degrades to plain QA mode. With no
systemmessage, the chat template emits"Question: {{q}}\nAnswer: "-- the model's earlier (non-RAG) SFT format -- so the prior science SFT behavior is still reachable, but without the distractor-refusal contract.
Usage
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
repo = "etanlightstone/simple-lm-rag-science"
tok = AutoTokenizer.from_pretrained(repo)
model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True).eval()
passage = (
"Photosynthesis is the process by which plants convert light energy "
"into chemical energy stored in glucose, releasing oxygen as a "
"byproduct."
)
messages = [
{"role": "system", "content": passage},
{"role": "user", "content": "What is photosynthesis?"},
]
inputs = tok.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)
prompt_len = inputs["input_ids"].shape[1]
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.3,
top_p=0.9,
repetition_penalty=1.1,
)
answer = tok.decode(out[0, prompt_len:], skip_special_tokens=True)
print(answer)
To test the distractor / refusal behavior, replace the system content with
something unrelated -- a well-trained checkpoint should respond near
verbatim with the refusal string above.
Architecture
| field | value |
|---|---|
| vocab_size | 32000 |
| context_length | 512 |
| d_model | 768 |
| n_layers | 12 |
| n_heads | 8 |
| d_ff | 2048 |
| activation | gelu |
| bias | True |
| tie_word_embeddings | True |
Tokenizer source: TinyLlama/TinyLlama-1.1B-Chat-v1.0
Training settings
{
"mode": "sft",
"source_pretrain_checkpoint": "/home/etan/simple_llm/models/sft_full_science.pt",
"source_pretrain_train_settings": {
"mode": "sft",
"source_pretrain_checkpoint": "/home/etan/simple_llm/checkpoints/lm_checkpoint_008_shutdown.pt",
"source_pretrain_train_settings": {
"batch_size": 10,
"batch_size_note": "per GPU when using torchrun",
"world_size": 1,
"learning_rate": 0.0003,
"weight_decay": 0.01,
"num_epochs": 3,
"max_steps": null,
"grad_clip": 1.0,
"seed": 42,
"docs_dir": "/home/etan/simple_llm/docs",
"block_size": 512,
"stride": 448,
"stride_overlap_tokens": 64
},
"data_dir": "/home/etan/simple_llm/datasets/MegaScience/data",
"data_glob": "*.parquet",
"subject_filter": null,
"subject_exclude": [
"math"
],
"question_regex_filter": null,
"batch_size": 10,
"world_size": 1,
"learning_rate": 3e-05,
"min_lr": 3e-06,
"warmup_steps": 200,
"weight_decay": 0.0,
"num_epochs": 1,
"max_steps": null,
"grad_clip": 1.0,
"seed": 42,
"block_size": 512,
"eval_fraction": 0.005,
"eval_every": 500,
"max_train_examples": null,
"freezing": {
"freeze_embeddings": false,
"freeze_lm_head": false,
"freeze_blocks_below": 0,
"tie_word_embeddings": true,
"trainable_params": 91138560,
"total_params": 91138560,
"frozen_params": 0,
"frozen_blocks": 0,
"total_blocks": 12
},
"prompt_template": "Question: {question}\nAnswer: ",
"completion_suffix": "</s>"
},
"data_dir": "/home/etan/simple_llm/datasets/MegaScience/data",
"data_glob": "*.parquet",
"subject_filter": null,
"subject_exclude": [
"math"
],
"question_regex_filter": null,
"batch_size": 10,
"world_size": 1,
"learning_rate": 1.5e-05,
"min_lr": 3e-06,
"warmup_steps": 200,
"weight_decay": 0.01,
"num_epochs": 2,
"max_steps": null,
"grad_clip": 1.0,
"seed": 42,
"block_size": 512,
"eval_fraction": 0.02,
"eval_every": 500,
"max_train_examples": 50000,
"freezing": {
"freeze_embeddings": false,
"freeze_lm_head": false,
"freeze_blocks_below": 2,
"tie_word_embeddings": true,
"trainable_params": 80110592,
"total_params": 91138560,
"frozen_params": 11027968,
"frozen_blocks": 2,
"total_blocks": 12
},
"prompt_template": "Question: {question}\nAnswer: ",
"completion_suffix": "</s>",
"rag_mode": true,
"rag": {
"prompt_template": "Context: {context}\n\nQuestion: {question}\nAnswer: ",
"passage_min_tokens": 60,
"passage_max_tokens": 250,
"answer_min_tokens": 15,
"answer_max_tokens": 180,
"answer_num_sentences": 3,
"distractor_fraction": 0.2,
"distractor_answer": "The passage does not provide enough information to answer this question."
}
}
- Downloads last month
- 14