| """ |
| Shared model loading utilities. |
| |
| TransformerLens's default from_pretrained loads HF weights in fp32 on CPU first |
| (~9 GB RAM peak for Gemma-2-2B), which OOMs machines with <16 GB free RAM. |
| The RAM-efficient path: load HF model directly to CUDA in bf16, then wrap. |
| """ |
|
|
| import gc |
| from pathlib import Path |
|
|
| import torch |
| from transformer_lens import HookedTransformer |
|
|
| from sae_gemma.paths import MODEL_NAME |
|
|
|
|
| def load_model(device: str = "cuda", dtype: str = "bfloat16") -> HookedTransformer: |
| """ |
| Load google/gemma-2-2b into a HookedTransformer with minimal RAM usage. |
| |
| Uses low_cpu_mem_usage=True + direct CUDA placement so the fp32 staging |
| buffer never materialises in RAM. |
| """ |
| torch_dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 |
| target_device = device if torch.cuda.is_available() else "cpu" |
|
|
| |
| |
| print(f"[model_utils] Loading {MODEL_NAME} HF weights -> CPU ({dtype}) ...", flush=True) |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| hf_model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=torch_dtype, |
| device_map="cpu", |
| low_cpu_mem_usage=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
| print(f"[model_utils] Wrapping in HookedTransformer (fold_ln=False, CPU processing) ...", flush=True) |
| |
| |
| model = HookedTransformer.from_pretrained( |
| MODEL_NAME, |
| hf_model=hf_model, |
| tokenizer=tokenizer, |
| dtype=torch_dtype, |
| device="cpu", |
| fold_ln=False, |
| center_writing_weights=False, |
| center_unembed=False, |
| ) |
|
|
| |
| del hf_model |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| if target_device != "cpu": |
| model = model.to(target_device) |
| model.eval() |
|
|
| if torch.cuda.is_available(): |
| used = torch.cuda.memory_allocated() / 1e9 |
| print(f"[model_utils] Model loaded. VRAM used: {used:.2f} GB", flush=True) |
|
|
| return model |
|
|