File size: 3,317 Bytes
53f0cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Build/inspect script for Component 4 model architecture.
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict

import yaml

# Ensure src imports work from project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.model_architecture.code_transformer import (  # noqa: E402
    CodeTransformerLM,
    ModelConfig,
    get_model_presets,
)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Build and inspect Component 4 model.")
    parser.add_argument(
        "--config",
        default="configs/component4_model_config.yaml",
        help="Path to model YAML config.",
    )
    parser.add_argument(
        "--save_summary",
        default="artifacts/model/component4_model_summary.json",
        help="Where to save model summary JSON.",
    )
    return parser.parse_args()


def load_yaml(path: Path) -> Dict[str, Any]:
    if not path.exists():
        raise FileNotFoundError(f"Model config not found: {path}")
    with path.open("r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    if not isinstance(data, dict):
        raise ValueError("Invalid YAML format in model config.")
    return data


def build_config(cfg_data: Dict[str, Any]) -> ModelConfig:
    preset = cfg_data.get("preset")
    model_cfg = cfg_data.get("model", {})
    if not isinstance(model_cfg, dict):
        raise ValueError("Config key 'model' must be an object.")

    base = None
    if preset:
        presets = get_model_presets()
        if preset not in presets:
            raise ValueError(f"Unknown preset '{preset}'. Available: {list(presets.keys())}")
        base = presets[preset]

    if base is None:
        return ModelConfig(**model_cfg)

    merged = {
        "vocab_size": base.vocab_size,
        "max_seq_len": base.max_seq_len,
        "d_model": base.d_model,
        "n_layers": base.n_layers,
        "n_heads": base.n_heads,
        "d_ff": base.d_ff,
        "dropout": base.dropout,
        "tie_embeddings": base.tie_embeddings,
        "gradient_checkpointing": base.gradient_checkpointing,
        "init_std": base.init_std,
        "rms_norm_eps": base.rms_norm_eps,
    }
    merged.update(model_cfg)
    return ModelConfig(**merged)


def main() -> None:
    args = parse_args()
    try:
        cfg_data = load_yaml(Path(args.config))
        model_cfg = build_config(cfg_data)
        model = CodeTransformerLM(model_cfg)
        summary = model.summary()

        save_path = Path(args.save_summary)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        with save_path.open("w", encoding="utf-8") as f:
            json.dump(summary, f, indent=2)

        print("Component 4 model build completed.")
        print(f"Preset: {cfg_data.get('preset')}")
        print(f"Parameters: {summary['num_parameters']:,}")
        print(f"Saved summary: {save_path}")
    except Exception as exc:
        print("Component 4 model build failed.")
        print(f"What went wrong: {exc}")
        print("Fix suggestion: check config values (especially d_model and n_heads divisibility).")
        raise SystemExit(1)


if __name__ == "__main__":
    main()