""" CUDA optimizations for Vortex model on Nvidia 4060 laptop. Flash Attention 2, torch.compile, INT8 quantization. """ import torch import torch.nn as nn from typing import Optional, Dict, Any def optimize_for_cuda( model: nn.Module, config: Dict, use_flash_attention: bool = True, use_torch_compile: bool = True, compile_mode: str = "reduce-overhead", quantization: Optional[str] = None, ) -> nn.Module: """ Apply CUDA optimizations to model. Args: model: VortexModel config: Model config use_flash_attention: Enable Flash Attention 2 use_torch_compile: Use torch.compile compile_mode: Compile mode ("reduce-overhead", "max-autotune") quantization: None, "int8", or "int4" Returns: Optimized model """ device = torch.device("cuda") # Move to CUDA model = model.to(device) # Set dtype dtype_str = config.get("dtype", "bfloat16") if dtype_str == "bfloat16": dtype = torch.bfloat16 elif dtype_str == "float16": dtype = torch.float16 else: dtype = torch.float32 model = model.to(dtype) # Apply Flash Attention 2 to attention layers if use_flash_attention: model = _apply_flash_attention(model) print("Applied Flash Attention 2") # Apply torch.compile if use_torch_compile: model = torch.compile( model, mode=compile_mode, fullgraph=True, dynamic=True, ) print(f"Applied torch.compile with mode={compile_mode}") # Apply quantization if requested if quantization == "int8": model = _apply_int8_quantization(model) print("Applied INT8 quantization") elif quantization == "int4": model = _apply_int4_quantization(model) print("Applied INT4 quantization") return model def _apply_flash_attention(model: nn.Module) -> nn.Module: """ Replace standard attention with Flash Attention 2. Requires: pip install flash-attn """ try: from flash_attn import flash_attn_func # Monkey-patch attention layers to use flash attention for name, module in model.named_modules(): if hasattr(module, 'use_flash_attention'): module.use_flash_attention = True # Replace forward with flash attention version original_forward = module.forward def flash_forward(self, x, *args, **kwargs): return self._flash_attention_forward(x, *args, **kwargs) module.forward = flash_forward.__get__(module, type(module)) return model except ImportError: print("Flash Attention not available. Install with: pip install flash-attn") return model def _apply_int8_quantization(model: nn.Module) -> nn.Module: """ Apply INT8 quantization using bitsandbytes. """ try: import bitsandbytes as bnb # Replace linear layers with 8-bit variants for name, module in model.named_modules(): if isinstance(module, nn.Linear): # Create 8-bit linear replacement parent_name = name.rsplit('.', 1)[0] if '.' in name else '' child_name = name.rsplit('.', 1)[1] if '.' in name else name # Get parent module parent = model if parent_name: for part in parent_name.split('.'): parent = getattr(parent, part) # Replace with 8-bit linear replacement = bnb.nn.Linear8bitLt( module.in_features, module.out_features, bias=module.bias is not None, has_fp16_weights=False, ) # Copy weights (will be quantized) replacement.weight.data = module.weight.data if module.bias is not None: replacement.bias.data = module.bias.data setattr(parent, child_name, replacement) return model except ImportError: print("bitsandbytes not available. Install with: pip install bitsandbytes") return model def _apply_int4_quantization(model: nn.Module) -> nn.Module: """ Apply INT4 quantization using bitsandbytes. More aggressive, for 13B on 8GB VRAM. """ try: import bitsandbytes as bnb for name, module in model.named_modules(): if isinstance(module, nn.Linear): parent_name = name.rsplit('.', 1)[0] if '.' in name else '' child_name = name.rsplit('.', 1)[1] if '.' in name else name parent = model if parent_name: for part in parent_name.split('.'): parent = getattr(parent, part) # 4-bit linear replacement = bnb.nn.Linear4bit( module.in_features, module.out_features, bias=module.bias is not None, compute_dtype=torch.float16, compress_statistics=True, ) replacement.weight.data = module.weight.data if module.bias is not None: replacement.bias.data = module.bias.data setattr(parent, child_name, replacement) return model except ImportError: print("bitsandbytes not available.") return model def get_cuda_memory_usage() -> Dict[str, float]: """Get current CUDA memory usage in GB.""" if not torch.cuda.is_available(): return {"error": "CUDA not available"} allocated = torch.cuda.memory_allocated() / 1e9 reserved = torch.cuda.memory_reserved() / 1e9 max_allocated = torch.cuda.max_memory_allocated() / 1e9 return { "allocated_gb": allocated, "reserved_gb": reserved, "max_allocated_gb": max_allocated, } def profile_model( model: nn.Module, input_ids: torch.Tensor, num_warmup: int = 10, num_runs: int = 100, ) -> Dict[str, float]: """ Profile model performance. Args: model: Model to profile input_ids: Example input num_warmup: Number of warmup runs num_runs: Number of profiling runs Returns: Dictionary with timing statistics """ model.eval() device = next(model.parameters()).device input_ids = input_ids.to(device) # Warmup with torch.no_grad(): for _ in range(num_warmup): _ = model(input_ids) # Profile torch.cuda.synchronize() import time start = time.time() with torch.no_grad(): for _ in range(num_runs): _ = model(input_ids) torch.cuda.synchronize() elapsed = time.time() - start avg_time = elapsed / num_runs tokens_per_sec = input_ids.shape[1] / avg_time return { "avg_time_sec": avg_time, "tokens_per_sec": tokens_per_sec, } def test_cuda_optimize(): """Test CUDA optimizations.""" if not torch.cuda.is_available(): print("CUDA not available, skipping test") return from models.vortex_model import VortexModel from configs.vortex_7b_config import VORTEX_7B_CONFIG config = VORTEX_7B_CONFIG.copy() config["d_model"] = 512 config["num_layers"] = 2 config["num_heads"] = 8 config["vocab_size"] = 1000 model = VortexModel(config) print(f"Model parameters: {model.get_num_params():,}") # Optimize model = optimize_for_cuda( model, config, use_flash_attention=False, # May not be available use_torch_compile=False, # Skip compile for test quantization=None, ) # Test forward batch_size = 2 seq_len = 128 input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).cuda() with torch.no_grad(): output = model(input_ids) logits = output["logits"] print(f"Output shape: {logits.shape}") print("CUDA optimize test passed!") if __name__ == "__main__": test_cuda_optimize()