|
|
|
|
|
import os |
|
|
import torch |
|
|
import argparse |
|
|
from diffusers import StableDiffusionPipeline, DDPMScheduler |
|
|
from peft import LoraConfig, get_peft_model |
|
|
from accelerate import Accelerator |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import glob |
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
accelerator = Accelerator( |
|
|
mixed_precision="fp16" if args.mixed_precision else None |
|
|
) |
|
|
|
|
|
print(f"🚀 Carregando modelo: {args.model_name}") |
|
|
try: |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
args.model_name, |
|
|
torch_dtype=torch.float16 if args.mixed_precision else torch.float32 |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"❌ Falha ao carregar modelo: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
unet = pipe.unet |
|
|
tokenizer = pipe.tokenizer |
|
|
text_encoder = pipe.text_encoder |
|
|
vae = pipe.vae |
|
|
noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=args.lora_rank, |
|
|
lora_alpha=args.lora_alpha, |
|
|
target_modules=["to_q", "to_v", "to_k", "to_out.0"], |
|
|
lora_dropout=0.0, |
|
|
bias="none" |
|
|
) |
|
|
unet = get_peft_model(unet, lora_config) |
|
|
unet.print_trainable_parameters() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(512), |
|
|
transforms.CenterCrop(512), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]), |
|
|
]) |
|
|
|
|
|
|
|
|
image_paths = [] |
|
|
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp"]: |
|
|
image_paths.extend(glob.glob(os.path.join(args.dataset_dir, ext))) |
|
|
|
|
|
if len(image_paths) == 0: |
|
|
print("❌ Nenhuma imagem encontrada no diretório!") |
|
|
return |
|
|
|
|
|
print(f"✅ {len(image_paths)} imagens encontradas. Carregando legendas...") |
|
|
|
|
|
captions = [] |
|
|
valid_images = [] |
|
|
for img_path in image_paths: |
|
|
txt_path = os.path.splitext(img_path)[0] + ".txt" |
|
|
if os.path.exists(txt_path): |
|
|
with open(txt_path, "r", encoding="utf-8") as f: |
|
|
caption = f.read().strip() |
|
|
else: |
|
|
caption = "person" |
|
|
captions.append(caption) |
|
|
valid_images.append(img_path) |
|
|
|
|
|
|
|
|
class SimpleDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, image_list, caption_list, transform): |
|
|
self.images = image_list |
|
|
self.captions = caption_list |
|
|
self.transform = transform |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = Image.open(self.images[idx]).convert("RGB") |
|
|
image = self.transform(image) |
|
|
return {"pixel_values": image, "input_ids": self.captions[idx]} |
|
|
|
|
|
dataset = SimpleDataset(valid_images, captions, transform) |
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True |
|
|
) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(unet.parameters(), lr=args.learning_rate) |
|
|
|
|
|
|
|
|
unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader) |
|
|
|
|
|
|
|
|
vae.eval() |
|
|
text_encoder.eval() |
|
|
|
|
|
|
|
|
unet.train() |
|
|
step = 0 |
|
|
for epoch in range(args.num_epochs): |
|
|
for batch in dataloader: |
|
|
with accelerator.accumulate(unet): |
|
|
|
|
|
pixel_values = batch["pixel_values"] |
|
|
latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215 |
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
bsz = latents.shape[0] |
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) |
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
batch["input_ids"], |
|
|
max_length=77, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
).to(latents.device) |
|
|
encoder_hidden_states = text_encoder(**inputs)[0] |
|
|
|
|
|
|
|
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
|
|
loss = torch.nn.functional.mse_loss(noise_pred, noise) |
|
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
step += 1 |
|
|
|
|
|
print(f"Epoch {epoch+1}/{args.num_epochs} - Loss: {loss.item():.4f}") |
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
if accelerator.is_main_process: |
|
|
output_dir = args.output_dir |
|
|
unwrapped_unet = accelerator.unwrap_model(unet) |
|
|
unwrapped_unet.save_pretrained(output_dir) |
|
|
print(f"✅ Modelo LoRA salvo em: {output_dir}") |
|
|
|
|
|
|
|
|
from peft import save_model |
|
|
save_model(unwrapped_unet, output_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Treina um modelo LoRA para Stable Diffusion") |
|
|
parser.add_argument("--model_name", type=str, default="runwayml/stable-diffusion-v1-5", help="Modelo base do HF") |
|
|
parser.add_argument("--dataset_dir", type=str, required=True, help="Pasta com imagens e .txt") |
|
|
parser.add_argument("--output_dir", type=str, default="lora-output", help="Onde salvar o LoRA") |
|
|
parser.add_argument("--lora_rank", type=int, default=4, help="Rank LoRA (4-64)") |
|
|
parser.add_argument("--lora_alpha", type=int, default=32, help="Alpha LoRA") |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Taxa de aprendizado") |
|
|
parser.add_argument("--num_epochs", type=int, default=10, help="Número de épocas") |
|
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size") |
|
|
parser.add_argument("--mixed_precision", action="store_true", help="Usa FP16") |
|
|
|
|
|
args = parser.parse_args() |
|
|
main(args) |