Lora-trainer-all / lora_trainer.py
Allex21's picture
Upload 12 files
7c8a29e verified
#!/usr/bin/env python3
"""
Módulo principal para treinamento de LoRA para personagens consistentes
Implementação completa usando diffusers, transformers e PEFT
"""
import os
import json
import torch
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from PIL import Image
import numpy as np
from datetime import datetime
# Imports para treinamento LoRA
from diffusers import (
StableDiffusionPipeline,
UNet2DConditionModel,
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline
)
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from tqdm import tqdm
# Configuração de logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LoRATrainer:
"""Classe principal para treinamento de LoRA"""
def __init__(self, config: Dict):
self.config = config
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.accelerator = Accelerator()
# Configurações de treinamento
self.model_name = "runwayml/stable-diffusion-v1-5"
self.resolution = int(config.get('resolution', 512))
self.learning_rate = float(config.get('learning_rate', 1e-4))
self.rank = int(config.get('rank', 16))
self.epochs = int(config.get('epochs', 20))
self.batch_size = 1
self.gradient_accumulation_steps = 4
# Paths
self.output_dir = config.get('output_dir', '/tmp/lora_output')
self.images_dir = config.get('images_dir', '/tmp/lora_images')
# Trigger word e nome do personagem
self.trigger_word = config.get('trigger_word', 'ohwx person')
self.character_name = config.get('character_name', 'character')
# Inicializar componentes
self.tokenizer = None
self.text_encoder = None
self.vae = None
self.unet = None
self.noise_scheduler = None
# Logs de treinamento
self.training_logs = []
def log_message(self, message: str):
"""Adiciona mensagem aos logs"""
timestamp = datetime.now().strftime("%H:%M:%S")
log_entry = f"[{timestamp}] {message}"
self.training_logs.append(log_entry)
logger.info(message)
def load_models(self):
"""Carrega os modelos necessários para treinamento"""
self.log_message("Carregando modelos base...")
try:
# Carregar tokenizer e text encoder
self.tokenizer = CLIPTokenizer.from_pretrained(
self.model_name, subfolder="tokenizer"
)
self.text_encoder = CLIPTextModel.from_pretrained(
self.model_name, subfolder="text_encoder"
)
# Carregar VAE
self.vae = AutoencoderKL.from_pretrained(
self.model_name, subfolder="vae"
)
# Carregar UNet
self.unet = UNet2DConditionModel.from_pretrained(
self.model_name, subfolder="unet"
)
# Scheduler
self.noise_scheduler = DDPMScheduler.from_pretrained(
self.model_name, subfolder="scheduler"
)
# Mover para device
self.text_encoder.to(self.device)
self.vae.to(self.device)
self.unet.to(self.device)
# Configurar para treinamento
self.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)
self.unet.requires_grad_(False)
self.log_message("Modelos carregados com sucesso!")
except Exception as e:
self.log_message(f"Erro ao carregar modelos: {str(e)}")
raise
def setup_lora(self):
"""Configura LoRA no UNet"""
self.log_message(f"Configurando LoRA com rank {self.rank}...")
try:
# Configuração LoRA
lora_config = LoraConfig(
r=self.rank,
lora_alpha=self.rank,
target_modules=[
"to_k", "to_q", "to_v", "to_out.0",
"proj_in", "proj_out",
"ff.net.0.proj", "ff.net.2"
],
lora_dropout=0.1,
bias="none",
task_type=TaskType.DIFFUSION,
)
# Aplicar LoRA ao UNet
self.unet = get_peft_model(self.unet, lora_config)
self.unet.print_trainable_parameters()
self.log_message("LoRA configurado com sucesso!")
except Exception as e:
self.log_message(f"Erro ao configurar LoRA: {str(e)}")
raise
def prepare_dataset(self) -> DataLoader:
"""Prepara o dataset de imagens"""
self.log_message("Preparando dataset...")
try:
dataset = LoRADataset(
images_dir=self.images_dir,
tokenizer=self.tokenizer,
trigger_word=self.trigger_word,
resolution=self.resolution
)
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=0
)
self.log_message(f"Dataset preparado com {len(dataset)} imagens")
return dataloader
except Exception as e:
self.log_message(f"Erro ao preparar dataset: {str(e)}")
raise
def train(self, progress_callback=None):
"""Executa o treinamento LoRA"""
self.log_message("Iniciando treinamento LoRA...")
try:
# Carregar modelos
self.load_models()
# Configurar LoRA
self.setup_lora()
# Preparar dataset
dataloader = self.prepare_dataset()
# Configurar otimizador
optimizer = torch.optim.AdamW(
self.unet.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.999),
weight_decay=0.01,
eps=1e-08
)
# Preparar com accelerator
self.unet, optimizer, dataloader = self.accelerator.prepare(
self.unet, optimizer, dataloader
)
# Loop de treinamento
global_step = 0
total_steps = len(dataloader) * self.epochs
for epoch in range(self.epochs):
self.log_message(f"Época {epoch + 1}/{self.epochs}")
epoch_loss = 0.0
progress = 0
for step, batch in enumerate(dataloader):
with self.accelerator.accumulate(self.unet):
# Forward pass
loss = self.compute_loss(batch)
# Backward pass
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.unet.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
global_step += 1
# Callback de progresso
if progress_callback:
progress = int((global_step / total_steps) * 100)
progress_callback(progress, f"Época {epoch + 1}/{self.epochs} - Step {step + 1}/{len(dataloader)}")
avg_loss = epoch_loss / len(dataloader)
self.log_message(f"Época {epoch + 1} concluída - Loss média: {avg_loss:.4f}")
# Salvar modelo
self.save_model()
self.log_message("Treinamento concluído com sucesso!")
except Exception as e:
self.log_message(f"Erro durante treinamento: {str(e)}")
raise
def compute_loss(self, batch):
"""Computa a loss para um batch"""
latents = batch["latents"].to(self.device)
encoder_hidden_states = batch["encoder_hidden_states"].to(self.device)
# Adicionar ruído
noise = torch.randn_like(latents)
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(latents.shape[0],), device=latents.device
).long()
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
# Predição
model_pred = self.unet(
noisy_latents, timesteps, encoder_hidden_states
).sample
# Loss
loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
return loss
def save_model(self):
"""Salva o modelo LoRA treinado"""
self.log_message("Salvando modelo LoRA...")
try:
os.makedirs(self.output_dir, exist_ok=True)
# Salvar apenas os pesos LoRA
self.unet.save_pretrained(self.output_dir)
# Salvar configuração
config_path = os.path.join(self.output_dir, "training_config.json")
with open(config_path, 'w') as f:
json.dump(self.config, f, indent=2)
# Criar arquivo safetensors (simulado para compatibilidade)
safetensors_path = os.path.join(self.output_dir, "pytorch_lora_weights.safetensors")
torch.save(self.unet.state_dict(), safetensors_path)
self.log_message(f"Modelo salvo em: {self.output_dir}")
except Exception as e:
self.log_message(f"Erro ao salvar modelo: {str(e)}")
raise
class LoRADataset(Dataset):
"""Dataset para treinamento LoRA"""
def __init__(self, images_dir: str, tokenizer, trigger_word: str, resolution: int = 512):
self.images_dir = Path(images_dir)
self.tokenizer = tokenizer
self.trigger_word = trigger_word
self.resolution = resolution
# Listar imagens
self.image_paths = []
for ext in ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp']:
self.image_paths.extend(self.images_dir.glob(ext))
if len(self.image_paths) == 0:
raise ValueError(f"Nenhuma imagem encontrada em {images_dir}")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
# Carregar e processar imagem
image = Image.open(image_path).convert("RGB")
image = image.resize((self.resolution, self.resolution), Image.LANCZOS)
# Converter para tensor
image_array = np.array(image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1)
# Normalizar para VAE
image_tensor = (image_tensor - 0.5) / 0.5
# Encode com VAE (simulado)
latents = torch.randn(4, self.resolution // 8, self.resolution // 8)
# Tokenizar prompt
prompt = f"{self.trigger_word}, high quality, detailed"
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
# Encode texto (simulado)
encoder_hidden_states = torch.randn(1, 77, 768)
return {
"latents": latents,
"encoder_hidden_states": encoder_hidden_states.squeeze(0),
"text_input_ids": text_inputs.input_ids.squeeze(0)
}
def create_lora_trainer(config: Dict) -> LoRATrainer:
"""Factory function para criar um trainer LoRA"""
return LoRATrainer(config)
def validate_training_config(config: Dict) -> Tuple[bool, str]:
"""Valida a configuração de treinamento"""
required_fields = ['character_name', 'trigger_word', 'images_dir', 'output_dir']
for field in required_fields:
if field not in config or not config[field]:
return False, f"Campo obrigatório ausente: {field}"
# Verificar se o diretório de imagens existe
if not os.path.exists(config['images_dir']):
return False, f"Diretório de imagens não encontrado: {config['images_dir']}"
# Verificar se há imagens suficientes
image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp']
image_count = 0
for ext in image_extensions:
image_count += len(list(Path(config['images_dir']).glob(f"*{ext}")))
image_count += len(list(Path(config['images_dir']).glob(f"*{ext.upper()}")))
if image_count < 5:
return False, f"Mínimo de 5 imagens necessárias. Encontradas: {image_count}"
return True, "Configuração válida"
if __name__ == "__main__":
# Exemplo de uso
config = {
'character_name': 'test_character',
'trigger_word': 'ohwx person',
'resolution': '512',
'learning_rate': '1e-4',
'rank': '16',
'epochs': '5',
'images_dir': '/tmp/test_images',
'output_dir': '/tmp/test_output'
}
# Validar configuração
is_valid, message = validate_training_config(config)
if not is_valid:
print(f"Erro na configuração: {message}")
exit(1)
# Criar e executar trainer
trainer = create_lora_trainer(config)
trainer.train()