|
|
import argparse |
|
|
import os |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from peft import LoraConfig, get_peft_model |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
class ImageDataset(Dataset): |
|
|
def __init__(self, folder, size=512): |
|
|
self.files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith((".png", ".jpg", ".jpeg"))] |
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((size, size)), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.files) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
img = Image.open(self.files[idx]).convert("RGB") |
|
|
return self.transform(img) |
|
|
|
|
|
def main(args): |
|
|
|
|
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") |
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=args.rank, |
|
|
lora_alpha=16, |
|
|
target_modules=["to_q", "to_v"], |
|
|
lora_dropout=0.1, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
) |
|
|
pipe.unet = get_peft_model(pipe.unet, lora_config) |
|
|
|
|
|
|
|
|
dataset = ImageDataset(args.images_dir) |
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=args.learning_rate) |
|
|
|
|
|
|
|
|
for epoch in range(args.num_epochs): |
|
|
for batch in dataloader: |
|
|
batch = batch.to("cuda") |
|
|
noise = torch.randn_like(batch) |
|
|
optimizer.zero_grad() |
|
|
loss = pipe.unet(batch, noise)["loss"] |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
print(f"✅ Epoch {epoch+1}/{args.num_epochs} finalizado.") |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
torch.save(pipe.unet.state_dict(), os.path.join(args.output_dir, "lora.safetensors")) |
|
|
print("✅ Treinamento concluído. Arquivo salvo em lora.safetensors") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--images_dir", type=str, required=True) |
|
|
parser.add_argument("--output_dir", type=str, required=True) |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4) |
|
|
parser.add_argument("--num_epochs", type=int, default=10) |
|
|
parser.add_argument("--rank", type=int, default=4) |
|
|
args = parser.parse_args() |
|
|
main(args) |