mindi-backup / scripts /build_component4_model.py
Mindigenous
Initial full project backup with Git LFS
53f0cc2
"""
Build/inspect script for Component 4 model architecture.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict
import yaml
# Ensure src imports work from project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.model_architecture.code_transformer import ( # noqa: E402
CodeTransformerLM,
ModelConfig,
get_model_presets,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Build and inspect Component 4 model.")
parser.add_argument(
"--config",
default="configs/component4_model_config.yaml",
help="Path to model YAML config.",
)
parser.add_argument(
"--save_summary",
default="artifacts/model/component4_model_summary.json",
help="Where to save model summary JSON.",
)
return parser.parse_args()
def load_yaml(path: Path) -> Dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Model config not found: {path}")
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError("Invalid YAML format in model config.")
return data
def build_config(cfg_data: Dict[str, Any]) -> ModelConfig:
preset = cfg_data.get("preset")
model_cfg = cfg_data.get("model", {})
if not isinstance(model_cfg, dict):
raise ValueError("Config key 'model' must be an object.")
base = None
if preset:
presets = get_model_presets()
if preset not in presets:
raise ValueError(f"Unknown preset '{preset}'. Available: {list(presets.keys())}")
base = presets[preset]
if base is None:
return ModelConfig(**model_cfg)
merged = {
"vocab_size": base.vocab_size,
"max_seq_len": base.max_seq_len,
"d_model": base.d_model,
"n_layers": base.n_layers,
"n_heads": base.n_heads,
"d_ff": base.d_ff,
"dropout": base.dropout,
"tie_embeddings": base.tie_embeddings,
"gradient_checkpointing": base.gradient_checkpointing,
"init_std": base.init_std,
"rms_norm_eps": base.rms_norm_eps,
}
merged.update(model_cfg)
return ModelConfig(**merged)
def main() -> None:
args = parse_args()
try:
cfg_data = load_yaml(Path(args.config))
model_cfg = build_config(cfg_data)
model = CodeTransformerLM(model_cfg)
summary = model.summary()
save_path = Path(args.save_summary)
save_path.parent.mkdir(parents=True, exist_ok=True)
with save_path.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2)
print("Component 4 model build completed.")
print(f"Preset: {cfg_data.get('preset')}")
print(f"Parameters: {summary['num_parameters']:,}")
print(f"Saved summary: {save_path}")
except Exception as exc:
print("Component 4 model build failed.")
print(f"What went wrong: {exc}")
print("Fix suggestion: check config values (especially d_model and n_heads divisibility).")
raise SystemExit(1)
if __name__ == "__main__":
main()