File size: 8,497 Bytes
5c43f61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 | """
CUDA optimizations for Vortex model on Nvidia 4060 laptop.
Flash Attention 2, torch.compile, INT8 quantization.
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Any
def optimize_for_cuda(
model: nn.Module,
config: Dict,
use_flash_attention: bool = True,
use_torch_compile: bool = True,
compile_mode: str = "reduce-overhead",
quantization: Optional[str] = None,
) -> nn.Module:
"""
Apply CUDA optimizations to model.
Args:
model: VortexModel
config: Model config
use_flash_attention: Enable Flash Attention 2
use_torch_compile: Use torch.compile
compile_mode: Compile mode ("reduce-overhead", "max-autotune")
quantization: None, "int8", or "int4"
Returns:
Optimized model
"""
device = torch.device("cuda")
# Move to CUDA
model = model.to(device)
# Set dtype
dtype_str = config.get("dtype", "bfloat16")
if dtype_str == "bfloat16":
dtype = torch.bfloat16
elif dtype_str == "float16":
dtype = torch.float16
else:
dtype = torch.float32
model = model.to(dtype)
# Apply Flash Attention 2 to attention layers
if use_flash_attention:
model = _apply_flash_attention(model)
print("Applied Flash Attention 2")
# Apply torch.compile
if use_torch_compile:
model = torch.compile(
model,
mode=compile_mode,
fullgraph=True,
dynamic=True,
)
print(f"Applied torch.compile with mode={compile_mode}")
# Apply quantization if requested
if quantization == "int8":
model = _apply_int8_quantization(model)
print("Applied INT8 quantization")
elif quantization == "int4":
model = _apply_int4_quantization(model)
print("Applied INT4 quantization")
return model
def _apply_flash_attention(model: nn.Module) -> nn.Module:
"""
Replace standard attention with Flash Attention 2.
Requires: pip install flash-attn
"""
try:
from flash_attn import flash_attn_func
# Monkey-patch attention layers to use flash attention
for name, module in model.named_modules():
if hasattr(module, 'use_flash_attention'):
module.use_flash_attention = True
# Replace forward with flash attention version
original_forward = module.forward
def flash_forward(self, x, *args, **kwargs):
return self._flash_attention_forward(x, *args, **kwargs)
module.forward = flash_forward.__get__(module, type(module))
return model
except ImportError:
print("Flash Attention not available. Install with: pip install flash-attn")
return model
def _apply_int8_quantization(model: nn.Module) -> nn.Module:
"""
Apply INT8 quantization using bitsandbytes.
"""
try:
import bitsandbytes as bnb
# Replace linear layers with 8-bit variants
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Create 8-bit linear replacement
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
child_name = name.rsplit('.', 1)[1] if '.' in name else name
# Get parent module
parent = model
if parent_name:
for part in parent_name.split('.'):
parent = getattr(parent, part)
# Replace with 8-bit linear
replacement = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
bias=module.bias is not None,
has_fp16_weights=False,
)
# Copy weights (will be quantized)
replacement.weight.data = module.weight.data
if module.bias is not None:
replacement.bias.data = module.bias.data
setattr(parent, child_name, replacement)
return model
except ImportError:
print("bitsandbytes not available. Install with: pip install bitsandbytes")
return model
def _apply_int4_quantization(model: nn.Module) -> nn.Module:
"""
Apply INT4 quantization using bitsandbytes.
More aggressive, for 13B on 8GB VRAM.
"""
try:
import bitsandbytes as bnb
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
child_name = name.rsplit('.', 1)[1] if '.' in name else name
parent = model
if parent_name:
for part in parent_name.split('.'):
parent = getattr(parent, part)
# 4-bit linear
replacement = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
bias=module.bias is not None,
compute_dtype=torch.float16,
compress_statistics=True,
)
replacement.weight.data = module.weight.data
if module.bias is not None:
replacement.bias.data = module.bias.data
setattr(parent, child_name, replacement)
return model
except ImportError:
print("bitsandbytes not available.")
return model
def get_cuda_memory_usage() -> Dict[str, float]:
"""Get current CUDA memory usage in GB."""
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
max_allocated = torch.cuda.max_memory_allocated() / 1e9
return {
"allocated_gb": allocated,
"reserved_gb": reserved,
"max_allocated_gb": max_allocated,
}
def profile_model(
model: nn.Module,
input_ids: torch.Tensor,
num_warmup: int = 10,
num_runs: int = 100,
) -> Dict[str, float]:
"""
Profile model performance.
Args:
model: Model to profile
input_ids: Example input
num_warmup: Number of warmup runs
num_runs: Number of profiling runs
Returns:
Dictionary with timing statistics
"""
model.eval()
device = next(model.parameters()).device
input_ids = input_ids.to(device)
# Warmup
with torch.no_grad():
for _ in range(num_warmup):
_ = model(input_ids)
# Profile
torch.cuda.synchronize()
import time
start = time.time()
with torch.no_grad():
for _ in range(num_runs):
_ = model(input_ids)
torch.cuda.synchronize()
elapsed = time.time() - start
avg_time = elapsed / num_runs
tokens_per_sec = input_ids.shape[1] / avg_time
return {
"avg_time_sec": avg_time,
"tokens_per_sec": tokens_per_sec,
}
def test_cuda_optimize():
"""Test CUDA optimizations."""
if not torch.cuda.is_available():
print("CUDA not available, skipping test")
return
from models.vortex_model import VortexModel
from configs.vortex_7b_config import VORTEX_7B_CONFIG
config = VORTEX_7B_CONFIG.copy()
config["d_model"] = 512
config["num_layers"] = 2
config["num_heads"] = 8
config["vocab_size"] = 1000
model = VortexModel(config)
print(f"Model parameters: {model.get_num_params():,}")
# Optimize
model = optimize_for_cuda(
model,
config,
use_flash_attention=False, # May not be available
use_torch_compile=False, # Skip compile for test
quantization=None,
)
# Test forward
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).cuda()
with torch.no_grad():
output = model(input_ids)
logits = output["logits"]
print(f"Output shape: {logits.shape}")
print("CUDA optimize test passed!")
if __name__ == "__main__":
test_cuda_optimize()
|