Gf / app.py
Allex21's picture
Update app.py
b458dab verified
import os
import torch
from diffusers import StableDiffusionPipeline
from peft import LoraConfig, get_peft_model
from transformers import CLIPTextModel
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import safetensors.torch
# Configurações
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "lora_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
class ImageDataset(Dataset):
def __init__(self, image_paths, caption, size=512):
self.image_paths = image_paths
self.caption = caption
self.size = size
self.transform = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
image = self.transform(image)
return {"pixel_values": image, "caption": self.caption}
def train_lora(images, trigger_word, num_epochs=5, learning_rate=1e-4, lora_rank=4):
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Usando dispositivo: {device}")
# Carrega modelo com half precision para economizar memória
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False
).to(device)
# Ativa LoRA no UNet
unet_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank,
target_modules=["to_q", "to_v", "to_k", "to_out.0"],
lora_dropout=0.0,
bias="none",
)
pipe.unet.add_adapter(unet_lora_config)
pipe.unet.enable_adapters()
# Ativa LoRA no Text Encoder
text_encoder_lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.0,
bias="none",
)
pipe.text_encoder.add_adapter(text_encoder_lora_config)
pipe.text_encoder.enable_adapters()
# Prepara dataset
image_paths = [img.name for img in images]
if not image_paths:
raise ValueError("Nenhuma imagem foi enviada.")
dataset = ImageDataset(image_paths, f"a photo of {trigger_word}")
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
# Otimizador
params_to_optimize = (
list(pipe.unet.parameters()) + list(pipe.text_encoder.parameters())
)
optimizer = torch.optim.AdamW(params_to_optimize, lr=learning_rate)
# Treinamento simplificado
pipe.unet.train()
pipe.text_encoder.train()
for epoch in range(num_epochs):
total_loss = 0.0
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
# Texto
text_inputs = pipe.tokenizer(
batch["caption"],
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
encoder_hidden_states = pipe.text_encoder(text_input_ids)[0]
# Imagem → latentes
pixel_values = batch["pixel_values"].to(device, dtype=torch.float16)
latents = pipe.vae.encode(pixel_values).latent_dist.sample()
latents = latents * 0.18215
# Adiciona ruído
noise = torch.randn_like(latents)
timesteps = torch.randint(0, 1000, (latents.shape[0],), device=latents.device).long()
noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
# Prediz o ruído
noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred, noise)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Step {step+1}, Loss: {loss.item():.4f}")
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{num_epochs} finalizado. Loss média: {avg_loss:.4f}")
# Salva pesos da LoRA
lora_weights = {}
# UNet
for name, module in pipe.unet.named_modules():
if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
lora_weights[f"lora_unet_{name}.lora_A.weight"] = module.lora_A["default"].weight
lora_weights[f"lora_unet_{name}.lora_B.weight"] = module.lora_B["default"].weight
# Text Encoder
for name, module in pipe.text_encoder.named_modules():
if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
lora_weights[f"lora_te_{name}.lora_A.weight"] = module.lora_A["default"].weight
lora_weights[f"lora_te_{name}.lora_B.weight"] = module.lora_B["default"].weight
# Salva
lora_path = os.path.join(OUTPUT_DIR, "lora_model.safetensors")
safetensors.torch.save_file(lora_weights, lora_path)
# Libera memória
del pipe, optimizer, dataloader, dataset
torch.cuda.empty_cache()
return lora_path
except Exception as e:
error_msg = f"Erro durante o treinamento: {str(e)}"
print(error_msg)
raise gr.Error(error_msg)
# Interface
with gr.Blocks(title="Treinador LoRA HF") as demo:
gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion")
gr.Markdown("Envie 3-8 imagens do mesmo objeto. Use um trigger word único (ex: `my_cat`).")
with gr.Row():
with gr.Column():
image_input = gr.File(label="📁 Upload de Imagens (JPG/PNG)", file_count="multiple", file_types=["image"])
trigger_word = gr.Textbox(label="🔤 Trigger Word", placeholder="ex: my_dog")
epochs = gr.Slider(1, 10, value=3, step=1, label="🔁 Epochs (recomendado: 3-5)")
lr = gr.Number(value=1e-4, label="📈 Learning Rate", precision=6)
rank = gr.Slider(2, 16, value=4, step=2, label="📊 LoRA Rank")
train_btn = gr.Button("🚀 Treinar LoRA", variant="primary")
with gr.Column():
output_file = gr.File(label="💾 Download LoRA (.safetensors)")
train_btn.click(
fn=train_lora,
inputs=[image_input, trigger_word, epochs, lr, rank],
outputs=output_file
)
if __name__ == "__main__":
demo.launch()