| import gradio as gr |
| import os |
| import torch |
| from accelerate import Accelerator |
| from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler |
| from diffusers.optimization import get_scheduler |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
| from transformers import CLIPTextModel, CLIPTokenizer |
| import zipfile |
| import shutil |
| from safetensors.torch import save_file |
| import torch.nn as nn |
|
|
| |
| def create_lora_layers(module, rank=4): |
| if isinstance(module, nn.Linear): |
| lora_down = nn.Linear(module.in_features, rank, bias=False) |
| lora_up = nn.Linear(rank, module.out_features, bias=False) |
| nn.init.zeros_(lora_up.weight) |
| return lora_down, lora_up |
| return None, None |
|
|
| |
| class DreamBoothDataset(Dataset): |
| def __init__(self, instance_data_root, tokenizer, size=512, train_prompt="a photo of sks dog"): |
| self.instance_data_root = instance_data_root |
| self.tokenizer = tokenizer |
| self.size = size |
| self.train_prompt = train_prompt |
| self.instance_images_path = [ |
| os.path.join(instance_data_root, file_path) |
| for file_path in os.listdir(instance_data_root) |
| if file_path.endswith((".png", ".jpg", ".jpeg")) |
| ] |
| self.transform = transforms.Compose( |
| [ |
| transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
| transforms.CenterCrop(size), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ] |
| ) |
|
|
| def __len__(self): |
| return len(self.instance_images_path) |
|
|
| def __getitem__(self, index): |
| instance_image = Image.open(self.instance_images_path[index]) |
| if not instance_image.mode == "RGB": |
| instance_image = instance_image.convert("RGB") |
| example = {} |
| example["instance_images"] = self.transform(instance_image) |
| example["instance_prompt_ids"] = self.tokenizer( |
| self.train_prompt, |
| truncation=True, |
| padding="max_length", |
| max_length=self.tokenizer.model_max_length, |
| return_tensors="pt", |
| ).input_ids[0] |
| return example |
|
|
| |
| def train_lora( |
| instance_data_dir: str, |
| output_dir: str, |
| resolution: int = 512, |
| learning_rate: float = 1e-4, |
| batch_size: int = 1, |
| num_epochs: int = 1, |
| train_prompt: str = "a photo of sks dog", |
| pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5", |
| ): |
| |
| accelerator = Accelerator( |
| gradient_accumulation_steps=1, |
| mixed_precision="fp16", |
| ) |
|
|
| |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder") |
| vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") |
| unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") |
|
|
| |
| vae.requires_grad_(False) |
| text_encoder.requires_grad_(False) |
| unet.requires_grad_(False) |
|
|
| |
| lora_layers = [] |
| for name, module in unet.named_modules(): |
| if name.endswith("to_q") or name.endswith("to_k") or name.endswith("to_v") or name.endswith("to_out.0"): |
| lora_down, lora_up = create_lora_layers(module, rank=4) |
| if lora_down is not None: |
| module.lora_down = lora_down.to(module.weight.device) |
| module.lora_up = lora_up.to(module.weight.device) |
| lora_layers.extend([module.lora_down, module.lora_up]) |
|
|
| |
| if not hasattr(module, "_original_forward"): |
| module._original_forward = module.forward |
|
|
| |
| def forward_with_lora(self, x): |
| original_output = self._original_forward(x) |
| lora_output = self.lora_up(self.lora_down(x)) |
| return original_output + lora_output |
|
|
| |
| import types |
| module.forward = types.MethodType(forward_with_lora, module) |
|
|
| |
| for layer in lora_layers: |
| layer.requires_grad_(True) |
|
|
| |
| lora_parameters = [] |
| for layer in lora_layers: |
| lora_parameters.extend(layer.parameters()) |
|
|
| |
| optimizer = torch.optim.AdamW(lora_parameters, lr=learning_rate) |
|
|
| |
| noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") |
|
|
| |
| lr_scheduler = get_scheduler( |
| "constant", |
| optimizer=optimizer, |
| num_warmup_steps=0, |
| num_training_steps=num_epochs * len(os.listdir(instance_data_dir)), |
| ) |
|
|
| |
| train_dataset = DreamBoothDataset(instance_data_dir, tokenizer, resolution, train_prompt) |
| train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
|
| |
| unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
| unet, optimizer, train_dataloader, lr_scheduler |
| ) |
|
|
| |
| global_step = 0 |
| for epoch in range(num_epochs): |
| unet.train() |
| for step, batch in enumerate(train_dataloader): |
| with accelerator.accumulate(unet): |
| |
| pixel_values = batch["instance_images"].to(accelerator.device) |
| latents = vae.encode(pixel_values).latent_dist.sample() |
| latents = latents * vae.config.scaling_factor |
|
|
| noise = torch.randn_like(latents).to(accelerator.device) |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long() |
|
|
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
| encoder_hidden_states = text_encoder(batch["instance_prompt_ids"].to(accelerator.device))[0] |
|
|
| |
| model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
|
|
| |
| loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean") |
|
|
| |
| accelerator.backward(loss) |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
|
|
| global_step += 1 |
| print(f"Epoch {epoch + 1}/{num_epochs}, Step {step + 1}, Loss: {loss.item():.6f}") |
|
|
| |
| lora_state_dict = {} |
| for name, module in unet.named_modules(): |
| if hasattr(module, "lora_down") and hasattr(module, "lora_up"): |
| lora_state_dict[f"{name}.lora_down.weight"] = module.lora_down.weight |
| lora_state_dict[f"{name}.lora_up.weight"] = module.lora_up.weight |
|
|
| lora_path = os.path.join(output_dir, "lora_model.safetensors") |
| save_file(lora_state_dict, lora_path) |
|
|
| return lora_path |
|
|
| |
| def run_training( |
| dataset_zip_file, |
| resolution, |
| learning_rate, |
| batch_size, |
| num_epochs, |
| train_prompt, |
| ): |
| if dataset_zip_file is None: |
| return "Por favor, faça o upload de um arquivo ZIP com seu dataset.", None |
|
|
| |
| if os.path.exists("./data/dataset"): |
| shutil.rmtree("./data/dataset") |
| if os.path.exists("./outputs"): |
| shutil.rmtree("./outputs") |
| os.makedirs("./data/dataset", exist_ok=True) |
| os.makedirs("./outputs", exist_ok=True) |
|
|
| |
| dataset_dir = "./data/dataset" |
| zip_path = dataset_zip_file.name |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall(dataset_dir) |
|
|
| |
| output_dir = "./outputs" |
| try: |
| lora_model_path = train_lora( |
| instance_data_dir=dataset_dir, |
| output_dir=output_dir, |
| resolution=resolution, |
| learning_rate=learning_rate, |
| batch_size=batch_size, |
| num_epochs=num_epochs, |
| train_prompt=train_prompt, |
| ) |
| return f"✅ Treinamento concluído! Modelo salvo em: {lora_model_path}", lora_model_path |
| except Exception as e: |
| return f"❌ Erro durante o treinamento: {str(e)}", None |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| dataset_zip = gr.File(label="📁 Upload do Dataset (ZIP)", file_types=[".zip"]) |
| resolution = gr.Slider(minimum=128, maximum=1024, value=512, step=128, label="📏 Resolução da Imagem") |
| learning_rate = gr.Number(value=1e-4, label="📈 Learning Rate") |
| batch_size = gr.Slider(minimum=1, maximum=8, value=1, step=1, label="📦 Batch Size") |
| num_epochs = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="🔁 Número de Epochs") |
| train_prompt = gr.Textbox(label="📝 Prompt de Treinamento (ex: a photo of sks dog)", value="a photo of sks dog") |
| train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary") |
|
|
| with gr.Column(): |
| output_text = gr.Textbox(label="📊 Status do Treinamento", lines=5) |
| output_file = gr.File(label="💾 Modelo LoRA Treinado") |
|
|
| train_button.click( |
| run_training, |
| inputs=[dataset_zip, resolution, learning_rate, batch_size, num_epochs, train_prompt], |
| outputs=[output_text, output_file], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(debug=True) |