| |
| |
| import sys, os; sys.path.insert(0, os.path.abspath(".")) |
|
|
| |
| import types, torch, torch.nn.functional as F, importlib.machinery as im |
| flash_pkg = types.ModuleType("flash_attn"); flash_pkg.__spec__ = im.ModuleSpec("flash_attn", loader=None, is_package=True); flash_pkg.__path__=[] |
| sys.modules["flash_attn"] = flash_pkg |
| fa = types.ModuleType("flash_attn.flash_attn_interface"); fa.__spec__ = im.ModuleSpec("flash_attn.flash_attn_interface", loader=None) |
| def _sdpa(qkv,*_,causal=False,**__): q,k,v = qkv.unbind(1); q,k,v = (t.unsqueeze(0) for t in(q,k,v)); return F.scaled_dot_product_attention(q,k,v,is_causal=causal).squeeze(0) |
| for s in ("flash_attn_unpadded_qkvpacked_func","flash_attn_unpadded_kvpacked_func","flash_attn_varlen_qkvpacked_func","flash_attn_varlen_kvpacked_func"): setattr(fa, s, _sdpa) |
| sys.modules["flash_attn.flash_attn_interface"] = fa; flash_pkg.flash_attn_interface = fa |
| pad = types.ModuleType("flash_attn.bert_padding"); pad.__spec__ = im.ModuleSpec("flash_attn.bert_padding", loader=None) |
| pad.pad_input = lambda x,*a,**k:(x,None); pad.unpad_input = lambda x,*a,**k:x |
| sys.modules["flash_attn.bert_padding"] = pad; flash_pkg.bert_padding = pad |
|
|
| if not torch.cuda.is_available(): |
| torch.cuda.is_available=lambda:False |
| torch.cuda.get_device_capability=lambda dev=None:(0,0) |
| torch.cuda.current_device=lambda:0 |
| torch.cuda.get_device_properties=lambda dev=None:types.SimpleNamespace(major=0,minor=0) |
|
|
| import importlib.metadata as _im |
| if "flash_attn" not in _im.packages_distributions(): |
| rv, rd = _im.version, _im.distribution |
| _im.version = lambda p:"0.0.0" if p=="flash_attn" else rv(p) |
| _im.distribution = lambda p:types.SimpleNamespace(version="0.0.0") if p=="flash_attn" else rd(p) |
|
|
| |
| from pathlib import Path |
| import argparse, json, shutil |
| from huggingface_hub import hf_hub_download |
| from transformers import AutoConfig |
| from gr00t.model.gr00t_n1 import GR00T_N1_5 |
|
|
| |
| def patched_cfg(): |
| p = hf_hub_download("nvidia/GR00T-N1.5-3B", "config.json") |
| d = json.load(open(p)) |
| if d.get("model_type") != "gr00t_n1_5": |
| d["model_type"] = "gr00t_n1_5" |
| patched = Path(p).with_name("config_patched.json") |
| patched.write_text(json.dumps(d)); return str(patched) |
| return p |
|
|
| def build_blank(): |
| cfg = AutoConfig.from_pretrained(patched_cfg(), |
| trust_remote_code=True, |
| local_files_only=True) |
| cfg.backbone_cfg.update(dict(tune_llm=True)) |
| cfg.backbone_cfg.pop("checkpoint_path", None) |
| cfg.backbone_cfg.pop("use_pretrained", None) |
| cfg.action_head_cfg.pop("checkpoint_path", None) |
| torch.manual_seed(0) |
| return GR00T_N1_5(cfg, local_model_path="") |
|
|
| def maybe_add_lm_head(model): |
| """Ensure lm_head is properly initialized with weights""" |
| |
| lm = model.backbone.eagle_model.language_model |
| |
| |
| embed_tokens = lm.model.embed_tokens |
| vocab_size = embed_tokens.num_embeddings |
| hidden_size = embed_tokens.embedding_dim |
| |
| print(f"Embedding dimensions: vocab_size={vocab_size}, hidden_size={hidden_size}") |
| |
| |
| if vocab_size != 151680 or hidden_size != 2048: |
| print(f"β οΈ Warning: Unexpected dimensions. Expected vocab=151680, hidden=2048") |
| |
| |
| if hasattr(lm, "lm_head"): |
| print(f"lm_head attribute exists: {lm.lm_head is not None}") |
| |
| |
| |
| print("Creating new lm_head with proper initialization...") |
| else: |
| print("lm_head attribute missing, creating...") |
| |
| |
| |
| new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False) |
| |
| |
| torch.nn.init.normal_(new_lm_head.weight, mean=0.0, std=0.02) |
| |
| |
| new_lm_head.weight.data = new_lm_head.weight.data.to(torch.bfloat16) |
| |
| |
| lm.lm_head = new_lm_head |
| |
| print(f"β Created lm_head: Linear({hidden_size}, {vocab_size}, bias=False)") |
| print(f" Weight shape: {lm.lm_head.weight.shape}") |
| print(f" Weight dtype: {lm.lm_head.weight.dtype}") |
| print(f" Parameters: {lm.lm_head.weight.numel() / 1e6:.1f}M") |
|
|
| def set_mixed(model): |
| """Set mixed precision: backbone in bf16, action head in fp32""" |
| for n,p in model.named_parameters(): |
| if n.startswith("backbone.") or "lm_head" in n: |
| p.data = p.data.to(torch.bfloat16) |
| else: |
| p.data = p.data.to(torch.float32) |
|
|
| def copy_tokenizer(out): |
| for f in ("tokenizer.json","tokenizer_config.json","vocab.txt","special_tokens_map.json"): |
| try: shutil.copy(hf_hub_download("nvidia/GR00T-N1.5-3B", f), out/f) |
| except Exception: pass |
|
|
| def diagnose_model(model): |
| """Print diagnostic info about the model""" |
| print("\nModel diagnostics:") |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f" Total params: {total_params/1e6:,.0f}M") |
| |
| |
| has_lm_head = False |
| lm_head_params = 0 |
| lm_head_location = None |
| |
| for name, param in model.named_parameters(): |
| if "lm_head" in name: |
| has_lm_head = True |
| lm_head_params += param.numel() |
| lm_head_location = name |
| |
| print(f" Has lm_head: {'β' if has_lm_head else 'β'}") |
| if has_lm_head: |
| print(f" lm_head params: {lm_head_params/1e6:,.0f}M") |
| print(f" lm_head location: {lm_head_location}") |
| |
| |
| lm = model.backbone.eagle_model.language_model |
| if hasattr(lm, 'lm_head') and lm.lm_head is not None: |
| actual_params = lm.lm_head.weight.numel() |
| print(f" lm_head actual params: {actual_params/1e6:,.0f}M") |
| print(f" lm_head weight shape: {lm.lm_head.weight.shape}") |
| print(f" lm_head weight dtype: {lm.lm_head.weight.dtype}") |
|
|
| def validate_model_architecture(model): |
| """Validate model against the architecture specification""" |
| print("\n" + "="*60) |
| print("ARCHITECTURE VALIDATION") |
| print("="*60) |
| |
| |
| expected_shapes = { |
| |
| "backbone.eagle_model.language_model.lm_head.weight": (151680, 2048), |
| "backbone.eagle_model.language_model.model.embed_tokens.weight": (151680, 2048), |
| "backbone.eagle_model.language_model.model.norm.weight": (2048,), |
| "backbone.eagle_model.mlp1.0.weight": (2048, 1152), |
| "backbone.eagle_model.mlp1.0.bias": (2048,), |
| "action_head.position_embedding.weight": (1024, 1536), |
| "action_head.vlln.weight": (2048,), |
| "action_head.vlln.bias": (2048,), |
| } |
| |
| errors = [] |
| warnings = [] |
| |
| |
| param_dict = dict(model.named_parameters()) |
| |
| |
| action_head_params = [name for name in param_dict.keys() if name.startswith("action_head.position")] |
| if action_head_params: |
| print("\nFound position embedding parameters:") |
| for name in action_head_params[:5]: |
| print(f" {name}: {param_dict[name].shape}") |
| |
| |
| for name, expected_shape in expected_shapes.items(): |
| if name in param_dict: |
| actual_shape = tuple(param_dict[name].shape) |
| if actual_shape != expected_shape: |
| errors.append(f"Shape mismatch for {name}: expected {expected_shape}, got {actual_shape}") |
| else: |
| print(f"β {name}: {actual_shape}") |
| else: |
| errors.append(f"Missing parameter: {name}") |
| |
| |
| dtype_issues = [] |
| for name, param in param_dict.items(): |
| if name.startswith("backbone."): |
| if param.dtype != torch.bfloat16: |
| dtype_issues.append(f"{name}: expected bfloat16, got {param.dtype}") |
| elif name.startswith("action_head."): |
| if param.dtype != torch.float32: |
| dtype_issues.append(f"{name}: expected float32, got {param.dtype}") |
| |
| if dtype_issues: |
| warnings.extend(dtype_issues[:5]) |
| |
| |
| component_params = { |
| "backbone": 0, |
| "action_head": 0, |
| "other": 0 |
| } |
| |
| for name, param in param_dict.items(): |
| count = param.numel() |
| if name.startswith("backbone."): |
| component_params["backbone"] += count |
| elif name.startswith("action_head."): |
| component_params["action_head"] += count |
| else: |
| component_params["other"] += count |
| |
| |
| lm_head_found = False |
| lm_head_params = 0 |
| for name, param in param_dict.items(): |
| if "lm_head" in name: |
| lm_head_found = True |
| lm_head_params += param.numel() |
| |
| |
| print("\nValidation Results:") |
| print(f" Errors: {len(errors)}") |
| print(f" Warnings: {len(warnings)}") |
| |
| if errors: |
| print("\nβ ERRORS:") |
| for error in errors: |
| print(f" - {error}") |
| |
| if warnings: |
| print("\nβ οΈ WARNINGS (showing first 5):") |
| for warning in warnings[:5]: |
| print(f" - {warning}") |
| if len(warnings) > 5: |
| print(f" ... and {len(warnings) - 5} more") |
| |
| print("\nπ Parameter Summary:") |
| total = sum(component_params.values()) |
| print(f" Total: {total/1e6:,.1f}M") |
| print(f" Backbone: {component_params['backbone']/1e6:,.1f}M") |
| print(f" Action Head: {component_params['action_head']/1e6:,.1f}M") |
| if component_params['other'] > 0: |
| print(f" Other: {component_params['other']/1e6:,.1f}M") |
| |
| print(f"\n lm_head found: {'β' if lm_head_found else 'β'}") |
| if lm_head_found: |
| print(f" lm_head params: {lm_head_params/1e6:.1f}M (expected: 311.1M)") |
| |
| |
| expected_total = 2724 |
| actual_total = total / 1e6 |
| diff = actual_total - expected_total |
| |
| print(f"\n Expected total: {expected_total}M") |
| print(f" Actual total: {actual_total:.1f}M") |
| print(f" Difference: {diff:+.1f}M") |
| |
| if abs(diff) < 1: |
| print("\nβ
Model architecture matches expected specification!") |
| else: |
| print("\nβ Model architecture does NOT match specification!") |
| |
| return len(errors) == 0 |
|
|
| |
| def main(device: str, out_dir: str): |
| print("="*60) |
| print("Creating blank GR00T-N1.5-3B model") |
| print("="*60) |
| |
| model = build_blank() |
| |
| |
| print("\nBefore adding lm_head:") |
| diagnose_model(model) |
| |
| maybe_add_lm_head(model) |
| |
| |
| print("\nAfter adding lm_head:") |
| diagnose_model(model) |
| |
| set_mixed(model) |
| model = model.to(device) |
| |
| |
| validate_model_architecture(model) |
|
|
| out = Path(out_dir).expanduser(); out.mkdir(parents=True, exist_ok=True) |
| |
| print(f"\nSaving model to {out}...") |
| model.save_pretrained(out, max_shard_size="2GB") |
| copy_tokenizer(out) |
| (out/"README.md").write_text("Random GR00T-N1.5-3B | backbone bf16 | action_head fp32 | Apache-2.0\n") |
| |
| |
| print("\n" + "="*60) |
| print("FINAL SUMMARY") |
| print("="*60) |
| print(f"β
Saved blank model ({sum(p.numel() for p in model.parameters())/1e6:,.0f}M params) β {out}") |
| print(f"β
Model has lm_head with {model.backbone.eagle_model.language_model.lm_head.weight.numel()/1e6:.1f}M params") |
| print(f"β
Ready for training with Apache-2.0 license") |
|
|
| |
| if __name__ == "__main__": |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--device", default="cpu") |
| ap.add_argument("--out_dir", default="DolphinGR00T-N1.5-3B-Zero") |
| args = ap.parse_args(); main(args.device, args.out_dir) |