| | """
|
| | CUDA optimizations for Vortex model on Nvidia 4060 laptop.
|
| | Flash Attention 2, torch.compile, INT8 quantization.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from typing import Optional, Dict, Any
|
| |
|
| |
|
| | def optimize_for_cuda(
|
| | model: nn.Module,
|
| | config: Dict,
|
| | use_flash_attention: bool = True,
|
| | use_torch_compile: bool = True,
|
| | compile_mode: str = "reduce-overhead",
|
| | quantization: Optional[str] = None,
|
| | ) -> nn.Module:
|
| | """
|
| | Apply CUDA optimizations to model.
|
| |
|
| | Args:
|
| | model: VortexModel
|
| | config: Model config
|
| | use_flash_attention: Enable Flash Attention 2
|
| | use_torch_compile: Use torch.compile
|
| | compile_mode: Compile mode ("reduce-overhead", "max-autotune")
|
| | quantization: None, "int8", or "int4"
|
| |
|
| | Returns:
|
| | Optimized model
|
| | """
|
| | device = torch.device("cuda")
|
| |
|
| |
|
| | model = model.to(device)
|
| |
|
| |
|
| | dtype_str = config.get("dtype", "bfloat16")
|
| | if dtype_str == "bfloat16":
|
| | dtype = torch.bfloat16
|
| | elif dtype_str == "float16":
|
| | dtype = torch.float16
|
| | else:
|
| | dtype = torch.float32
|
| |
|
| | model = model.to(dtype)
|
| |
|
| |
|
| | if use_flash_attention:
|
| | model = _apply_flash_attention(model)
|
| | print("Applied Flash Attention 2")
|
| |
|
| |
|
| | if use_torch_compile:
|
| | model = torch.compile(
|
| | model,
|
| | mode=compile_mode,
|
| | fullgraph=True,
|
| | dynamic=True,
|
| | )
|
| | print(f"Applied torch.compile with mode={compile_mode}")
|
| |
|
| |
|
| | if quantization == "int8":
|
| | model = _apply_int8_quantization(model)
|
| | print("Applied INT8 quantization")
|
| | elif quantization == "int4":
|
| | model = _apply_int4_quantization(model)
|
| | print("Applied INT4 quantization")
|
| |
|
| | return model
|
| |
|
| |
|
| | def _apply_flash_attention(model: nn.Module) -> nn.Module:
|
| | """
|
| | Replace standard attention with Flash Attention 2.
|
| | Requires: pip install flash-attn
|
| | """
|
| | try:
|
| | from flash_attn import flash_attn_func
|
| |
|
| |
|
| | for name, module in model.named_modules():
|
| | if hasattr(module, 'use_flash_attention'):
|
| | module.use_flash_attention = True
|
| |
|
| | original_forward = module.forward
|
| |
|
| | def flash_forward(self, x, *args, **kwargs):
|
| | return self._flash_attention_forward(x, *args, **kwargs)
|
| |
|
| | module.forward = flash_forward.__get__(module, type(module))
|
| |
|
| | return model
|
| |
|
| | except ImportError:
|
| | print("Flash Attention not available. Install with: pip install flash-attn")
|
| | return model
|
| |
|
| |
|
| | def _apply_int8_quantization(model: nn.Module) -> nn.Module:
|
| | """
|
| | Apply INT8 quantization using bitsandbytes.
|
| | """
|
| | try:
|
| | import bitsandbytes as bnb
|
| |
|
| |
|
| | for name, module in model.named_modules():
|
| | if isinstance(module, nn.Linear):
|
| |
|
| | parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
| | child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
| |
|
| |
|
| | parent = model
|
| | if parent_name:
|
| | for part in parent_name.split('.'):
|
| | parent = getattr(parent, part)
|
| |
|
| |
|
| | replacement = bnb.nn.Linear8bitLt(
|
| | module.in_features,
|
| | module.out_features,
|
| | bias=module.bias is not None,
|
| | has_fp16_weights=False,
|
| | )
|
| |
|
| | replacement.weight.data = module.weight.data
|
| | if module.bias is not None:
|
| | replacement.bias.data = module.bias.data
|
| |
|
| | setattr(parent, child_name, replacement)
|
| |
|
| | return model
|
| |
|
| | except ImportError:
|
| | print("bitsandbytes not available. Install with: pip install bitsandbytes")
|
| | return model
|
| |
|
| |
|
| | def _apply_int4_quantization(model: nn.Module) -> nn.Module:
|
| | """
|
| | Apply INT4 quantization using bitsandbytes.
|
| | More aggressive, for 13B on 8GB VRAM.
|
| | """
|
| | try:
|
| | import bitsandbytes as bnb
|
| |
|
| | for name, module in model.named_modules():
|
| | if isinstance(module, nn.Linear):
|
| | parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
| | child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
| |
|
| | parent = model
|
| | if parent_name:
|
| | for part in parent_name.split('.'):
|
| | parent = getattr(parent, part)
|
| |
|
| |
|
| | replacement = bnb.nn.Linear4bit(
|
| | module.in_features,
|
| | module.out_features,
|
| | bias=module.bias is not None,
|
| | compute_dtype=torch.float16,
|
| | compress_statistics=True,
|
| | )
|
| | replacement.weight.data = module.weight.data
|
| | if module.bias is not None:
|
| | replacement.bias.data = module.bias.data
|
| |
|
| | setattr(parent, child_name, replacement)
|
| |
|
| | return model
|
| |
|
| | except ImportError:
|
| | print("bitsandbytes not available.")
|
| | return model
|
| |
|
| |
|
| | def get_cuda_memory_usage() -> Dict[str, float]:
|
| | """Get current CUDA memory usage in GB."""
|
| | if not torch.cuda.is_available():
|
| | return {"error": "CUDA not available"}
|
| |
|
| | allocated = torch.cuda.memory_allocated() / 1e9
|
| | reserved = torch.cuda.memory_reserved() / 1e9
|
| | max_allocated = torch.cuda.max_memory_allocated() / 1e9
|
| |
|
| | return {
|
| | "allocated_gb": allocated,
|
| | "reserved_gb": reserved,
|
| | "max_allocated_gb": max_allocated,
|
| | }
|
| |
|
| |
|
| | def profile_model(
|
| | model: nn.Module,
|
| | input_ids: torch.Tensor,
|
| | num_warmup: int = 10,
|
| | num_runs: int = 100,
|
| | ) -> Dict[str, float]:
|
| | """
|
| | Profile model performance.
|
| |
|
| | Args:
|
| | model: Model to profile
|
| | input_ids: Example input
|
| | num_warmup: Number of warmup runs
|
| | num_runs: Number of profiling runs
|
| |
|
| | Returns:
|
| | Dictionary with timing statistics
|
| | """
|
| | model.eval()
|
| | device = next(model.parameters()).device
|
| | input_ids = input_ids.to(device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | for _ in range(num_warmup):
|
| | _ = model(input_ids)
|
| |
|
| |
|
| | torch.cuda.synchronize()
|
| | import time
|
| | start = time.time()
|
| |
|
| | with torch.no_grad():
|
| | for _ in range(num_runs):
|
| | _ = model(input_ids)
|
| |
|
| | torch.cuda.synchronize()
|
| | elapsed = time.time() - start
|
| |
|
| | avg_time = elapsed / num_runs
|
| | tokens_per_sec = input_ids.shape[1] / avg_time
|
| |
|
| | return {
|
| | "avg_time_sec": avg_time,
|
| | "tokens_per_sec": tokens_per_sec,
|
| | }
|
| |
|
| |
|
| | def test_cuda_optimize():
|
| | """Test CUDA optimizations."""
|
| | if not torch.cuda.is_available():
|
| | print("CUDA not available, skipping test")
|
| | return
|
| |
|
| | from models.vortex_model import VortexModel
|
| | from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| |
|
| | config = VORTEX_7B_CONFIG.copy()
|
| | config["d_model"] = 512
|
| | config["num_layers"] = 2
|
| | config["num_heads"] = 8
|
| | config["vocab_size"] = 1000
|
| |
|
| | model = VortexModel(config)
|
| | print(f"Model parameters: {model.get_num_params():,}")
|
| |
|
| |
|
| | model = optimize_for_cuda(
|
| | model,
|
| | config,
|
| | use_flash_attention=False,
|
| | use_torch_compile=False,
|
| | quantization=None,
|
| | )
|
| |
|
| |
|
| | batch_size = 2
|
| | seq_len = 128
|
| | input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).cuda()
|
| |
|
| | with torch.no_grad():
|
| | output = model(input_ids)
|
| | logits = output["logits"]
|
| |
|
| | print(f"Output shape: {logits.shape}")
|
| | print("CUDA optimize test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_cuda_optimize()
|
| |
|