""" Build/inspect script for Component 4 model architecture. """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any, Dict import yaml # Ensure src imports work from project root. PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from src.model_architecture.code_transformer import ( # noqa: E402 CodeTransformerLM, ModelConfig, get_model_presets, ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Build and inspect Component 4 model.") parser.add_argument( "--config", default="configs/component4_model_config.yaml", help="Path to model YAML config.", ) parser.add_argument( "--save_summary", default="artifacts/model/component4_model_summary.json", help="Where to save model summary JSON.", ) return parser.parse_args() def load_yaml(path: Path) -> Dict[str, Any]: if not path.exists(): raise FileNotFoundError(f"Model config not found: {path}") with path.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) if not isinstance(data, dict): raise ValueError("Invalid YAML format in model config.") return data def build_config(cfg_data: Dict[str, Any]) -> ModelConfig: preset = cfg_data.get("preset") model_cfg = cfg_data.get("model", {}) if not isinstance(model_cfg, dict): raise ValueError("Config key 'model' must be an object.") base = None if preset: presets = get_model_presets() if preset not in presets: raise ValueError(f"Unknown preset '{preset}'. Available: {list(presets.keys())}") base = presets[preset] if base is None: return ModelConfig(**model_cfg) merged = { "vocab_size": base.vocab_size, "max_seq_len": base.max_seq_len, "d_model": base.d_model, "n_layers": base.n_layers, "n_heads": base.n_heads, "d_ff": base.d_ff, "dropout": base.dropout, "tie_embeddings": base.tie_embeddings, "gradient_checkpointing": base.gradient_checkpointing, "init_std": base.init_std, "rms_norm_eps": base.rms_norm_eps, } merged.update(model_cfg) return ModelConfig(**merged) def main() -> None: args = parse_args() try: cfg_data = load_yaml(Path(args.config)) model_cfg = build_config(cfg_data) model = CodeTransformerLM(model_cfg) summary = model.summary() save_path = Path(args.save_summary) save_path.parent.mkdir(parents=True, exist_ok=True) with save_path.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2) print("Component 4 model build completed.") print(f"Preset: {cfg_data.get('preset')}") print(f"Parameters: {summary['num_parameters']:,}") print(f"Saved summary: {save_path}") except Exception as exc: print("Component 4 model build failed.") print(f"What went wrong: {exc}") print("Fix suggestion: check config values (especially d_model and n_heads divisibility).") raise SystemExit(1) if __name__ == "__main__": main()