Lora-trainer / gpu_optimizer.py
Allex21's picture
Upload 24 files
5bb2330 verified
import torch
import gc
import logging
from typing import Dict, Any, Optional
import psutil
import os
class GPUOptimizer:
"""Serviço para otimizações de GPU e gerenciamento de memória"""
def __init__(self):
self.logger = logging.getLogger(__name__)
self.device = self._get_optimal_device()
self.memory_stats = {}
def _get_optimal_device(self) -> str:
"""Determina o melhor dispositivo disponível"""
if torch.cuda.is_available():
# Seleciona a GPU com mais memória livre
gpu_count = torch.cuda.device_count()
if gpu_count > 0:
best_gpu = 0
max_free_memory = 0
for i in range(gpu_count):
torch.cuda.set_device(i)
free_memory = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i)
if free_memory > max_free_memory:
max_free_memory = free_memory
best_gpu = i
return f"cuda:{best_gpu}"
self.logger.warning("CUDA não disponível, usando CPU")
return "cpu"
def setup_memory_efficient_training(self, model, config: Dict[str, Any]):
"""Configura o modelo para treinamento eficiente em memória"""
optimizations_applied = []
# 1. Gradient Checkpointing
if config.get('use_gradient_checkpointing', True):
if hasattr(model, 'gradient_checkpointing_enable'):
model.gradient_checkpointing_enable()
optimizations_applied.append("gradient_checkpointing")
self.logger.info("Gradient checkpointing habilitado")
# 2. Mixed Precision
mixed_precision = config.get('mixed_precision', 'fp16')
if mixed_precision in ['fp16', 'bf16']:
optimizations_applied.append(f"mixed_precision_{mixed_precision}")
self.logger.info(f"Mixed precision configurado: {mixed_precision}")
# 3. Model Quantization (se suportado)
if config.get('use_8bit_quantization', False):
optimizations_applied.append("8bit_quantization")
self.logger.info("Quantização 8-bit habilitada")
return optimizations_applied
def get_8bit_optimizer(self, optimizer_class, model_parameters, **kwargs):
"""Retorna um otimizador 8-bit para economia de memória"""
try:
import bitsandbytes as bnb
# Mapeia otimizadores padrão para versões 8-bit
optimizer_mapping = {
'AdamW': bnb.optim.AdamW8bit,
'Adam': bnb.optim.Adam8bit,
'SGD': bnb.optim.SGD8bit,
}
optimizer_name = optimizer_class.__name__
if optimizer_name in optimizer_mapping:
optimizer_8bit = optimizer_mapping[optimizer_name]
self.logger.info(f"Usando otimizador 8-bit: {optimizer_name}")
return optimizer_8bit(model_parameters, **kwargs)
else:
self.logger.warning(f"Otimizador 8-bit não disponível para {optimizer_name}, usando versão padrão")
return optimizer_class(model_parameters, **kwargs)
except ImportError:
self.logger.warning("bitsandbytes não disponível, usando otimizador padrão")
return optimizer_class(model_parameters, **kwargs)
def optimize_batch_size(self, base_batch_size: int, model_size_mb: float) -> int:
"""Otimiza o tamanho do batch baseado na memória disponível"""
if self.device == "cpu":
# Para CPU, limita baseado na RAM disponível
available_ram_gb = psutil.virtual_memory().available / (1024**3)
# Heurística: 1GB de RAM por item do batch para modelos grandes
max_batch_size = max(1, int(available_ram_gb / 2))
return min(base_batch_size, max_batch_size)
# Para GPU
if torch.cuda.is_available():
device_idx = int(self.device.split(':')[1]) if ':' in self.device else 0
total_memory = torch.cuda.get_device_properties(device_idx).total_memory
allocated_memory = torch.cuda.memory_allocated(device_idx)
free_memory = total_memory - allocated_memory
# Heurística: reserva 20% da memória livre para overhead
usable_memory = free_memory * 0.8
# Estima quantos itens do batch cabem na memória
# Assume que cada item do batch usa aproximadamente model_size_mb * 3 (forward + backward + gradients)
memory_per_item = model_size_mb * 3 * 1024 * 1024 # Converte para bytes
max_batch_size = max(1, int(usable_memory / memory_per_item))
optimized_batch_size = min(base_batch_size, max_batch_size)
self.logger.info(f"Batch size otimizado: {base_batch_size} -> {optimized_batch_size}")
return optimized_batch_size
return base_batch_size
def clear_memory_cache(self):
"""Limpa cache de memória GPU e força garbage collection"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Force garbage collection
gc.collect()
self.logger.info("Cache de memória limpo")
def get_memory_usage(self) -> Dict[str, Any]:
"""Retorna estatísticas de uso de memória"""
stats = {
'device': self.device,
'cpu_memory': {
'total_gb': psutil.virtual_memory().total / (1024**3),
'available_gb': psutil.virtual_memory().available / (1024**3),
'used_percent': psutil.virtual_memory().percent
}
}
if torch.cuda.is_available() and self.device.startswith('cuda'):
device_idx = int(self.device.split(':')[1]) if ':' in self.device else 0
stats['gpu_memory'] = {
'device_name': torch.cuda.get_device_name(device_idx),
'total_mb': torch.cuda.get_device_properties(device_idx).total_memory / (1024**2),
'allocated_mb': torch.cuda.memory_allocated(device_idx) / (1024**2),
'reserved_mb': torch.cuda.memory_reserved(device_idx) / (1024**2),
'free_mb': (torch.cuda.get_device_properties(device_idx).total_memory -
torch.cuda.memory_reserved(device_idx)) / (1024**2)
}
else:
stats['gpu_memory'] = {
'device_name': 'CPU',
'total_mb': 0,
'allocated_mb': 0,
'reserved_mb': 0,
'free_mb': 0
}
return stats
def estimate_training_memory(self, model_params: int, batch_size: int, sequence_length: int = 512) -> Dict[str, float]:
"""Estima o uso de memória para treinamento"""
# Estimativas baseadas em heurísticas comuns
# Memória do modelo (parâmetros + gradientes + estados do otimizador)
param_memory_mb = model_params * 4 / (1024**2) # 4 bytes por parâmetro (float32)
gradient_memory_mb = param_memory_mb # Gradientes têm o mesmo tamanho dos parâmetros
optimizer_memory_mb = param_memory_mb * 2 # Adam mantém momentum e variance
# Memória de ativações (depende do batch size e sequence length)
activation_memory_mb = batch_size * sequence_length * 768 * 4 / (1024**2) # Estimativa para transformer
total_memory_mb = param_memory_mb + gradient_memory_mb + optimizer_memory_mb + activation_memory_mb
return {
'model_params_mb': param_memory_mb,
'gradients_mb': gradient_memory_mb,
'optimizer_states_mb': optimizer_memory_mb,
'activations_mb': activation_memory_mb,
'total_estimated_mb': total_memory_mb,
'total_estimated_gb': total_memory_mb / 1024
}
def suggest_optimizations(self, current_config: Dict[str, Any]) -> Dict[str, Any]:
"""Sugere otimizações baseadas no hardware disponível"""
suggestions = {}
memory_stats = self.get_memory_usage()
# Sugestões baseadas na memória GPU disponível
if 'gpu_memory' in memory_stats:
gpu_memory_gb = memory_stats['gpu_memory']['total_mb'] / 1024
if gpu_memory_gb < 4: # GPU com pouca memória
suggestions.update({
'use_8bit_optimizer': True,
'use_gradient_checkpointing': True,
'mixed_precision': 'fp16',
'batch_size': 1,
'suggested_rank': 4, # Rank baixo para LoRA
'reason': 'GPU com pouca memória detectada'
})
elif gpu_memory_gb < 8: # GPU média
suggestions.update({
'use_8bit_optimizer': True,
'use_gradient_checkpointing': True,
'mixed_precision': 'fp16',
'batch_size': 2,
'suggested_rank': 8,
'reason': 'GPU com memória média detectada'
})
else: # GPU com bastante memória
suggestions.update({
'use_8bit_optimizer': False,
'use_gradient_checkpointing': False,
'mixed_precision': 'fp16',
'batch_size': 4,
'suggested_rank': 16,
'reason': 'GPU com memória suficiente detectada'
})
else: # CPU only
suggestions.update({
'use_8bit_optimizer': True,
'use_gradient_checkpointing': True,
'mixed_precision': 'fp32', # CPU não suporta fp16 eficientemente
'batch_size': 1,
'suggested_rank': 4,
'reason': 'Treinamento em CPU detectado'
})
return suggestions
def monitor_memory_during_training(self) -> Dict[str, Any]:
"""Monitora o uso de memória durante o treinamento"""
current_stats = self.get_memory_usage()
# Detecta vazamentos de memória ou uso excessivo
warnings = []
if 'gpu_memory' in current_stats:
gpu_usage_percent = (current_stats['gpu_memory']['allocated_mb'] /
current_stats['gpu_memory']['total_mb']) * 100
if gpu_usage_percent > 90:
warnings.append("Uso de GPU muito alto (>90%), considere reduzir batch size")
elif gpu_usage_percent > 80:
warnings.append("Uso de GPU alto (>80%), monitore para possível OOM")
cpu_usage_percent = current_stats['cpu_memory']['used_percent']
if cpu_usage_percent > 90:
warnings.append("Uso de RAM muito alto (>90%)")
return {
'memory_stats': current_stats,
'warnings': warnings,
'timestamp': torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
}