Vortex-7b-V1 / cuda_optimize.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
CUDA optimizations for Vortex model on Nvidia 4060 laptop.
Flash Attention 2, torch.compile, INT8 quantization.
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Any
def optimize_for_cuda(
model: nn.Module,
config: Dict,
use_flash_attention: bool = True,
use_torch_compile: bool = True,
compile_mode: str = "reduce-overhead",
quantization: Optional[str] = None,
) -> nn.Module:
"""
Apply CUDA optimizations to model.
Args:
model: VortexModel
config: Model config
use_flash_attention: Enable Flash Attention 2
use_torch_compile: Use torch.compile
compile_mode: Compile mode ("reduce-overhead", "max-autotune")
quantization: None, "int8", or "int4"
Returns:
Optimized model
"""
device = torch.device("cuda")
# Move to CUDA
model = model.to(device)
# Set dtype
dtype_str = config.get("dtype", "bfloat16")
if dtype_str == "bfloat16":
dtype = torch.bfloat16
elif dtype_str == "float16":
dtype = torch.float16
else:
dtype = torch.float32
model = model.to(dtype)
# Apply Flash Attention 2 to attention layers
if use_flash_attention:
model = _apply_flash_attention(model)
print("Applied Flash Attention 2")
# Apply torch.compile
if use_torch_compile:
model = torch.compile(
model,
mode=compile_mode,
fullgraph=True,
dynamic=True,
)
print(f"Applied torch.compile with mode={compile_mode}")
# Apply quantization if requested
if quantization == "int8":
model = _apply_int8_quantization(model)
print("Applied INT8 quantization")
elif quantization == "int4":
model = _apply_int4_quantization(model)
print("Applied INT4 quantization")
return model
def _apply_flash_attention(model: nn.Module) -> nn.Module:
"""
Replace standard attention with Flash Attention 2.
Requires: pip install flash-attn
"""
try:
from flash_attn import flash_attn_func
# Monkey-patch attention layers to use flash attention
for name, module in model.named_modules():
if hasattr(module, 'use_flash_attention'):
module.use_flash_attention = True
# Replace forward with flash attention version
original_forward = module.forward
def flash_forward(self, x, *args, **kwargs):
return self._flash_attention_forward(x, *args, **kwargs)
module.forward = flash_forward.__get__(module, type(module))
return model
except ImportError:
print("Flash Attention not available. Install with: pip install flash-attn")
return model
def _apply_int8_quantization(model: nn.Module) -> nn.Module:
"""
Apply INT8 quantization using bitsandbytes.
"""
try:
import bitsandbytes as bnb
# Replace linear layers with 8-bit variants
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Create 8-bit linear replacement
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
child_name = name.rsplit('.', 1)[1] if '.' in name else name
# Get parent module
parent = model
if parent_name:
for part in parent_name.split('.'):
parent = getattr(parent, part)
# Replace with 8-bit linear
replacement = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
bias=module.bias is not None,
has_fp16_weights=False,
)
# Copy weights (will be quantized)
replacement.weight.data = module.weight.data
if module.bias is not None:
replacement.bias.data = module.bias.data
setattr(parent, child_name, replacement)
return model
except ImportError:
print("bitsandbytes not available. Install with: pip install bitsandbytes")
return model
def _apply_int4_quantization(model: nn.Module) -> nn.Module:
"""
Apply INT4 quantization using bitsandbytes.
More aggressive, for 13B on 8GB VRAM.
"""
try:
import bitsandbytes as bnb
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
child_name = name.rsplit('.', 1)[1] if '.' in name else name
parent = model
if parent_name:
for part in parent_name.split('.'):
parent = getattr(parent, part)
# 4-bit linear
replacement = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
bias=module.bias is not None,
compute_dtype=torch.float16,
compress_statistics=True,
)
replacement.weight.data = module.weight.data
if module.bias is not None:
replacement.bias.data = module.bias.data
setattr(parent, child_name, replacement)
return model
except ImportError:
print("bitsandbytes not available.")
return model
def get_cuda_memory_usage() -> Dict[str, float]:
"""Get current CUDA memory usage in GB."""
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
max_allocated = torch.cuda.max_memory_allocated() / 1e9
return {
"allocated_gb": allocated,
"reserved_gb": reserved,
"max_allocated_gb": max_allocated,
}
def profile_model(
model: nn.Module,
input_ids: torch.Tensor,
num_warmup: int = 10,
num_runs: int = 100,
) -> Dict[str, float]:
"""
Profile model performance.
Args:
model: Model to profile
input_ids: Example input
num_warmup: Number of warmup runs
num_runs: Number of profiling runs
Returns:
Dictionary with timing statistics
"""
model.eval()
device = next(model.parameters()).device
input_ids = input_ids.to(device)
# Warmup
with torch.no_grad():
for _ in range(num_warmup):
_ = model(input_ids)
# Profile
torch.cuda.synchronize()
import time
start = time.time()
with torch.no_grad():
for _ in range(num_runs):
_ = model(input_ids)
torch.cuda.synchronize()
elapsed = time.time() - start
avg_time = elapsed / num_runs
tokens_per_sec = input_ids.shape[1] / avg_time
return {
"avg_time_sec": avg_time,
"tokens_per_sec": tokens_per_sec,
}
def test_cuda_optimize():
"""Test CUDA optimizations."""
if not torch.cuda.is_available():
print("CUDA not available, skipping test")
return
from models.vortex_model import VortexModel
from configs.vortex_7b_config import VORTEX_7B_CONFIG
config = VORTEX_7B_CONFIG.copy()
config["d_model"] = 512
config["num_layers"] = 2
config["num_heads"] = 8
config["vocab_size"] = 1000
model = VortexModel(config)
print(f"Model parameters: {model.get_num_params():,}")
# Optimize
model = optimize_for_cuda(
model,
config,
use_flash_attention=False, # May not be available
use_torch_compile=False, # Skip compile for test
quantization=None,
)
# Test forward
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).cuda()
with torch.no_grad():
output = model(input_ids)
logits = output["logits"]
print(f"Output shape: {logits.shape}")
print("CUDA optimize test passed!")
if __name__ == "__main__":
test_cuda_optimize()