import os import torch from diffusers import StableDiffusionPipeline from peft import LoraConfig, get_peft_model from transformers import CLIPTextModel from PIL import Image from torchvision import transforms from torch.utils.data import Dataset, DataLoader import gradio as gr import safetensors.torch # Configurações MODEL_NAME = "runwayml/stable-diffusion-v1-5" OUTPUT_DIR = "lora_output" os.makedirs(OUTPUT_DIR, exist_ok=True) class ImageDataset(Dataset): def __init__(self, image_paths, caption, size=512): self.image_paths = image_paths self.caption = caption self.size = size self.transform = transforms.Compose([ transforms.Resize(size), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = Image.open(self.image_paths[idx]).convert("RGB") image = self.transform(image) return {"pixel_values": image, "caption": self.caption} def train_lora(images, trigger_word, num_epochs=5, learning_rate=1e-4, lora_rank=4): try: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Usando dispositivo: {device}") # Carrega modelo com half precision para economizar memória pipe = StableDiffusionPipeline.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False ).to(device) # Ativa LoRA no UNet unet_lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_rank, target_modules=["to_q", "to_v", "to_k", "to_out.0"], lora_dropout=0.0, bias="none", ) pipe.unet.add_adapter(unet_lora_config) pipe.unet.enable_adapters() # Ativa LoRA no Text Encoder text_encoder_lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_rank, target_modules=["q_proj", "v_proj"], lora_dropout=0.0, bias="none", ) pipe.text_encoder.add_adapter(text_encoder_lora_config) pipe.text_encoder.enable_adapters() # Prepara dataset image_paths = [img.name for img in images] if not image_paths: raise ValueError("Nenhuma imagem foi enviada.") dataset = ImageDataset(image_paths, f"a photo of {trigger_word}") dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # Otimizador params_to_optimize = ( list(pipe.unet.parameters()) + list(pipe.text_encoder.parameters()) ) optimizer = torch.optim.AdamW(params_to_optimize, lr=learning_rate) # Treinamento simplificado pipe.unet.train() pipe.text_encoder.train() for epoch in range(num_epochs): total_loss = 0.0 for step, batch in enumerate(dataloader): optimizer.zero_grad() # Texto text_inputs = pipe.tokenizer( batch["caption"], padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) encoder_hidden_states = pipe.text_encoder(text_input_ids)[0] # Imagem → latentes pixel_values = batch["pixel_values"].to(device, dtype=torch.float16) latents = pipe.vae.encode(pixel_values).latent_dist.sample() latents = latents * 0.18215 # Adiciona ruído noise = torch.randn_like(latents) timesteps = torch.randint(0, 1000, (latents.shape[0],), device=latents.device).long() noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps) # Prediz o ruído noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = torch.nn.functional.mse_loss(noise_pred, noise) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Step {step+1}, Loss: {loss.item():.4f}") avg_loss = total_loss / len(dataloader) print(f"Epoch {epoch+1}/{num_epochs} finalizado. Loss média: {avg_loss:.4f}") # Salva pesos da LoRA lora_weights = {} # UNet for name, module in pipe.unet.named_modules(): if hasattr(module, "lora_A") and hasattr(module, "lora_B"): lora_weights[f"lora_unet_{name}.lora_A.weight"] = module.lora_A["default"].weight lora_weights[f"lora_unet_{name}.lora_B.weight"] = module.lora_B["default"].weight # Text Encoder for name, module in pipe.text_encoder.named_modules(): if hasattr(module, "lora_A") and hasattr(module, "lora_B"): lora_weights[f"lora_te_{name}.lora_A.weight"] = module.lora_A["default"].weight lora_weights[f"lora_te_{name}.lora_B.weight"] = module.lora_B["default"].weight # Salva lora_path = os.path.join(OUTPUT_DIR, "lora_model.safetensors") safetensors.torch.save_file(lora_weights, lora_path) # Libera memória del pipe, optimizer, dataloader, dataset torch.cuda.empty_cache() return lora_path except Exception as e: error_msg = f"Erro durante o treinamento: {str(e)}" print(error_msg) raise gr.Error(error_msg) # Interface with gr.Blocks(title="Treinador LoRA HF") as demo: gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion") gr.Markdown("Envie 3-8 imagens do mesmo objeto. Use um trigger word único (ex: `my_cat`).") with gr.Row(): with gr.Column(): image_input = gr.File(label="📁 Upload de Imagens (JPG/PNG)", file_count="multiple", file_types=["image"]) trigger_word = gr.Textbox(label="🔤 Trigger Word", placeholder="ex: my_dog") epochs = gr.Slider(1, 10, value=3, step=1, label="🔁 Epochs (recomendado: 3-5)") lr = gr.Number(value=1e-4, label="📈 Learning Rate", precision=6) rank = gr.Slider(2, 16, value=4, step=2, label="📊 LoRA Rank") train_btn = gr.Button("🚀 Treinar LoRA", variant="primary") with gr.Column(): output_file = gr.File(label="💾 Download LoRA (.safetensors)") train_btn.click( fn=train_lora, inputs=[image_input, trigger_word, epochs, lr, rank], outputs=output_file ) if __name__ == "__main__": demo.launch()