Gf / app.py
Allex21's picture
Update app.py
3d74922 verified
raw
history blame
10 kB
import gradio as gr
import os
import torch
from accelerate import Accelerator
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from diffusers.optimization import get_scheduler
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer
import zipfile
import shutil
from safetensors.torch import save_file
import torch.nn as nn
# Função para criar camadas LoRA
def create_lora_layers(module, rank=4):
if isinstance(module, nn.Linear):
lora_down = nn.Linear(module.in_features, rank, bias=False)
lora_up = nn.Linear(rank, module.out_features, bias=False)
nn.init.zeros_(lora_up.weight) # Inicialização zero para começar neutro
return lora_down, lora_up
return None, None
# Dataset simplificado
class DreamBoothDataset(Dataset):
def __init__(self, instance_data_root, tokenizer, size=512, train_prompt="a photo of sks dog"):
self.instance_data_root = instance_data_root
self.tokenizer = tokenizer
self.size = size
self.train_prompt = train_prompt
self.instance_images_path = [
os.path.join(instance_data_root, file_path)
for file_path in os.listdir(instance_data_root)
if file_path.endswith((".png", ".jpg", ".jpeg"))
]
self.transform = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return len(self.instance_images_path)
def __getitem__(self, index):
instance_image = Image.open(self.instance_images_path[index])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example = {}
example["instance_images"] = self.transform(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.train_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0]
return example
# Função principal de treinamento
def train_lora(
instance_data_dir: str,
output_dir: str,
resolution: int = 512,
learning_rate: float = 1e-4,
batch_size: int = 1,
num_epochs: int = 1,
train_prompt: str = "a photo of sks dog",
pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5",
):
# Configurações básicas
accelerator = Accelerator(
gradient_accumulation_steps=1,
mixed_precision="fp16",
)
# Carregar modelos
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
# Congelar VAE e Text Encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# Injetar LoRA no UNet
lora_layers = []
for name, module in unet.named_modules():
if name.endswith("to_q") or name.endswith("to_k") or name.endswith("to_v") or name.endswith("to_out.0"):
lora_down, lora_up = create_lora_layers(module, rank=4)
if lora_down is not None:
module.lora_down = lora_down.to(module.weight.device)
module.lora_up = lora_up.to(module.weight.device)
lora_layers.extend([module.lora_down, module.lora_up])
# Guardar forward original
if not hasattr(module, "_original_forward"):
module._original_forward = module.forward
# Criar novo forward com LoRA
def forward_with_lora(self, x):
original_output = self._original_forward(x)
lora_output = self.lora_up(self.lora_down(x))
return original_output + lora_output
# Associar o novo forward ao módulo
import types
module.forward = types.MethodType(forward_with_lora, module)
# Liberar apenas parâmetros LoRA
for layer in lora_layers:
layer.requires_grad_(True)
# Coletar parâmetros treináveis
lora_parameters = []
for layer in lora_layers:
lora_parameters.extend(layer.parameters())
# Otimizador
optimizer = torch.optim.AdamW(lora_parameters, lr=learning_rate)
# Scheduler de ruído
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
# Scheduler de learning rate
lr_scheduler = get_scheduler(
"constant",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_epochs * len(os.listdir(instance_data_dir)),
)
# Dataset e DataLoader
train_dataset = DreamBoothDataset(instance_data_dir, tokenizer, resolution, train_prompt)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Preparar com Accelerator
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# Treinamento
global_step = 0
for epoch in range(num_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Preparar dados
pixel_values = batch["instance_images"].to(accelerator.device)
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * vae.config.scaling_factor
noise = torch.randn_like(latents).to(accelerator.device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["instance_prompt_ids"].to(accelerator.device))[0]
# Predição
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Perda
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
# Backprop
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
print(f"Epoch {epoch + 1}/{num_epochs}, Step {step + 1}, Loss: {loss.item():.6f}")
# Salvar LoRA
lora_state_dict = {}
for name, module in unet.named_modules():
if hasattr(module, "lora_down") and hasattr(module, "lora_up"):
lora_state_dict[f"{name}.lora_down.weight"] = module.lora_down.weight
lora_state_dict[f"{name}.lora_up.weight"] = module.lora_up.weight
lora_path = os.path.join(output_dir, "lora_model.safetensors")
save_file(lora_state_dict, lora_path)
return lora_path
# Função para Gradio
def run_training(
dataset_zip_file,
resolution,
learning_rate,
batch_size,
num_epochs,
train_prompt,
):
if dataset_zip_file is None:
return "Por favor, faça o upload de um arquivo ZIP com seu dataset.", None
# Limpar diretórios anteriores
if os.path.exists("./data/dataset"):
shutil.rmtree("./data/dataset")
if os.path.exists("./outputs"):
shutil.rmtree("./outputs")
os.makedirs("./data/dataset", exist_ok=True)
os.makedirs("./outputs", exist_ok=True)
# Extrair dataset
dataset_dir = "./data/dataset"
zip_path = dataset_zip_file.name
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(dataset_dir)
# Treinar
output_dir = "./outputs"
try:
lora_model_path = train_lora(
instance_data_dir=dataset_dir,
output_dir=output_dir,
resolution=resolution,
learning_rate=learning_rate,
batch_size=batch_size,
num_epochs=num_epochs,
train_prompt=train_prompt,
)
return f"✅ Treinamento concluído! Modelo salvo em: {lora_model_path}", lora_model_path
except Exception as e:
return f"❌ Erro durante o treinamento: {str(e)}", None
# Interface Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion")
with gr.Row():
with gr.Column():
dataset_zip = gr.File(label="📁 Upload do Dataset (ZIP)", file_types=[".zip"])
resolution = gr.Slider(minimum=128, maximum=1024, value=512, step=128, label="📏 Resolução da Imagem")
learning_rate = gr.Number(value=1e-4, label="📈 Learning Rate")
batch_size = gr.Slider(minimum=1, maximum=8, value=1, step=1, label="📦 Batch Size")
num_epochs = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="🔁 Número de Epochs")
train_prompt = gr.Textbox(label="📝 Prompt de Treinamento (ex: a photo of sks dog)", value="a photo of sks dog")
train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="📊 Status do Treinamento", lines=5)
output_file = gr.File(label="💾 Modelo LoRA Treinado")
train_button.click(
run_training,
inputs=[dataset_zip, resolution, learning_rate, batch_size, num_epochs, train_prompt],
outputs=[output_text, output_file],
)
if __name__ == "__main__":
demo.launch(debug=True)