File size: 4,324 Bytes
53f0cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Component 4 verification script.

This script:
- Builds model from config.
- Runs a small forward pass.
- Prints live VRAM usage at each stage.
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Any, Dict

import torch
import yaml

# Ensure src imports work from project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.model_architecture.code_transformer import (  # noqa: E402
    CodeTransformerLM,
    ModelConfig,
    get_model_presets,
)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Verify Component 4 model load and VRAM usage.")
    parser.add_argument(
        "--config",
        default="configs/component4_model_config.yaml",
        help="Path to model YAML config.",
    )
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for forward test.")
    parser.add_argument("--seq_len", type=int, default=256, help="Sequence length for forward test.")
    return parser.parse_args()


def load_yaml(path: Path) -> Dict[str, Any]:
    if not path.exists():
        raise FileNotFoundError(f"Model config not found: {path}")
    with path.open("r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    if not isinstance(data, dict):
        raise ValueError("Invalid YAML format in model config.")
    return data


def build_config(cfg_data: Dict[str, Any]) -> ModelConfig:
    preset = cfg_data.get("preset")
    model_cfg = cfg_data.get("model", {})
    if not isinstance(model_cfg, dict):
        raise ValueError("Config key 'model' must be an object.")

    if preset:
        presets = get_model_presets()
        if preset not in presets:
            raise ValueError(f"Unknown preset '{preset}'.")
        base = presets[preset]
        merged = base.__dict__.copy()
        merged.update(model_cfg)
        return ModelConfig(**merged)
    return ModelConfig(**model_cfg)


def gpu_memory_report(stage: str) -> None:
    if not torch.cuda.is_available():
        print(f"[{stage}] CUDA not available")
        return
    allocated = torch.cuda.memory_allocated() / (1024**3)
    reserved = torch.cuda.memory_reserved() / (1024**3)
    max_alloc = torch.cuda.max_memory_allocated() / (1024**3)
    print(
        f"[{stage}] VRAM allocated={allocated:.2f} GB "
        f"reserved={reserved:.2f} GB max_allocated={max_alloc:.2f} GB"
    )


def main() -> None:
    args = parse_args()
    try:
        cfg_data = load_yaml(Path(args.config))
        model_cfg = build_config(cfg_data)
        if args.seq_len > model_cfg.max_seq_len:
            raise ValueError(
                f"seq_len={args.seq_len} exceeds max_seq_len={model_cfg.max_seq_len} in config."
            )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        use_fp16 = device.type == "cuda"
        if device.type == "cuda":
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            print(f"Detected GPU: {torch.cuda.get_device_name(0)}")
            gpu_memory_report("start")
        else:
            print("CUDA not available. Running verification on CPU.")

        model = CodeTransformerLM(model_cfg)
        print(f"Model parameters: {model.estimate_num_parameters():,}")

        if use_fp16:
            model = model.half()
        model.to(device)
        model.eval()
        gpu_memory_report("after_model_load")

        input_ids = torch.randint(
            low=0,
            high=model_cfg.vocab_size,
            size=(args.batch_size, args.seq_len),
            dtype=torch.long,
            device=device,
        )
        gpu_memory_report("after_input_alloc")

        with torch.no_grad():
            out = model(input_ids=input_ids)
        logits = out["logits"]
        gpu_memory_report("after_forward")

        print(f"Forward output shape: {tuple(logits.shape)}")
        print("Component 4 verification passed.")
    except Exception as exc:
        print("Component 4 verification failed.")
        print(f"What went wrong: {exc}")
        print("Fix suggestion: reduce seq_len or check CUDA/PyTorch installation.")
        raise SystemExit(1)


if __name__ == "__main__":
    main()