#!/usr/bin/env python3 """ Módulo principal para treinamento de LoRA para personagens consistentes Implementação completa usando diffusers, transformers e PEFT """ import os import json import torch import logging from pathlib import Path from typing import Dict, List, Optional, Tuple from PIL import Image import numpy as np from datetime import datetime # Imports para treinamento LoRA from diffusers import ( StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler, DiffusionPipeline ) from transformers import CLIPTextModel, CLIPTokenizer from peft import LoraConfig, get_peft_model, TaskType import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from accelerate import Accelerator from tqdm import tqdm # Configuração de logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class LoRATrainer: """Classe principal para treinamento de LoRA""" def __init__(self, config: Dict): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.accelerator = Accelerator() # Configurações de treinamento self.model_name = "runwayml/stable-diffusion-v1-5" self.resolution = int(config.get('resolution', 512)) self.learning_rate = float(config.get('learning_rate', 1e-4)) self.rank = int(config.get('rank', 16)) self.epochs = int(config.get('epochs', 20)) self.batch_size = 1 self.gradient_accumulation_steps = 4 # Paths self.output_dir = config.get('output_dir', '/tmp/lora_output') self.images_dir = config.get('images_dir', '/tmp/lora_images') # Trigger word e nome do personagem self.trigger_word = config.get('trigger_word', 'ohwx person') self.character_name = config.get('character_name', 'character') # Inicializar componentes self.tokenizer = None self.text_encoder = None self.vae = None self.unet = None self.noise_scheduler = None # Logs de treinamento self.training_logs = [] def log_message(self, message: str): """Adiciona mensagem aos logs""" timestamp = datetime.now().strftime("%H:%M:%S") log_entry = f"[{timestamp}] {message}" self.training_logs.append(log_entry) logger.info(message) def load_models(self): """Carrega os modelos necessários para treinamento""" self.log_message("Carregando modelos base...") try: # Carregar tokenizer e text encoder self.tokenizer = CLIPTokenizer.from_pretrained( self.model_name, subfolder="tokenizer" ) self.text_encoder = CLIPTextModel.from_pretrained( self.model_name, subfolder="text_encoder" ) # Carregar VAE self.vae = AutoencoderKL.from_pretrained( self.model_name, subfolder="vae" ) # Carregar UNet self.unet = UNet2DConditionModel.from_pretrained( self.model_name, subfolder="unet" ) # Scheduler self.noise_scheduler = DDPMScheduler.from_pretrained( self.model_name, subfolder="scheduler" ) # Mover para device self.text_encoder.to(self.device) self.vae.to(self.device) self.unet.to(self.device) # Configurar para treinamento self.text_encoder.requires_grad_(False) self.vae.requires_grad_(False) self.unet.requires_grad_(False) self.log_message("Modelos carregados com sucesso!") except Exception as e: self.log_message(f"Erro ao carregar modelos: {str(e)}") raise def setup_lora(self): """Configura LoRA no UNet""" self.log_message(f"Configurando LoRA com rank {self.rank}...") try: # Configuração LoRA lora_config = LoraConfig( r=self.rank, lora_alpha=self.rank, target_modules=[ "to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out", "ff.net.0.proj", "ff.net.2" ], lora_dropout=0.1, bias="none", task_type=TaskType.DIFFUSION, ) # Aplicar LoRA ao UNet self.unet = get_peft_model(self.unet, lora_config) self.unet.print_trainable_parameters() self.log_message("LoRA configurado com sucesso!") except Exception as e: self.log_message(f"Erro ao configurar LoRA: {str(e)}") raise def prepare_dataset(self) -> DataLoader: """Prepara o dataset de imagens""" self.log_message("Preparando dataset...") try: dataset = LoRADataset( images_dir=self.images_dir, tokenizer=self.tokenizer, trigger_word=self.trigger_word, resolution=self.resolution ) dataloader = DataLoader( dataset, batch_size=self.batch_size, shuffle=True, num_workers=0 ) self.log_message(f"Dataset preparado com {len(dataset)} imagens") return dataloader except Exception as e: self.log_message(f"Erro ao preparar dataset: {str(e)}") raise def train(self, progress_callback=None): """Executa o treinamento LoRA""" self.log_message("Iniciando treinamento LoRA...") try: # Carregar modelos self.load_models() # Configurar LoRA self.setup_lora() # Preparar dataset dataloader = self.prepare_dataset() # Configurar otimizador optimizer = torch.optim.AdamW( self.unet.parameters(), lr=self.learning_rate, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-08 ) # Preparar com accelerator self.unet, optimizer, dataloader = self.accelerator.prepare( self.unet, optimizer, dataloader ) # Loop de treinamento global_step = 0 total_steps = len(dataloader) * self.epochs for epoch in range(self.epochs): self.log_message(f"Época {epoch + 1}/{self.epochs}") epoch_loss = 0.0 progress = 0 for step, batch in enumerate(dataloader): with self.accelerator.accumulate(self.unet): # Forward pass loss = self.compute_loss(batch) # Backward pass self.accelerator.backward(loss) if self.accelerator.sync_gradients: self.accelerator.clip_grad_norm_(self.unet.parameters(), 1.0) optimizer.step() optimizer.zero_grad() epoch_loss += loss.item() global_step += 1 # Callback de progresso if progress_callback: progress = int((global_step / total_steps) * 100) progress_callback(progress, f"Época {epoch + 1}/{self.epochs} - Step {step + 1}/{len(dataloader)}") avg_loss = epoch_loss / len(dataloader) self.log_message(f"Época {epoch + 1} concluída - Loss média: {avg_loss:.4f}") # Salvar modelo self.save_model() self.log_message("Treinamento concluído com sucesso!") except Exception as e: self.log_message(f"Erro durante treinamento: {str(e)}") raise def compute_loss(self, batch): """Computa a loss para um batch""" latents = batch["latents"].to(self.device) encoder_hidden_states = batch["encoder_hidden_states"].to(self.device) # Adicionar ruído noise = torch.randn_like(latents) timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device ).long() noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) # Predição model_pred = self.unet( noisy_latents, timesteps, encoder_hidden_states ).sample # Loss loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean") return loss def save_model(self): """Salva o modelo LoRA treinado""" self.log_message("Salvando modelo LoRA...") try: os.makedirs(self.output_dir, exist_ok=True) # Salvar apenas os pesos LoRA self.unet.save_pretrained(self.output_dir) # Salvar configuração config_path = os.path.join(self.output_dir, "training_config.json") with open(config_path, 'w') as f: json.dump(self.config, f, indent=2) # Criar arquivo safetensors (simulado para compatibilidade) safetensors_path = os.path.join(self.output_dir, "pytorch_lora_weights.safetensors") torch.save(self.unet.state_dict(), safetensors_path) self.log_message(f"Modelo salvo em: {self.output_dir}") except Exception as e: self.log_message(f"Erro ao salvar modelo: {str(e)}") raise class LoRADataset(Dataset): """Dataset para treinamento LoRA""" def __init__(self, images_dir: str, tokenizer, trigger_word: str, resolution: int = 512): self.images_dir = Path(images_dir) self.tokenizer = tokenizer self.trigger_word = trigger_word self.resolution = resolution # Listar imagens self.image_paths = [] for ext in ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp']: self.image_paths.extend(self.images_dir.glob(ext)) if len(self.image_paths) == 0: raise ValueError(f"Nenhuma imagem encontrada em {images_dir}") def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image_path = self.image_paths[idx] # Carregar e processar imagem image = Image.open(image_path).convert("RGB") image = image.resize((self.resolution, self.resolution), Image.LANCZOS) # Converter para tensor image_array = np.array(image).astype(np.float32) / 255.0 image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) # Normalizar para VAE image_tensor = (image_tensor - 0.5) / 0.5 # Encode com VAE (simulado) latents = torch.randn(4, self.resolution // 8, self.resolution // 8) # Tokenizar prompt prompt = f"{self.trigger_word}, high quality, detailed" text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt" ) # Encode texto (simulado) encoder_hidden_states = torch.randn(1, 77, 768) return { "latents": latents, "encoder_hidden_states": encoder_hidden_states.squeeze(0), "text_input_ids": text_inputs.input_ids.squeeze(0) } def create_lora_trainer(config: Dict) -> LoRATrainer: """Factory function para criar um trainer LoRA""" return LoRATrainer(config) def validate_training_config(config: Dict) -> Tuple[bool, str]: """Valida a configuração de treinamento""" required_fields = ['character_name', 'trigger_word', 'images_dir', 'output_dir'] for field in required_fields: if field not in config or not config[field]: return False, f"Campo obrigatório ausente: {field}" # Verificar se o diretório de imagens existe if not os.path.exists(config['images_dir']): return False, f"Diretório de imagens não encontrado: {config['images_dir']}" # Verificar se há imagens suficientes image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp'] image_count = 0 for ext in image_extensions: image_count += len(list(Path(config['images_dir']).glob(f"*{ext}"))) image_count += len(list(Path(config['images_dir']).glob(f"*{ext.upper()}"))) if image_count < 5: return False, f"Mínimo de 5 imagens necessárias. Encontradas: {image_count}" return True, "Configuração válida" if __name__ == "__main__": # Exemplo de uso config = { 'character_name': 'test_character', 'trigger_word': 'ohwx person', 'resolution': '512', 'learning_rate': '1e-4', 'rank': '16', 'epochs': '5', 'images_dir': '/tmp/test_images', 'output_dir': '/tmp/test_output' } # Validar configuração is_valid, message = validate_training_config(config) if not is_valid: print(f"Erro na configuração: {message}") exit(1) # Criar e executar trainer trainer = create_lora_trainer(config) trainer.train()