File size: 4,324 Bytes
53f0cc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
Component 4 verification script.
This script:
- Builds model from config.
- Runs a small forward pass.
- Prints live VRAM usage at each stage.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Any, Dict
import torch
import yaml
# Ensure src imports work from project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.model_architecture.code_transformer import ( # noqa: E402
CodeTransformerLM,
ModelConfig,
get_model_presets,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Verify Component 4 model load and VRAM usage.")
parser.add_argument(
"--config",
default="configs/component4_model_config.yaml",
help="Path to model YAML config.",
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for forward test.")
parser.add_argument("--seq_len", type=int, default=256, help="Sequence length for forward test.")
return parser.parse_args()
def load_yaml(path: Path) -> Dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Model config not found: {path}")
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError("Invalid YAML format in model config.")
return data
def build_config(cfg_data: Dict[str, Any]) -> ModelConfig:
preset = cfg_data.get("preset")
model_cfg = cfg_data.get("model", {})
if not isinstance(model_cfg, dict):
raise ValueError("Config key 'model' must be an object.")
if preset:
presets = get_model_presets()
if preset not in presets:
raise ValueError(f"Unknown preset '{preset}'.")
base = presets[preset]
merged = base.__dict__.copy()
merged.update(model_cfg)
return ModelConfig(**merged)
return ModelConfig(**model_cfg)
def gpu_memory_report(stage: str) -> None:
if not torch.cuda.is_available():
print(f"[{stage}] CUDA not available")
return
allocated = torch.cuda.memory_allocated() / (1024**3)
reserved = torch.cuda.memory_reserved() / (1024**3)
max_alloc = torch.cuda.max_memory_allocated() / (1024**3)
print(
f"[{stage}] VRAM allocated={allocated:.2f} GB "
f"reserved={reserved:.2f} GB max_allocated={max_alloc:.2f} GB"
)
def main() -> None:
args = parse_args()
try:
cfg_data = load_yaml(Path(args.config))
model_cfg = build_config(cfg_data)
if args.seq_len > model_cfg.max_seq_len:
raise ValueError(
f"seq_len={args.seq_len} exceeds max_seq_len={model_cfg.max_seq_len} in config."
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_fp16 = device.type == "cuda"
if device.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"Detected GPU: {torch.cuda.get_device_name(0)}")
gpu_memory_report("start")
else:
print("CUDA not available. Running verification on CPU.")
model = CodeTransformerLM(model_cfg)
print(f"Model parameters: {model.estimate_num_parameters():,}")
if use_fp16:
model = model.half()
model.to(device)
model.eval()
gpu_memory_report("after_model_load")
input_ids = torch.randint(
low=0,
high=model_cfg.vocab_size,
size=(args.batch_size, args.seq_len),
dtype=torch.long,
device=device,
)
gpu_memory_report("after_input_alloc")
with torch.no_grad():
out = model(input_ids=input_ids)
logits = out["logits"]
gpu_memory_report("after_forward")
print(f"Forward output shape: {tuple(logits.shape)}")
print("Component 4 verification passed.")
except Exception as exc:
print("Component 4 verification failed.")
print(f"What went wrong: {exc}")
print("Fix suggestion: reduce seq_len or check CUDA/PyTorch installation.")
raise SystemExit(1)
if __name__ == "__main__":
main()
|