|
|
import sys |
|
|
import torch |
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
from model import SmolLM2, SmolConfig |
|
|
|
|
|
|
|
|
PRETRAINED_NAME = "HuggingFaceTB/SmolLM2-135M" |
|
|
|
|
|
|
|
|
def build_custom_model(): |
|
|
"""Create our SmolLM2 using HF config to ensure identical hyperparams.""" |
|
|
hf_cfg = AutoConfig.from_pretrained(PRETRAINED_NAME) |
|
|
cfg = SmolConfig.from_hf(hf_cfg) |
|
|
model = SmolLM2(cfg) |
|
|
return model, cfg |
|
|
|
|
|
|
|
|
def build_hf_model(): |
|
|
"""Load reference HF model.""" |
|
|
hf_model = AutoModelForCausalLM.from_pretrained( |
|
|
PRETRAINED_NAME, |
|
|
torch_dtype=torch.float32, |
|
|
) |
|
|
hf_model.eval() |
|
|
return hf_model |
|
|
|
|
|
|
|
|
def load_weights_from_hf(custom_model: SmolLM2, hf_model: AutoModelForCausalLM): |
|
|
""" |
|
|
Map HF LlamaForCausalLM weights into our SmolLM2 model. |
|
|
|
|
|
- HF model structure: hf_model.model (LlamaModel) + hf_model.lm_head |
|
|
- Our model: embed_tokens, layers, norm, lm_head |
|
|
""" |
|
|
hf_state = hf_model.state_dict() |
|
|
custom_state = custom_model.state_dict() |
|
|
|
|
|
|
|
|
custom_state["embed_tokens.weight"] = hf_state["model.embed_tokens.weight"] |
|
|
|
|
|
|
|
|
num_layers = custom_model.config.num_hidden_layers |
|
|
|
|
|
for i in range(num_layers): |
|
|
|
|
|
custom_state[f"layers.{i}.attn_norm.weight"] = hf_state[ |
|
|
f"model.layers.{i}.input_layernorm.weight" |
|
|
] |
|
|
custom_state[f"layers.{i}.mlp_norm.weight"] = hf_state[ |
|
|
f"model.layers.{i}.post_attention_layernorm.weight" |
|
|
] |
|
|
|
|
|
|
|
|
custom_state[f"layers.{i}.attn.q_proj.weight"] = hf_state[ |
|
|
f"model.layers.{i}.self_attn.q_proj.weight" |
|
|
] |
|
|
custom_state[f"layers.{i}.attn.k_proj.weight"] = hf_state[ |
|
|
f"model.layers.{i}.self_attn.k_proj.weight" |
|
|
] |
|
|
custom_state[f"layers.{i}.attn.v_proj.weight"] = hf_state[ |
|
|
f"model.layers.{i}.self_attn.v_proj.weight" |
|
|
] |
|
|
custom_state[f"layers.{i}.attn.o_proj.weight"] = hf_state[ |
|
|
f"model.layers.{i}.self_attn.o_proj.weight" |
|
|
] |
|
|
|
|
|
|
|
|
gate = hf_state[f"model.layers.{i}.mlp.gate_proj.weight"] |
|
|
up = hf_state[f"model.layers.{i}.mlp.up_proj.weight"] |
|
|
down = hf_state[f"model.layers.{i}.mlp.down_proj.weight"] |
|
|
|
|
|
|
|
|
custom_state[f"layers.{i}.mlp.fc1.weight"] = torch.cat([gate, up], dim=0) |
|
|
|
|
|
custom_state[f"layers.{i}.mlp.fc2.weight"] = down |
|
|
|
|
|
|
|
|
custom_state["norm.weight"] = hf_state["model.norm.weight"] |
|
|
|
|
|
|
|
|
custom_state["lm_head.weight"] = hf_state["lm_head.weight"] |
|
|
|
|
|
|
|
|
missing, unexpected = custom_model.load_state_dict(custom_state, strict=False) |
|
|
return missing, unexpected |
|
|
|
|
|
|
|
|
def test_weight_loading(): |
|
|
""" |
|
|
1. Build custom SmolLM2 model (our implementation). |
|
|
2. Build HF reference model. |
|
|
3. Load HF weights into our model via mapping. |
|
|
4. Run a small test prompt and compare logits. |
|
|
""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
print("π¦ Building custom model...") |
|
|
custom_model, cfg = build_custom_model() |
|
|
custom_model.to(device) |
|
|
custom_model.eval() |
|
|
|
|
|
print("π¦ Building HF reference model...") |
|
|
hf_model = build_hf_model() |
|
|
hf_model.to(device) |
|
|
|
|
|
print("π¦ Mapping HF weights into custom model...") |
|
|
missing, unexpected = load_weights_from_hf(custom_model, hf_model) |
|
|
|
|
|
print(f"Missing keys : {len(missing)}") |
|
|
print(f"Unexpected keys : {len(unexpected)}") |
|
|
if missing: |
|
|
print(" Missing examples:", missing[:5]) |
|
|
if unexpected: |
|
|
print(" Unexpected examples:", unexpected[:5]) |
|
|
|
|
|
if len(missing) > 0: |
|
|
print("β οΈ There are missing keys; mapping may be incomplete.") |
|
|
else: |
|
|
print("β
All expected parameters were assigned from HF weights.") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_NAME) |
|
|
prompt = "Hello, how are you?" |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
print("π¦ Running HF model forward...") |
|
|
with torch.no_grad(): |
|
|
hf_logits = hf_model(**inputs).logits |
|
|
|
|
|
print("π¦ Running custom model forward...") |
|
|
with torch.no_grad(): |
|
|
custom_logits, _ = custom_model(inputs["input_ids"]) |
|
|
|
|
|
|
|
|
|
|
|
hf_logits = hf_logits.to(torch.float32) |
|
|
custom_logits = custom_logits.to(torch.float32) |
|
|
|
|
|
diff = torch.abs(hf_logits - custom_logits).max().item() |
|
|
print(f"π Max absolute difference between logits: {diff:.6f}") |
|
|
|
|
|
if diff < 1e-4: |
|
|
print("β
SUCCESS: Outputs match very closely. Implementation is correct.") |
|
|
elif diff < 1e-2: |
|
|
print("π‘ Outputs are close but not identical; check for small implementation differences (e.g., RoPE details).") |
|
|
else: |
|
|
print("β Outputs differ significantly. Some part of the implementation is likely off.") |
|
|
|
|
|
|
|
|
print("\nπ Predictions:") |
|
|
print(f"Prompt: '{prompt}'") |
|
|
|
|
|
|
|
|
hf_predicted_ids = hf_logits.argmax(dim=-1) |
|
|
custom_predicted_ids = custom_logits.argmax(dim=-1) |
|
|
|
|
|
|
|
|
hf_next_token_id = hf_predicted_ids[0, -1].item() |
|
|
custom_next_token_id = custom_predicted_ids[0, -1].item() |
|
|
|
|
|
|
|
|
hf_next_token = tokenizer.decode([hf_next_token_id]) |
|
|
custom_next_token = tokenizer.decode([custom_next_token_id]) |
|
|
|
|
|
print(f"HF Model prediction (next token): '{hf_next_token}' (token_id: {hf_next_token_id})") |
|
|
print(f"Custom Model prediction (next token): '{custom_next_token}' (token_id: {custom_next_token_id})") |
|
|
|
|
|
|
|
|
hf_full_prediction = tokenizer.decode(hf_predicted_ids[0]) |
|
|
custom_full_prediction = tokenizer.decode(custom_predicted_ids[0]) |
|
|
print(f"\nHF Model full sequence prediction: '{hf_full_prediction}'") |
|
|
print(f"Custom Model full sequence prediction: '{custom_full_prediction}'") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python test_model_implementation.py test_weight_loading") |
|
|
sys.exit(1) |
|
|
|
|
|
mode = sys.argv[1] |
|
|
|
|
|
if mode == "test_weight_loading": |
|
|
test_weight_loading() |
|
|
else: |
|
|
print(f"Unknown mode: {mode}") |
|
|
|