|
|
import os |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from peft import LoraConfig, get_peft_model |
|
|
from transformers import CLIPTextModel |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import gradio as gr |
|
|
import safetensors.torch |
|
|
|
|
|
|
|
|
MODEL_NAME = "runwayml/stable-diffusion-v1-5" |
|
|
OUTPUT_DIR = "lora_output" |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
class ImageDataset(Dataset): |
|
|
def __init__(self, image_paths, caption, size=512): |
|
|
self.image_paths = image_paths |
|
|
self.caption = caption |
|
|
self.size = size |
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize(size), |
|
|
transforms.CenterCrop(size), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]), |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = Image.open(self.image_paths[idx]).convert("RGB") |
|
|
image = self.transform(image) |
|
|
return {"pixel_values": image, "caption": self.caption} |
|
|
|
|
|
def train_lora(images, trigger_word, num_epochs=5, learning_rate=1e-4, lora_rank=4): |
|
|
try: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Usando dispositivo: {device}") |
|
|
|
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float16, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False |
|
|
).to(device) |
|
|
|
|
|
|
|
|
unet_lora_config = LoraConfig( |
|
|
r=lora_rank, |
|
|
lora_alpha=lora_rank, |
|
|
target_modules=["to_q", "to_v", "to_k", "to_out.0"], |
|
|
lora_dropout=0.0, |
|
|
bias="none", |
|
|
) |
|
|
pipe.unet.add_adapter(unet_lora_config) |
|
|
pipe.unet.enable_adapters() |
|
|
|
|
|
|
|
|
text_encoder_lora_config = LoraConfig( |
|
|
r=lora_rank, |
|
|
lora_alpha=lora_rank, |
|
|
target_modules=["q_proj", "v_proj"], |
|
|
lora_dropout=0.0, |
|
|
bias="none", |
|
|
) |
|
|
pipe.text_encoder.add_adapter(text_encoder_lora_config) |
|
|
pipe.text_encoder.enable_adapters() |
|
|
|
|
|
|
|
|
image_paths = [img.name for img in images] |
|
|
if not image_paths: |
|
|
raise ValueError("Nenhuma imagem foi enviada.") |
|
|
|
|
|
dataset = ImageDataset(image_paths, f"a photo of {trigger_word}") |
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) |
|
|
|
|
|
|
|
|
params_to_optimize = ( |
|
|
list(pipe.unet.parameters()) + list(pipe.text_encoder.parameters()) |
|
|
) |
|
|
optimizer = torch.optim.AdamW(params_to_optimize, lr=learning_rate) |
|
|
|
|
|
|
|
|
pipe.unet.train() |
|
|
pipe.text_encoder.train() |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
total_loss = 0.0 |
|
|
for step, batch in enumerate(dataloader): |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
text_inputs = pipe.tokenizer( |
|
|
batch["caption"], |
|
|
padding="max_length", |
|
|
max_length=pipe.tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids.to(device) |
|
|
encoder_hidden_states = pipe.text_encoder(text_input_ids)[0] |
|
|
|
|
|
|
|
|
pixel_values = batch["pixel_values"].to(device, dtype=torch.float16) |
|
|
latents = pipe.vae.encode(pixel_values).latent_dist.sample() |
|
|
latents = latents * 0.18215 |
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
timesteps = torch.randint(0, 1000, (latents.shape[0],), device=latents.device).long() |
|
|
noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps) |
|
|
|
|
|
|
|
|
noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample |
|
|
loss = torch.nn.functional.mse_loss(noise_pred, noise) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
print(f"Epoch {epoch+1}, Step {step+1}, Loss: {loss.item():.4f}") |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
print(f"Epoch {epoch+1}/{num_epochs} finalizado. Loss média: {avg_loss:.4f}") |
|
|
|
|
|
|
|
|
lora_weights = {} |
|
|
|
|
|
|
|
|
for name, module in pipe.unet.named_modules(): |
|
|
if hasattr(module, "lora_A") and hasattr(module, "lora_B"): |
|
|
lora_weights[f"lora_unet_{name}.lora_A.weight"] = module.lora_A["default"].weight |
|
|
lora_weights[f"lora_unet_{name}.lora_B.weight"] = module.lora_B["default"].weight |
|
|
|
|
|
|
|
|
for name, module in pipe.text_encoder.named_modules(): |
|
|
if hasattr(module, "lora_A") and hasattr(module, "lora_B"): |
|
|
lora_weights[f"lora_te_{name}.lora_A.weight"] = module.lora_A["default"].weight |
|
|
lora_weights[f"lora_te_{name}.lora_B.weight"] = module.lora_B["default"].weight |
|
|
|
|
|
|
|
|
lora_path = os.path.join(OUTPUT_DIR, "lora_model.safetensors") |
|
|
safetensors.torch.save_file(lora_weights, lora_path) |
|
|
|
|
|
|
|
|
del pipe, optimizer, dataloader, dataset |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return lora_path |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Erro durante o treinamento: {str(e)}" |
|
|
print(error_msg) |
|
|
raise gr.Error(error_msg) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Treinador LoRA HF") as demo: |
|
|
gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion") |
|
|
gr.Markdown("Envie 3-8 imagens do mesmo objeto. Use um trigger word único (ex: `my_cat`).") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.File(label="📁 Upload de Imagens (JPG/PNG)", file_count="multiple", file_types=["image"]) |
|
|
trigger_word = gr.Textbox(label="🔤 Trigger Word", placeholder="ex: my_dog") |
|
|
epochs = gr.Slider(1, 10, value=3, step=1, label="🔁 Epochs (recomendado: 3-5)") |
|
|
lr = gr.Number(value=1e-4, label="📈 Learning Rate", precision=6) |
|
|
rank = gr.Slider(2, 16, value=4, step=2, label="📊 LoRA Rank") |
|
|
train_btn = gr.Button("🚀 Treinar LoRA", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_file = gr.File(label="💾 Download LoRA (.safetensors)") |
|
|
|
|
|
train_btn.click( |
|
|
fn=train_lora, |
|
|
inputs=[image_input, trigger_word, epochs, lr, rank], |
|
|
outputs=output_file |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |