import os import logging import torch import time from datetime import datetime from typing import Optional, Dict, Any from pathlib import Path from src.services.gpu_optimizer import GPUOptimizer class LoRATrainer: """LoRA training service with GPU optimizations""" def __init__(self): self.logger = logging.getLogger(__name__) self.gpu_optimizer = GPUOptimizer() self.device = self.gpu_optimizer.device self.logger.info(f"LoRA Trainer initialized with device: {self.device}") def train_project(self, project_id: int): """Train a LoRA project with optimizations""" from src.models.lora_project import LoRAProject, TrainingStatus, db try: # Get project from database project = LoRAProject.query.get(project_id) if not project: raise ValueError(f"Project {project_id} not found") # Update status to running project.status = TrainingStatus.RUNNING project.started_at = datetime.utcnow() db.session.commit() # Setup logging log_dir = Path(f"logs/project_{project_id}") log_dir.mkdir(parents=True, exist_ok=True) log_file = log_dir / f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" project.log_file = str(log_file) # Setup output directory output_dir = Path(f"outputs/project_{project_id}") output_dir.mkdir(parents=True, exist_ok=True) project.output_path = str(output_dir) db.session.commit() # Configure file logging file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) self.logger.addHandler(file_handler) self.logger.info(f"Starting LoRA training for project: {project.name}") # Get optimization suggestions config = { 'use_8bit_optimizer': project.use_8bit_optimizer, 'use_gradient_checkpointing': project.use_gradient_checkpointing, 'mixed_precision': project.mixed_precision, 'batch_size': project.batch_size, 'rank': project.rank } suggestions = self.gpu_optimizer.suggest_optimizations(config) self.logger.info(f"GPU Optimization suggestions: {suggestions}") # Apply optimizations if they differ from current config if suggestions.get('batch_size', project.batch_size) != project.batch_size: old_batch_size = project.batch_size project.batch_size = suggestions['batch_size'] self.logger.info(f"Batch size optimized: {old_batch_size} -> {project.batch_size}") db.session.commit() # Log initial memory usage memory_stats = self.gpu_optimizer.get_memory_usage() self.logger.info(f"Initial memory usage: {memory_stats}") # Load and prepare model self._load_base_model(project) # Prepare dataset self._prepare_dataset(project) # Setup LoRA with optimizations self._setup_lora_optimized(project) # Train model with memory monitoring self._train_model_optimized(project) # Save final model self._save_model(project) # Update project status project.status = TrainingStatus.COMPLETED project.completed_at = datetime.utcnow() project.progress = 1.0 db.session.commit() # Log final memory usage final_memory_stats = self.gpu_optimizer.get_memory_usage() self.logger.info(f"Final memory usage: {final_memory_stats}") self.logger.info("Training completed successfully") except Exception as e: self.logger.error(f"Training failed: {str(e)}") # Update project with error project.status = TrainingStatus.FAILED project.error_message = str(e) project.completed_at = datetime.utcnow() db.session.commit() raise finally: # Clean up GPU memory self.gpu_optimizer.clear_memory_cache() def _load_base_model(self, project): """Load the base model for training""" self.logger.info(f"Loading base model: {project.base_model}") # Estimate model memory requirements estimated_params = self._estimate_model_parameters(project.base_model) memory_estimate = self.gpu_optimizer.estimate_training_memory( estimated_params, project.batch_size ) self.logger.info(f"Estimated memory usage: {memory_estimate['total_estimated_gb']:.2f} GB") # Check if we have enough memory current_memory = self.gpu_optimizer.get_memory_usage() if 'gpu_memory' in current_memory: available_gb = current_memory['gpu_memory']['free_mb'] / 1024 if memory_estimate['total_estimated_gb'] > available_gb: self.logger.warning(f"Estimated memory usage ({memory_estimate['total_estimated_gb']:.2f} GB) " f"exceeds available memory ({available_gb:.2f} GB)") # Simulate model loading with memory optimization time.sleep(2) self.logger.info("Base model loaded successfully with optimizations") def _prepare_dataset(self, project): """Prepare the dataset for training""" self.logger.info(f"Preparing dataset from: {project.dataset_path}") if not project.dataset_path or not os.path.exists(project.dataset_path): raise ValueError("Dataset path not found") # Optimize batch size based on available memory optimized_batch_size = self.gpu_optimizer.optimize_batch_size( project.batch_size, model_size_mb=500 # Estimated model size ) if optimized_batch_size != project.batch_size: from src.models.lora_project import db project.batch_size = optimized_batch_size db.session.commit() self.logger.info(f"Batch size auto-optimized to: {optimized_batch_size}") time.sleep(1) self.logger.info("Dataset prepared successfully with memory optimizations") def _setup_lora_optimized(self, project): """Setup LoRA configuration with optimizations""" self.logger.info("Setting up LoRA configuration with optimizations") # Apply memory-efficient configurations optimizations = [] if project.use_8bit_optimizer: optimizations.append("8-bit optimizer") self.logger.info("Using 8-bit optimizer for memory efficiency") if project.use_gradient_checkpointing: optimizations.append("gradient checkpointing") self.logger.info("Using gradient checkpointing to reduce memory usage") self.logger.info(f"Mixed precision: {project.mixed_precision}") optimizations.append(f"mixed precision ({project.mixed_precision})") # Log LoRA-specific optimizations self.logger.info(f"LoRA rank: {project.rank} (lower rank = less memory)") self.logger.info(f"LoRA alpha: {project.alpha}") # Calculate memory savings from LoRA full_model_params = self._estimate_model_parameters(project.base_model) lora_params = self._estimate_lora_parameters(project.rank, full_model_params) memory_savings = ((full_model_params - lora_params) / full_model_params) * 100 self.logger.info(f"LoRA memory savings: {memory_savings:.1f}% " f"({lora_params:,} trainable params vs {full_model_params:,} total)") self.logger.info(f"Applied optimizations: {', '.join(optimizations)}") def _train_model_optimized(self, project): """Execute the training loop with memory monitoring""" from src.models.lora_project import db self.logger.info("Starting optimized training loop") for epoch in range(project.num_epochs): if project.status != TrainingStatus.RUNNING: self.logger.info("Training cancelled by user") break self.logger.info(f"Epoch {epoch + 1}/{project.num_epochs}") # Monitor memory at the start of each epoch memory_monitor = self.gpu_optimizer.monitor_memory_during_training() if memory_monitor['warnings']: for warning in memory_monitor['warnings']: self.logger.warning(warning) # Simulate training steps with memory-efficient processing num_steps = max(1, 10 // project.batch_size) # Adjust steps based on batch size for step in range(num_steps): if project.status != TrainingStatus.RUNNING: break # Simulate training step with optimizations step_time = 0.3 if project.use_8bit_optimizer else 0.5 step_time *= (1.2 if project.use_gradient_checkpointing else 1.0) # Gradient checkpointing is slower time.sleep(step_time) # Simulate loss calculation with better convergence for optimized training base_loss = 1.0 - (epoch * num_steps + step) / (project.num_epochs * num_steps) * 0.8 # Add optimization-specific improvements if project.use_8bit_optimizer: base_loss *= 0.95 # 8-bit optimizer can be slightly less stable if project.mixed_precision == 'fp16': base_loss *= 0.98 # Mixed precision can have small numerical differences loss = base_loss + (torch.rand(1).item() - 0.5) * 0.1 # Update progress progress = (epoch * num_steps + step + 1) / (project.num_epochs * num_steps) project.progress = progress project.current_epoch = epoch + 1 project.current_loss = loss # Commit progress to database db.session.commit() if step % 3 == 0: # Log every 3 steps self.logger.info(f"Step {step + 1}/{num_steps}, Loss: {loss:.4f}") # Periodic memory cleanup if step % 5 == 0: self.gpu_optimizer.clear_memory_cache() # Log memory usage at end of epoch epoch_memory = self.gpu_optimizer.get_memory_usage() if 'gpu_memory' in epoch_memory: gpu_usage = epoch_memory['gpu_memory']['allocated_mb'] self.logger.info(f"Epoch {epoch + 1} completed, Loss: {project.current_loss:.4f}, " f"GPU Memory: {gpu_usage:.1f} MB") else: self.logger.info(f"Epoch {epoch + 1} completed, Loss: {project.current_loss:.4f}") self.logger.info("Optimized training loop completed") def _save_model(self, project): """Save the trained LoRA model""" self.logger.info("Saving trained LoRA model") model_path = os.path.join(project.output_path, "lora_model.safetensors") # Simulate saving with optimization info time.sleep(1) # Create a detailed model file with optimization metadata with open(model_path, 'w') as f: f.write("# LoRA model with optimizations\n") f.write(f"# Project: {project.name}\n") f.write(f"# Base model: {project.base_model}\n") f.write(f"# LoRA rank: {project.rank}\n") f.write(f"# LoRA alpha: {project.alpha}\n") f.write(f"# Optimizations applied:\n") f.write(f"# - 8-bit optimizer: {project.use_8bit_optimizer}\n") f.write(f"# - Gradient checkpointing: {project.use_gradient_checkpointing}\n") f.write(f"# - Mixed precision: {project.mixed_precision}\n") f.write(f"# - Final batch size: {project.batch_size}\n") f.write(f"# Training device: {self.device}\n") self.logger.info(f"Optimized LoRA model saved to: {model_path}") def _estimate_model_parameters(self, model_name: str) -> int: """Estimate the number of parameters in a model""" # Simplified parameter estimation based on model name if "stable-diffusion-v1" in model_name.lower(): return 860_000_000 # ~860M parameters elif "stable-diffusion-xl" in model_name.lower(): return 3_500_000_000 # ~3.5B parameters elif "dialogpt-medium" in model_name.lower(): return 345_000_000 # ~345M parameters else: return 500_000_000 # Default estimate def _estimate_lora_parameters(self, rank: int, base_model_params: int) -> int: """Estimate the number of trainable parameters with LoRA""" # Simplified estimation: LoRA typically affects ~10% of model layers # Each LoRA layer adds 2 * rank * original_dim parameters # This is a rough approximation affected_layers = int(base_model_params * 0.1) avg_layer_size = 1024 # Average dimension lora_params_per_layer = 2 * rank * avg_layer_size total_lora_params = (affected_layers // avg_layer_size) * lora_params_per_layer return min(total_lora_params, base_model_params // 100) # Cap at 1% of base model def get_memory_usage(self) -> Dict[str, Any]: """Get current memory usage from GPU optimizer""" return self.gpu_optimizer.get_memory_usage()