Allex21 commited on
Commit
c13c01c
·
verified ·
1 Parent(s): 6c328ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -238
app.py CHANGED
@@ -1,260 +1,158 @@
1
- import gradio as gr
2
  import os
3
  import torch
4
- from accelerate import Accelerator
5
- from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
6
- from diffusers.optimization import get_scheduler
7
  from PIL import Image
8
- from torch.utils.data import Dataset
9
  from torchvision import transforms
10
- from transformers import CLIPTextModel, CLIPTokenizer
11
- import zipfile
12
- import shutil
13
- from safetensors.torch import save_file
14
- import torch.nn as nn
15
 
16
- # Função para criar camadas LoRA
17
- def create_lora_layers(module, rank=4):
18
- if isinstance(module, nn.Linear):
19
- lora_down = nn.Linear(module.in_features, rank, bias=False)
20
- lora_up = nn.Linear(rank, module.out_features, bias=False)
21
- nn.init.zeros_(lora_up.weight) # Inicialização zero para começar neutro
22
- return lora_down, lora_up
23
- return None, None
24
 
25
- # Dataset simplificado
26
- class DreamBoothDataset(Dataset):
27
- def __init__(self, instance_data_root, tokenizer, size=512, train_prompt="a photo of sks dog"):
28
- self.instance_data_root = instance_data_root
29
- self.tokenizer = tokenizer
30
  self.size = size
31
- self.train_prompt = train_prompt
32
- self.instance_images_path = [
33
- os.path.join(instance_data_root, file_path)
34
- for file_path in os.listdir(instance_data_root)
35
- if file_path.endswith((".png", ".jpg", ".jpeg"))
36
- ]
37
- self.transform = transforms.Compose(
38
- [
39
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
40
- transforms.CenterCrop(size),
41
- transforms.ToTensor(),
42
- transforms.Normalize([0.5], [0.5]),
43
- ]
44
- )
45
 
46
  def __len__(self):
47
- return len(self.instance_images_path)
48
-
49
- def __getitem__(self, index):
50
- instance_image = Image.open(self.instance_images_path[index])
51
- if not instance_image.mode == "RGB":
52
- instance_image = instance_image.convert("RGB")
53
- example = {}
54
- example["instance_images"] = self.transform(instance_image)
55
- example["instance_prompt_ids"] = self.tokenizer(
56
- self.train_prompt,
57
- truncation=True,
58
- padding="max_length",
59
- max_length=self.tokenizer.model_max_length,
60
- return_tensors="pt",
61
- ).input_ids[0]
62
- return example
63
-
64
- # Função principal de treinamento
65
- def train_lora(
66
- instance_data_dir: str,
67
- output_dir: str,
68
- resolution: int = 512,
69
- learning_rate: float = 1e-4,
70
- batch_size: int = 1,
71
- num_epochs: int = 1,
72
- train_prompt: str = "a photo of sks dog",
73
- pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5",
74
- ):
75
- # Configurações básicas
76
- accelerator = Accelerator(
77
- gradient_accumulation_steps=1,
78
- mixed_precision="fp16",
79
  )
80
-
81
- # Carregar modelos
82
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
83
- text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
84
- vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
85
- unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
86
-
87
- # Congelar VAE e Text Encoder
88
- vae.requires_grad_(False)
89
- text_encoder.requires_grad_(False)
90
- unet.requires_grad_(False)
91
-
92
- # Injetar LoRA no UNet
93
- lora_layers = []
94
- for name, module in unet.named_modules():
95
- if name.endswith("to_q") or name.endswith("to_k") or name.endswith("to_v") or name.endswith("to_out.0"):
96
- lora_down, lora_up = create_lora_layers(module, rank=4)
97
- if lora_down is not None:
98
- module.lora_down = lora_down.to(module.weight.device)
99
- module.lora_up = lora_up.to(module.weight.device)
100
- lora_layers.extend([module.lora_down, module.lora_up])
101
-
102
- # Guardar forward original
103
- if not hasattr(module, "_original_forward"):
104
- module._original_forward = module.forward
105
-
106
- # Criar novo forward com LoRA
107
- def forward_with_lora(self, x):
108
- original_output = self._original_forward(x)
109
- lora_output = self.lora_up(self.lora_down(x))
110
- return original_output + lora_output
111
-
112
- # Associar o novo forward ao módulo
113
- import types
114
- module.forward = types.MethodType(forward_with_lora, module)
115
-
116
- # Liberar apenas parâmetros LoRA
117
- for layer in lora_layers:
118
- layer.requires_grad_(True)
119
-
120
- # Coletar parâmetros treináveis
121
- lora_parameters = []
122
- for layer in lora_layers:
123
- lora_parameters.extend(layer.parameters())
124
-
125
- # Otimizador
126
- optimizer = torch.optim.AdamW(lora_parameters, lr=learning_rate)
127
-
128
- # Scheduler de ruído
129
- noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
130
-
131
- # Scheduler de learning rate
132
- lr_scheduler = get_scheduler(
133
- "constant",
134
- optimizer=optimizer,
135
- num_warmup_steps=0,
136
- num_training_steps=num_epochs * len(os.listdir(instance_data_dir)),
137
  )
138
-
139
- # Dataset e DataLoader
140
- train_dataset = DreamBoothDataset(instance_data_dir, tokenizer, resolution, train_prompt)
141
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
142
-
143
- # Preparar com Accelerator
144
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
145
- unet, optimizer, train_dataloader, lr_scheduler
 
 
146
  )
147
-
 
148
  # Treinamento
149
- global_step = 0
 
 
150
  for epoch in range(num_epochs):
151
- unet.train()
152
- for step, batch in enumerate(train_dataloader):
153
- with accelerator.accumulate(unet):
154
- # Preparar dados
155
- pixel_values = batch["instance_images"].to(accelerator.device)
156
- latents = vae.encode(pixel_values).latent_dist.sample()
157
- latents = latents * vae.config.scaling_factor
158
-
159
- noise = torch.randn_like(latents).to(accelerator.device)
160
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
161
-
162
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
163
-
164
- encoder_hidden_states = text_encoder(batch["instance_prompt_ids"].to(accelerator.device))[0]
165
-
166
- # Predição
167
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
168
-
169
- # Perda
170
- loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
171
-
172
- # Backprop
173
- accelerator.backward(loss)
174
- optimizer.step()
175
- lr_scheduler.step()
176
- optimizer.zero_grad()
177
-
178
- global_step += 1
179
- print(f"Epoch {epoch + 1}/{num_epochs}, Step {step + 1}, Loss: {loss.item():.6f}")
180
-
181
- # Salvar LoRA
182
- lora_state_dict = {}
183
- for name, module in unet.named_modules():
184
- if hasattr(module, "lora_down") and hasattr(module, "lora_up"):
185
- lora_state_dict[f"{name}.lora_down.weight"] = module.lora_down.weight
186
- lora_state_dict[f"{name}.lora_up.weight"] = module.lora_up.weight
187
-
188
- lora_path = os.path.join(output_dir, "lora_model.safetensors")
189
- save_file(lora_state_dict, lora_path)
190
-
 
 
 
 
 
 
 
 
 
 
 
191
  return lora_path
192
 
193
- # Função para Gradio
194
- def run_training(
195
- dataset_zip_file,
196
- resolution,
197
- learning_rate,
198
- batch_size,
199
- num_epochs,
200
- train_prompt,
201
- ):
202
- if dataset_zip_file is None:
203
- return "Por favor, faça o upload de um arquivo ZIP com seu dataset.", None
204
-
205
- # Limpar diretórios anteriores
206
- if os.path.exists("./data/dataset"):
207
- shutil.rmtree("./data/dataset")
208
- if os.path.exists("./outputs"):
209
- shutil.rmtree("./outputs")
210
- os.makedirs("./data/dataset", exist_ok=True)
211
- os.makedirs("./outputs", exist_ok=True)
212
-
213
- # Extrair dataset
214
- dataset_dir = "./data/dataset"
215
- zip_path = dataset_zip_file.name
216
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
217
- zip_ref.extractall(dataset_dir)
218
-
219
- # Treinar
220
- output_dir = "./outputs"
221
- try:
222
- lora_model_path = train_lora(
223
- instance_data_dir=dataset_dir,
224
- output_dir=output_dir,
225
- resolution=resolution,
226
- learning_rate=learning_rate,
227
- batch_size=batch_size,
228
- num_epochs=num_epochs,
229
- train_prompt=train_prompt,
230
- )
231
- return f"✅ Treinamento concluído! Modelo salvo em: {lora_model_path}", lora_model_path
232
- except Exception as e:
233
- return f"❌ Erro durante o treinamento: {str(e)}", None
234
-
235
  # Interface Gradio
236
- with gr.Blocks() as demo:
237
- gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion")
238
-
 
239
  with gr.Row():
240
  with gr.Column():
241
- dataset_zip = gr.File(label="📁 Upload do Dataset (ZIP)", file_types=[".zip"])
242
- resolution = gr.Slider(minimum=128, maximum=1024, value=512, step=128, label="📏 Resolução da Imagem")
243
- learning_rate = gr.Number(value=1e-4, label="📈 Learning Rate")
244
- batch_size = gr.Slider(minimum=1, maximum=8, value=1, step=1, label="📦 Batch Size")
245
- num_epochs = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="🔁 Número de Epochs")
246
- train_prompt = gr.Textbox(label="📝 Prompt de Treinamento (ex: a photo of sks dog)", value="a photo of sks dog")
247
- train_button = gr.Button("🚀 Iniciar Treinamento", variant="primary")
248
-
249
  with gr.Column():
250
- output_text = gr.Textbox(label="📊 Status do Treinamento", lines=5)
251
- output_file = gr.File(label="💾 Modelo LoRA Treinado")
252
-
253
- train_button.click(
254
- run_training,
255
- inputs=[dataset_zip, resolution, learning_rate, batch_size, num_epochs, train_prompt],
256
- outputs=[output_text, output_file],
257
  )
258
 
259
- if __name__ == "__main__":
260
- demo.launch(debug=True)
 
 
1
  import os
2
  import torch
3
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel
4
+ from peft import LoraConfig, get_peft_model
5
+ from transformers import CLIPTextModel
6
  from PIL import Image
 
7
  from torchvision import transforms
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import gradio as gr
10
+ import safetensors.torch
 
 
11
 
12
+ # Configurações básicas
13
+ MODEL_NAME = "runwayml/stable-diffusion-v1-5"
14
+ OUTPUT_DIR = "lora_output"
15
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
 
 
 
 
16
 
17
+ class ImageDataset(Dataset):
18
+ def __init__(self, image_paths, caption, size=512):
19
+ self.image_paths = image_paths
20
+ self.caption = caption
 
21
  self.size = size
22
+ self.transform = transforms.Compose([
23
+ transforms.Resize(size),
24
+ transforms.CenterCrop(size),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.5], [0.5]),
27
+ ])
 
 
 
 
 
 
 
 
28
 
29
  def __len__(self):
30
+ return len(self.image_paths)
31
+
32
+ def __getitem__(self, idx):
33
+ image = Image.open(self.image_paths[idx]).convert("RGB")
34
+ image = self.transform(image)
35
+ return {"pixel_values": image, "caption": self.caption}
36
+
37
+ def train_lora(images, trigger_word, num_epochs=10, learning_rate=1e-4, lora_rank=4, batch_size=1):
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+ # Carrega o modelo
41
+ pipe = StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
42
+ pipe.to(device)
43
+
44
+ # Configura LoRA no UNet
45
+ unet_lora_config = LoraConfig(
46
+ r=lora_rank,
47
+ lora_alpha=lora_rank,
48
+ target_modules=["to_q", "to_v", "to_k", "to_out.0"],
49
+ lora_dropout=0.0,
50
+ bias="none",
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
+ pipe.unet = get_peft_model(pipe.unet, unet_lora_config)
53
+
54
+ # Configura LoRA no Text Encoder (opcional, mas recomendado)
55
+ text_encoder_lora_config = LoraConfig(
56
+ r=lora_rank,
57
+ lora_alpha=lora_rank,
58
+ target_modules=["q_proj", "v_proj"],
59
+ lora_dropout=0.0,
60
+ bias="none",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
+ pipe.text_encoder = get_peft_model(pipe.text_encoder, text_encoder_lora_config)
63
+
64
+ # Prepara dataset
65
+ image_paths = [img.name for img in images]
66
+ dataset = ImageDataset(image_paths, f"a photo of {trigger_word}")
67
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
68
+
69
+ # Otimizador
70
+ params_to_optimize = (
71
+ list(pipe.unet.parameters()) + list(pipe.text_encoder.parameters())
72
  )
73
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=learning_rate)
74
+
75
  # Treinamento
76
+ pipe.unet.train()
77
+ pipe.text_encoder.train()
78
+
79
  for epoch in range(num_epochs):
80
+ for batch in dataloader:
81
+ optimizer.zero_grad()
82
+
83
+ # Encode texto
84
+ text_inputs = pipe.tokenizer(
85
+ batch["caption"],
86
+ padding="max_length",
87
+ max_length=pipe.tokenizer.model_max_length,
88
+ truncation=True,
89
+ return_tensors="pt",
90
+ )
91
+ text_input_ids = text_inputs.input_ids.to(device)
92
+ encoder_hidden_states = pipe.text_encoder(text_input_ids)[0]
93
+
94
+ # Encode imagem (latentes)
95
+ latents = pipe.vae.encode(batch["pixel_values"].to(device, dtype=torch.float16)).latent_dist.sample()
96
+ latents = latents * 0.18215
97
+
98
+ # Simula timestep e ruído (simplificado para demonstração)
99
+ noise = torch.randn_like(latents)
100
+ timesteps = torch.randint(0, 1000, (latents.shape[0],), device=latents.device).long()
101
+ noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
102
+
103
+ # Predição
104
+ noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample
105
+
106
+ # Loss e backward
107
+ loss = torch.nn.functional.mse_loss(noise_pred, noise)
108
+ loss.backward()
109
+ optimizer.step()
110
+
111
+ print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}")
112
+
113
+ # Salva LoRA
114
+ lora_weights = {}
115
+ for name, module in pipe.unet.named_modules():
116
+ if hasattr(module, "lora_A"):
117
+ lora_weights[f"lora_unet_{name}.lora_A.weight"] = module.lora_A.default.weight
118
+ lora_weights[f"lora_unet_{name}.lora_B.weight"] = module.lora_B.default.weight
119
+
120
+ for name, module in pipe.text_encoder.named_modules():
121
+ if hasattr(module, "lora_A"):
122
+ lora_weights[f"lora_te_{name}.lora_A.weight"] = module.lora_A.default.weight
123
+ lora_weights[f"lora_te_{name}.lora_B.weight"] = module.lora_B.default.weight
124
+
125
+ lora_path = os.path.join(OUTPUT_DIR, "lora_model.safetensors")
126
+ safetensors.torch.save_file(lora_weights, lora_path)
127
+
128
+ del pipe
129
+ torch.cuda.empty_cache()
130
+
131
  return lora_path
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # Interface Gradio
134
+ with gr.Blocks(title="Treinador LoRA Simplificado") as demo:
135
+ gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion (Hugging Face)")
136
+ gr.Markdown("Faça upload de 3-10 imagens do mesmo conceito. Use um 'trigger word' único (ex: `shs_dog`).")
137
+
138
  with gr.Row():
139
  with gr.Column():
140
+ image_input = gr.File(label="📁 Faça upload das imagens (JPG/PNG)", file_count="multiple", file_types=["image"])
141
+ trigger_word = gr.Textbox(label="🔤 Trigger Word (ex: my_cat)", placeholder="shs_dog")
142
+ epochs = gr.Slider(1, 50, value=10, step=1, label="🔁 Número de Epochs")
143
+ lr = gr.Number(value=1e-4, label="📈 Taxa de Aprendizado")
144
+ rank = gr.Slider(2, 32, value=4, step=2, label="📊 Rank da LoRA")
145
+ batch = gr.Slider(1, 4, value=1, step=1, label="📦 Batch Size (mantenha 1 no HF)")
146
+ train_btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
147
+
148
  with gr.Column():
149
+ output_file = gr.File(label="💾 Download da LoRA Treinada (.safetensors)")
150
+ log_box = gr.Textbox(label="📋 Log de Treinamento", lines=10)
151
+
152
+ train_btn.click(
153
+ fn=train_lora,
154
+ inputs=[image_input, trigger_word, epochs, lr, rank, batch],
155
+ outputs=output_file
156
  )
157
 
158
+ demo.launch()