Allex21 commited on
Commit
b458dab
·
verified ·
1 Parent(s): c13c01c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -114
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from diffusers import StableDiffusionPipeline, UNet2DConditionModel
4
  from peft import LoraConfig, get_peft_model
5
  from transformers import CLIPTextModel
6
  from PIL import Image
@@ -9,7 +9,7 @@ from torch.utils.data import Dataset, DataLoader
9
  import gradio as gr
10
  import safetensors.torch
11
 
12
- # Configurações básicas
13
  MODEL_NAME = "runwayml/stable-diffusion-v1-5"
14
  OUTPUT_DIR = "lora_output"
15
  os.makedirs(OUTPUT_DIR, exist_ok=True)
@@ -34,125 +34,149 @@ class ImageDataset(Dataset):
34
  image = self.transform(image)
35
  return {"pixel_values": image, "caption": self.caption}
36
 
37
- def train_lora(images, trigger_word, num_epochs=10, learning_rate=1e-4, lora_rank=4, batch_size=1):
38
- device = "cuda" if torch.cuda.is_available() else "cpu"
39
-
40
- # Carrega o modelo
41
- pipe = StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
42
- pipe.to(device)
43
-
44
- # Configura LoRA no UNet
45
- unet_lora_config = LoraConfig(
46
- r=lora_rank,
47
- lora_alpha=lora_rank,
48
- target_modules=["to_q", "to_v", "to_k", "to_out.0"],
49
- lora_dropout=0.0,
50
- bias="none",
51
- )
52
- pipe.unet = get_peft_model(pipe.unet, unet_lora_config)
53
-
54
- # Configura LoRA no Text Encoder (opcional, mas recomendado)
55
- text_encoder_lora_config = LoraConfig(
56
- r=lora_rank,
57
- lora_alpha=lora_rank,
58
- target_modules=["q_proj", "v_proj"],
59
- lora_dropout=0.0,
60
- bias="none",
61
- )
62
- pipe.text_encoder = get_peft_model(pipe.text_encoder, text_encoder_lora_config)
63
-
64
- # Prepara dataset
65
- image_paths = [img.name for img in images]
66
- dataset = ImageDataset(image_paths, f"a photo of {trigger_word}")
67
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
68
-
69
- # Otimizador
70
- params_to_optimize = (
71
- list(pipe.unet.parameters()) + list(pipe.text_encoder.parameters())
72
- )
73
- optimizer = torch.optim.AdamW(params_to_optimize, lr=learning_rate)
74
-
75
- # Treinamento
76
- pipe.unet.train()
77
- pipe.text_encoder.train()
78
-
79
- for epoch in range(num_epochs):
80
- for batch in dataloader:
81
- optimizer.zero_grad()
82
-
83
- # Encode texto
84
- text_inputs = pipe.tokenizer(
85
- batch["caption"],
86
- padding="max_length",
87
- max_length=pipe.tokenizer.model_max_length,
88
- truncation=True,
89
- return_tensors="pt",
90
- )
91
- text_input_ids = text_inputs.input_ids.to(device)
92
- encoder_hidden_states = pipe.text_encoder(text_input_ids)[0]
93
-
94
- # Encode imagem (latentes)
95
- latents = pipe.vae.encode(batch["pixel_values"].to(device, dtype=torch.float16)).latent_dist.sample()
96
- latents = latents * 0.18215
97
-
98
- # Simula timestep e ruído (simplificado para demonstração)
99
- noise = torch.randn_like(latents)
100
- timesteps = torch.randint(0, 1000, (latents.shape[0],), device=latents.device).long()
101
- noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
102
-
103
- # Predição
104
- noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample
105
-
106
- # Loss e backward
107
- loss = torch.nn.functional.mse_loss(noise_pred, noise)
108
- loss.backward()
109
- optimizer.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}")
112
-
113
- # Salva LoRA
114
- lora_weights = {}
115
- for name, module in pipe.unet.named_modules():
116
- if hasattr(module, "lora_A"):
117
- lora_weights[f"lora_unet_{name}.lora_A.weight"] = module.lora_A.default.weight
118
- lora_weights[f"lora_unet_{name}.lora_B.weight"] = module.lora_B.default.weight
119
-
120
- for name, module in pipe.text_encoder.named_modules():
121
- if hasattr(module, "lora_A"):
122
- lora_weights[f"lora_te_{name}.lora_A.weight"] = module.lora_A.default.weight
123
- lora_weights[f"lora_te_{name}.lora_B.weight"] = module.lora_B.default.weight
124
-
125
- lora_path = os.path.join(OUTPUT_DIR, "lora_model.safetensors")
126
- safetensors.torch.save_file(lora_weights, lora_path)
127
-
128
- del pipe
129
- torch.cuda.empty_cache()
130
-
131
- return lora_path
132
-
133
- # Interface Gradio
134
- with gr.Blocks(title="Treinador LoRA Simplificado") as demo:
135
- gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion (Hugging Face)")
136
- gr.Markdown("Faça upload de 3-10 imagens do mesmo conceito. Use um 'trigger word' único (ex: `shs_dog`).")
137
-
 
 
 
 
 
138
  with gr.Row():
139
  with gr.Column():
140
- image_input = gr.File(label="📁 Faça upload das imagens (JPG/PNG)", file_count="multiple", file_types=["image"])
141
- trigger_word = gr.Textbox(label="🔤 Trigger Word (ex: my_cat)", placeholder="shs_dog")
142
- epochs = gr.Slider(1, 50, value=10, step=1, label="🔁 Número de Epochs")
143
- lr = gr.Number(value=1e-4, label="📈 Taxa de Aprendizado")
144
- rank = gr.Slider(2, 32, value=4, step=2, label="📊 Rank da LoRA")
145
- batch = gr.Slider(1, 4, value=1, step=1, label="📦 Batch Size (mantenha 1 no HF)")
146
- train_btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
147
 
148
  with gr.Column():
149
- output_file = gr.File(label="💾 Download da LoRA Treinada (.safetensors)")
150
- log_box = gr.Textbox(label="📋 Log de Treinamento", lines=10)
151
-
152
  train_btn.click(
153
  fn=train_lora,
154
- inputs=[image_input, trigger_word, epochs, lr, rank, batch],
155
  outputs=output_file
156
  )
157
 
158
- demo.launch()
 
 
1
  import os
2
  import torch
3
+ from diffusers import StableDiffusionPipeline
4
  from peft import LoraConfig, get_peft_model
5
  from transformers import CLIPTextModel
6
  from PIL import Image
 
9
  import gradio as gr
10
  import safetensors.torch
11
 
12
+ # Configurações
13
  MODEL_NAME = "runwayml/stable-diffusion-v1-5"
14
  OUTPUT_DIR = "lora_output"
15
  os.makedirs(OUTPUT_DIR, exist_ok=True)
 
34
  image = self.transform(image)
35
  return {"pixel_values": image, "caption": self.caption}
36
 
37
+ def train_lora(images, trigger_word, num_epochs=5, learning_rate=1e-4, lora_rank=4):
38
+ try:
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ print(f"Usando dispositivo: {device}")
41
+
42
+ # Carrega modelo com half precision para economizar memória
43
+ pipe = StableDiffusionPipeline.from_pretrained(
44
+ MODEL_NAME,
45
+ torch_dtype=torch.float16,
46
+ safety_checker=None,
47
+ requires_safety_checker=False
48
+ ).to(device)
49
+
50
+ # Ativa LoRA no UNet
51
+ unet_lora_config = LoraConfig(
52
+ r=lora_rank,
53
+ lora_alpha=lora_rank,
54
+ target_modules=["to_q", "to_v", "to_k", "to_out.0"],
55
+ lora_dropout=0.0,
56
+ bias="none",
57
+ )
58
+ pipe.unet.add_adapter(unet_lora_config)
59
+ pipe.unet.enable_adapters()
60
+
61
+ # Ativa LoRA no Text Encoder
62
+ text_encoder_lora_config = LoraConfig(
63
+ r=lora_rank,
64
+ lora_alpha=lora_rank,
65
+ target_modules=["q_proj", "v_proj"],
66
+ lora_dropout=0.0,
67
+ bias="none",
68
+ )
69
+ pipe.text_encoder.add_adapter(text_encoder_lora_config)
70
+ pipe.text_encoder.enable_adapters()
71
+
72
+ # Prepara dataset
73
+ image_paths = [img.name for img in images]
74
+ if not image_paths:
75
+ raise ValueError("Nenhuma imagem foi enviada.")
76
+
77
+ dataset = ImageDataset(image_paths, f"a photo of {trigger_word}")
78
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
79
+
80
+ # Otimizador
81
+ params_to_optimize = (
82
+ list(pipe.unet.parameters()) + list(pipe.text_encoder.parameters())
83
+ )
84
+ optimizer = torch.optim.AdamW(params_to_optimize, lr=learning_rate)
85
+
86
+ # Treinamento simplificado
87
+ pipe.unet.train()
88
+ pipe.text_encoder.train()
89
+
90
+ for epoch in range(num_epochs):
91
+ total_loss = 0.0
92
+ for step, batch in enumerate(dataloader):
93
+ optimizer.zero_grad()
94
+
95
+ # Texto
96
+ text_inputs = pipe.tokenizer(
97
+ batch["caption"],
98
+ padding="max_length",
99
+ max_length=pipe.tokenizer.model_max_length,
100
+ truncation=True,
101
+ return_tensors="pt",
102
+ )
103
+ text_input_ids = text_inputs.input_ids.to(device)
104
+ encoder_hidden_states = pipe.text_encoder(text_input_ids)[0]
105
+
106
+ # Imagem latentes
107
+ pixel_values = batch["pixel_values"].to(device, dtype=torch.float16)
108
+ latents = pipe.vae.encode(pixel_values).latent_dist.sample()
109
+ latents = latents * 0.18215
110
+
111
+ # Adiciona ruído
112
+ noise = torch.randn_like(latents)
113
+ timesteps = torch.randint(0, 1000, (latents.shape[0],), device=latents.device).long()
114
+ noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
115
+
116
+ # Prediz o ruído
117
+ noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states).sample
118
+ loss = torch.nn.functional.mse_loss(noise_pred, noise)
119
+ loss.backward()
120
+ optimizer.step()
121
+
122
+ total_loss += loss.item()
123
+ print(f"Epoch {epoch+1}, Step {step+1}, Loss: {loss.item():.4f}")
124
+
125
+ avg_loss = total_loss / len(dataloader)
126
+ print(f"Epoch {epoch+1}/{num_epochs} finalizado. Loss média: {avg_loss:.4f}")
127
+
128
+ # Salva pesos da LoRA
129
+ lora_weights = {}
130
 
131
+ # UNet
132
+ for name, module in pipe.unet.named_modules():
133
+ if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
134
+ lora_weights[f"lora_unet_{name}.lora_A.weight"] = module.lora_A["default"].weight
135
+ lora_weights[f"lora_unet_{name}.lora_B.weight"] = module.lora_B["default"].weight
136
+
137
+ # Text Encoder
138
+ for name, module in pipe.text_encoder.named_modules():
139
+ if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
140
+ lora_weights[f"lora_te_{name}.lora_A.weight"] = module.lora_A["default"].weight
141
+ lora_weights[f"lora_te_{name}.lora_B.weight"] = module.lora_B["default"].weight
142
+
143
+ # Salva
144
+ lora_path = os.path.join(OUTPUT_DIR, "lora_model.safetensors")
145
+ safetensors.torch.save_file(lora_weights, lora_path)
146
+
147
+ # Libera memória
148
+ del pipe, optimizer, dataloader, dataset
149
+ torch.cuda.empty_cache()
150
+
151
+ return lora_path
152
+
153
+ except Exception as e:
154
+ error_msg = f"Erro durante o treinamento: {str(e)}"
155
+ print(error_msg)
156
+ raise gr.Error(error_msg)
157
+
158
+ # Interface
159
+ with gr.Blocks(title="Treinador LoRA HF") as demo:
160
+ gr.Markdown("# 🧠 Treinador LoRA para Stable Diffusion")
161
+ gr.Markdown("Envie 3-8 imagens do mesmo objeto. Use um trigger word único (ex: `my_cat`).")
162
+
163
  with gr.Row():
164
  with gr.Column():
165
+ image_input = gr.File(label="📁 Upload de Imagens (JPG/PNG)", file_count="multiple", file_types=["image"])
166
+ trigger_word = gr.Textbox(label="🔤 Trigger Word", placeholder="ex: my_dog")
167
+ epochs = gr.Slider(1, 10, value=3, step=1, label="🔁 Epochs (recomendado: 3-5)")
168
+ lr = gr.Number(value=1e-4, label="📈 Learning Rate", precision=6)
169
+ rank = gr.Slider(2, 16, value=4, step=2, label="📊 LoRA Rank")
170
+ train_btn = gr.Button("🚀 Treinar LoRA", variant="primary")
 
171
 
172
  with gr.Column():
173
+ output_file = gr.File(label="💾 Download LoRA (.safetensors)")
174
+
 
175
  train_btn.click(
176
  fn=train_lora,
177
+ inputs=[image_input, trigger_word, epochs, lr, rank],
178
  outputs=output_file
179
  )
180
 
181
+ if __name__ == "__main__":
182
+ demo.launch()