| """ |
| Component 4 verification script. |
| |
| This script: |
| - Builds model from config. |
| - Runs a small forward pass. |
| - Prints live VRAM usage at each stage. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict |
|
|
| import torch |
| import yaml |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from src.model_architecture.code_transformer import ( |
| CodeTransformerLM, |
| ModelConfig, |
| get_model_presets, |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Verify Component 4 model load and VRAM usage.") |
| parser.add_argument( |
| "--config", |
| default="configs/component4_model_config.yaml", |
| help="Path to model YAML config.", |
| ) |
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size for forward test.") |
| parser.add_argument("--seq_len", type=int, default=256, help="Sequence length for forward test.") |
| return parser.parse_args() |
|
|
|
|
| def load_yaml(path: Path) -> Dict[str, Any]: |
| if not path.exists(): |
| raise FileNotFoundError(f"Model config not found: {path}") |
| with path.open("r", encoding="utf-8") as f: |
| data = yaml.safe_load(f) |
| if not isinstance(data, dict): |
| raise ValueError("Invalid YAML format in model config.") |
| return data |
|
|
|
|
| def build_config(cfg_data: Dict[str, Any]) -> ModelConfig: |
| preset = cfg_data.get("preset") |
| model_cfg = cfg_data.get("model", {}) |
| if not isinstance(model_cfg, dict): |
| raise ValueError("Config key 'model' must be an object.") |
|
|
| if preset: |
| presets = get_model_presets() |
| if preset not in presets: |
| raise ValueError(f"Unknown preset '{preset}'.") |
| base = presets[preset] |
| merged = base.__dict__.copy() |
| merged.update(model_cfg) |
| return ModelConfig(**merged) |
| return ModelConfig(**model_cfg) |
|
|
|
|
| def gpu_memory_report(stage: str) -> None: |
| if not torch.cuda.is_available(): |
| print(f"[{stage}] CUDA not available") |
| return |
| allocated = torch.cuda.memory_allocated() / (1024**3) |
| reserved = torch.cuda.memory_reserved() / (1024**3) |
| max_alloc = torch.cuda.max_memory_allocated() / (1024**3) |
| print( |
| f"[{stage}] VRAM allocated={allocated:.2f} GB " |
| f"reserved={reserved:.2f} GB max_allocated={max_alloc:.2f} GB" |
| ) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| try: |
| cfg_data = load_yaml(Path(args.config)) |
| model_cfg = build_config(cfg_data) |
| if args.seq_len > model_cfg.max_seq_len: |
| raise ValueError( |
| f"seq_len={args.seq_len} exceeds max_seq_len={model_cfg.max_seq_len} in config." |
| ) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| use_fp16 = device.type == "cuda" |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
| print(f"Detected GPU: {torch.cuda.get_device_name(0)}") |
| gpu_memory_report("start") |
| else: |
| print("CUDA not available. Running verification on CPU.") |
|
|
| model = CodeTransformerLM(model_cfg) |
| print(f"Model parameters: {model.estimate_num_parameters():,}") |
|
|
| if use_fp16: |
| model = model.half() |
| model.to(device) |
| model.eval() |
| gpu_memory_report("after_model_load") |
|
|
| input_ids = torch.randint( |
| low=0, |
| high=model_cfg.vocab_size, |
| size=(args.batch_size, args.seq_len), |
| dtype=torch.long, |
| device=device, |
| ) |
| gpu_memory_report("after_input_alloc") |
|
|
| with torch.no_grad(): |
| out = model(input_ids=input_ids) |
| logits = out["logits"] |
| gpu_memory_report("after_forward") |
|
|
| print(f"Forward output shape: {tuple(logits.shape)}") |
| print("Component 4 verification passed.") |
| except Exception as exc: |
| print("Component 4 verification failed.") |
| print(f"What went wrong: {exc}") |
| print("Fix suggestion: reduce seq_len or check CUDA/PyTorch installation.") |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|