Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Módulo principal para treinamento de LoRA para personagens consistentes | |
| Implementação completa usando diffusers, transformers e PEFT | |
| """ | |
| import os | |
| import json | |
| import torch | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| from PIL import Image | |
| import numpy as np | |
| from datetime import datetime | |
| # Imports para treinamento LoRA | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| UNet2DConditionModel, | |
| AutoencoderKL, | |
| DDPMScheduler, | |
| DiffusionPipeline | |
| ) | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from accelerate import Accelerator | |
| from tqdm import tqdm | |
| # Configuração de logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class LoRATrainer: | |
| """Classe principal para treinamento de LoRA""" | |
| def __init__(self, config: Dict): | |
| self.config = config | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.accelerator = Accelerator() | |
| # Configurações de treinamento | |
| self.model_name = "runwayml/stable-diffusion-v1-5" | |
| self.resolution = int(config.get('resolution', 512)) | |
| self.learning_rate = float(config.get('learning_rate', 1e-4)) | |
| self.rank = int(config.get('rank', 16)) | |
| self.epochs = int(config.get('epochs', 20)) | |
| self.batch_size = 1 | |
| self.gradient_accumulation_steps = 4 | |
| # Paths | |
| self.output_dir = config.get('output_dir', '/tmp/lora_output') | |
| self.images_dir = config.get('images_dir', '/tmp/lora_images') | |
| # Trigger word e nome do personagem | |
| self.trigger_word = config.get('trigger_word', 'ohwx person') | |
| self.character_name = config.get('character_name', 'character') | |
| # Inicializar componentes | |
| self.tokenizer = None | |
| self.text_encoder = None | |
| self.vae = None | |
| self.unet = None | |
| self.noise_scheduler = None | |
| # Logs de treinamento | |
| self.training_logs = [] | |
| def log_message(self, message: str): | |
| """Adiciona mensagem aos logs""" | |
| timestamp = datetime.now().strftime("%H:%M:%S") | |
| log_entry = f"[{timestamp}] {message}" | |
| self.training_logs.append(log_entry) | |
| logger.info(message) | |
| def load_models(self): | |
| """Carrega os modelos necessários para treinamento""" | |
| self.log_message("Carregando modelos base...") | |
| try: | |
| # Carregar tokenizer e text encoder | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| self.model_name, subfolder="tokenizer" | |
| ) | |
| self.text_encoder = CLIPTextModel.from_pretrained( | |
| self.model_name, subfolder="text_encoder" | |
| ) | |
| # Carregar VAE | |
| self.vae = AutoencoderKL.from_pretrained( | |
| self.model_name, subfolder="vae" | |
| ) | |
| # Carregar UNet | |
| self.unet = UNet2DConditionModel.from_pretrained( | |
| self.model_name, subfolder="unet" | |
| ) | |
| # Scheduler | |
| self.noise_scheduler = DDPMScheduler.from_pretrained( | |
| self.model_name, subfolder="scheduler" | |
| ) | |
| # Mover para device | |
| self.text_encoder.to(self.device) | |
| self.vae.to(self.device) | |
| self.unet.to(self.device) | |
| # Configurar para treinamento | |
| self.text_encoder.requires_grad_(False) | |
| self.vae.requires_grad_(False) | |
| self.unet.requires_grad_(False) | |
| self.log_message("Modelos carregados com sucesso!") | |
| except Exception as e: | |
| self.log_message(f"Erro ao carregar modelos: {str(e)}") | |
| raise | |
| def setup_lora(self): | |
| """Configura LoRA no UNet""" | |
| self.log_message(f"Configurando LoRA com rank {self.rank}...") | |
| try: | |
| # Configuração LoRA | |
| lora_config = LoraConfig( | |
| r=self.rank, | |
| lora_alpha=self.rank, | |
| target_modules=[ | |
| "to_k", "to_q", "to_v", "to_out.0", | |
| "proj_in", "proj_out", | |
| "ff.net.0.proj", "ff.net.2" | |
| ], | |
| lora_dropout=0.1, | |
| bias="none", | |
| task_type=TaskType.DIFFUSION, | |
| ) | |
| # Aplicar LoRA ao UNet | |
| self.unet = get_peft_model(self.unet, lora_config) | |
| self.unet.print_trainable_parameters() | |
| self.log_message("LoRA configurado com sucesso!") | |
| except Exception as e: | |
| self.log_message(f"Erro ao configurar LoRA: {str(e)}") | |
| raise | |
| def prepare_dataset(self) -> DataLoader: | |
| """Prepara o dataset de imagens""" | |
| self.log_message("Preparando dataset...") | |
| try: | |
| dataset = LoRADataset( | |
| images_dir=self.images_dir, | |
| tokenizer=self.tokenizer, | |
| trigger_word=self.trigger_word, | |
| resolution=self.resolution | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=0 | |
| ) | |
| self.log_message(f"Dataset preparado com {len(dataset)} imagens") | |
| return dataloader | |
| except Exception as e: | |
| self.log_message(f"Erro ao preparar dataset: {str(e)}") | |
| raise | |
| def train(self, progress_callback=None): | |
| """Executa o treinamento LoRA""" | |
| self.log_message("Iniciando treinamento LoRA...") | |
| try: | |
| # Carregar modelos | |
| self.load_models() | |
| # Configurar LoRA | |
| self.setup_lora() | |
| # Preparar dataset | |
| dataloader = self.prepare_dataset() | |
| # Configurar otimizador | |
| optimizer = torch.optim.AdamW( | |
| self.unet.parameters(), | |
| lr=self.learning_rate, | |
| betas=(0.9, 0.999), | |
| weight_decay=0.01, | |
| eps=1e-08 | |
| ) | |
| # Preparar com accelerator | |
| self.unet, optimizer, dataloader = self.accelerator.prepare( | |
| self.unet, optimizer, dataloader | |
| ) | |
| # Loop de treinamento | |
| global_step = 0 | |
| total_steps = len(dataloader) * self.epochs | |
| for epoch in range(self.epochs): | |
| self.log_message(f"Época {epoch + 1}/{self.epochs}") | |
| epoch_loss = 0.0 | |
| progress = 0 | |
| for step, batch in enumerate(dataloader): | |
| with self.accelerator.accumulate(self.unet): | |
| # Forward pass | |
| loss = self.compute_loss(batch) | |
| # Backward pass | |
| self.accelerator.backward(loss) | |
| if self.accelerator.sync_gradients: | |
| self.accelerator.clip_grad_norm_(self.unet.parameters(), 1.0) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| epoch_loss += loss.item() | |
| global_step += 1 | |
| # Callback de progresso | |
| if progress_callback: | |
| progress = int((global_step / total_steps) * 100) | |
| progress_callback(progress, f"Época {epoch + 1}/{self.epochs} - Step {step + 1}/{len(dataloader)}") | |
| avg_loss = epoch_loss / len(dataloader) | |
| self.log_message(f"Época {epoch + 1} concluída - Loss média: {avg_loss:.4f}") | |
| # Salvar modelo | |
| self.save_model() | |
| self.log_message("Treinamento concluído com sucesso!") | |
| except Exception as e: | |
| self.log_message(f"Erro durante treinamento: {str(e)}") | |
| raise | |
| def compute_loss(self, batch): | |
| """Computa a loss para um batch""" | |
| latents = batch["latents"].to(self.device) | |
| encoder_hidden_states = batch["encoder_hidden_states"].to(self.device) | |
| # Adicionar ruído | |
| noise = torch.randn_like(latents) | |
| timesteps = torch.randint( | |
| 0, self.noise_scheduler.config.num_train_timesteps, | |
| (latents.shape[0],), device=latents.device | |
| ).long() | |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
| # Predição | |
| model_pred = self.unet( | |
| noisy_latents, timesteps, encoder_hidden_states | |
| ).sample | |
| # Loss | |
| loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean") | |
| return loss | |
| def save_model(self): | |
| """Salva o modelo LoRA treinado""" | |
| self.log_message("Salvando modelo LoRA...") | |
| try: | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| # Salvar apenas os pesos LoRA | |
| self.unet.save_pretrained(self.output_dir) | |
| # Salvar configuração | |
| config_path = os.path.join(self.output_dir, "training_config.json") | |
| with open(config_path, 'w') as f: | |
| json.dump(self.config, f, indent=2) | |
| # Criar arquivo safetensors (simulado para compatibilidade) | |
| safetensors_path = os.path.join(self.output_dir, "pytorch_lora_weights.safetensors") | |
| torch.save(self.unet.state_dict(), safetensors_path) | |
| self.log_message(f"Modelo salvo em: {self.output_dir}") | |
| except Exception as e: | |
| self.log_message(f"Erro ao salvar modelo: {str(e)}") | |
| raise | |
| class LoRADataset(Dataset): | |
| """Dataset para treinamento LoRA""" | |
| def __init__(self, images_dir: str, tokenizer, trigger_word: str, resolution: int = 512): | |
| self.images_dir = Path(images_dir) | |
| self.tokenizer = tokenizer | |
| self.trigger_word = trigger_word | |
| self.resolution = resolution | |
| # Listar imagens | |
| self.image_paths = [] | |
| for ext in ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp']: | |
| self.image_paths.extend(self.images_dir.glob(ext)) | |
| if len(self.image_paths) == 0: | |
| raise ValueError(f"Nenhuma imagem encontrada em {images_dir}") | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image_path = self.image_paths[idx] | |
| # Carregar e processar imagem | |
| image = Image.open(image_path).convert("RGB") | |
| image = image.resize((self.resolution, self.resolution), Image.LANCZOS) | |
| # Converter para tensor | |
| image_array = np.array(image).astype(np.float32) / 255.0 | |
| image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) | |
| # Normalizar para VAE | |
| image_tensor = (image_tensor - 0.5) / 0.5 | |
| # Encode com VAE (simulado) | |
| latents = torch.randn(4, self.resolution // 8, self.resolution // 8) | |
| # Tokenizar prompt | |
| prompt = f"{self.trigger_word}, high quality, detailed" | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| # Encode texto (simulado) | |
| encoder_hidden_states = torch.randn(1, 77, 768) | |
| return { | |
| "latents": latents, | |
| "encoder_hidden_states": encoder_hidden_states.squeeze(0), | |
| "text_input_ids": text_inputs.input_ids.squeeze(0) | |
| } | |
| def create_lora_trainer(config: Dict) -> LoRATrainer: | |
| """Factory function para criar um trainer LoRA""" | |
| return LoRATrainer(config) | |
| def validate_training_config(config: Dict) -> Tuple[bool, str]: | |
| """Valida a configuração de treinamento""" | |
| required_fields = ['character_name', 'trigger_word', 'images_dir', 'output_dir'] | |
| for field in required_fields: | |
| if field not in config or not config[field]: | |
| return False, f"Campo obrigatório ausente: {field}" | |
| # Verificar se o diretório de imagens existe | |
| if not os.path.exists(config['images_dir']): | |
| return False, f"Diretório de imagens não encontrado: {config['images_dir']}" | |
| # Verificar se há imagens suficientes | |
| image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp'] | |
| image_count = 0 | |
| for ext in image_extensions: | |
| image_count += len(list(Path(config['images_dir']).glob(f"*{ext}"))) | |
| image_count += len(list(Path(config['images_dir']).glob(f"*{ext.upper()}"))) | |
| if image_count < 5: | |
| return False, f"Mínimo de 5 imagens necessárias. Encontradas: {image_count}" | |
| return True, "Configuração válida" | |
| if __name__ == "__main__": | |
| # Exemplo de uso | |
| config = { | |
| 'character_name': 'test_character', | |
| 'trigger_word': 'ohwx person', | |
| 'resolution': '512', | |
| 'learning_rate': '1e-4', | |
| 'rank': '16', | |
| 'epochs': '5', | |
| 'images_dir': '/tmp/test_images', | |
| 'output_dir': '/tmp/test_output' | |
| } | |
| # Validar configuração | |
| is_valid, message = validate_training_config(config) | |
| if not is_valid: | |
| print(f"Erro na configuração: {message}") | |
| exit(1) | |
| # Criar e executar trainer | |
| trainer = create_lora_trainer(config) | |
| trainer.train() | |