File size: 1,838 Bytes
2fe29f1 8132ac6 2fe29f1 8132ac6 2fe29f1 8132ac6 2fe29f1 8132ac6 2fe29f1 8132ac6 2fe29f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | ---
language:
- en
tags:
- llama
- decoder-only
- educational
- pretrained
license: apache-2.0
datasets:
- HuggingFaceFW/fineweb-edu
---
# LLM-1B-Lab
Educational implementation of a **1.1B parameter LLaMA-style Decoder-Only Transformer**,
trained from scratch on [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu).
## Model Details
| Attribute | Value |
|-----------|-------|
| Parameters | ~1.1B |
| Architecture | LLaMA-style (RMSNorm, RoPE, GQA, SwiGLU, Weight Tying) |
| Hidden dim | 2048 |
| Layers | 22 |
| Attention heads | 16 (Q) / 4 (KV) |
| Max sequence length | 2048 |
| Vocab size | 32,000 |
| Training steps | 20,000 |
| Best val loss | 2.3653 (perplexity: 10.65) |
## Training
- **Dataset**: FineWeb-Edu (sample-10BT)
- **Tokenizer**: Pretrained LLaMA 2 (`NousResearch/Llama-2-7b-hf`, 32K vocab)
- **Hardware**: Google Colab Pro+ (A100 40GB)
- **Precision**: bfloat16 mixed precision
- **Optimizer**: AdamW (lr=3e-4, weight_decay=0.1, beta2=0.95)
- **Scheduler**: Cosine warmup (2000 warmup steps)
- **Effective batch size**: 128
## Usage
```python
import torch
from safetensors.torch import load_file
from transformers import AutoTokenizer
# 1. Load config and rebuild model
from llm_lab.config import ModelConfig
from llm_lab.model import LLMModel
model = LLMModel(ModelConfig.base_1b())
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict, strict=False) # strict=False for weight tying
model.eval()
# 2. Load tokenizer (pretrained LLaMA 2)
tokenizer = AutoTokenizer.from_pretrained("Vjeong/LLM-1B-Lab")
# 3. Generate text
prompt = "The future of AI is"
input_ids = torch.tensor([tokenizer.encode(prompt)])
output = model.generate(input_ids, max_new_tokens=100, temperature=0.8, top_p=0.9)
print(tokenizer.decode(output[0].tolist()))
```
## License
Apache 2.0
|