Allex21 commited on
Commit
299e2b8
·
verified ·
1 Parent(s): 45f293b

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -243
app.py DELETED
@@ -1,243 +0,0 @@
1
-
2
- import gradio as gr
3
- import os
4
- import torch
5
- from accelerate import Accelerator
6
- from accelerate.utils import set_seed
7
- from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline
8
- from diffusers.optimization import get_scheduler
9
- from diffusers.training_utils import EMAModel
10
- from diffusers.models.attention_processor import LoRAAttnProcessor as DiffusersLoRAAttnProcessor
11
- from huggingface_hub import create_repo, upload_folder
12
- from PIL import Image
13
- from torch.utils.data import Dataset
14
- from torchvision import transforms
15
- from tqdm.auto import tqdm
16
- from transformers import CLIPTextModel, CLIPTokenizer
17
- import zipfile
18
- import shutil
19
- from safetensors.torch import save_file
20
-
21
- # Placeholder para o script de treinamento
22
- def train_lora(
23
- instance_data_dir: str,
24
- output_dir: str,
25
- resolution: int = 512,
26
- learning_rate: float = 1e-4,
27
- batch_size: int = 1,
28
- num_epochs: int = 1,
29
- train_prompt: str = "a photo of sks dog",
30
- pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5",
31
- ):
32
- # Configurações básicas
33
- accelerator = Accelerator(
34
- gradient_accumulation_steps=1,
35
- mixed_precision="fp16",
36
- )
37
-
38
- # Carregar tokenizer e modelo base
39
- tokenizer = CLIPTokenizer.from_pretrained(
40
- pretrained_model_name_or_path, subfolder="tokenizer"
41
- )
42
- text_encoder = CLIPTextModel.from_pretrained(
43
- pretrained_model_name_or_path, subfolder="text_encoder"
44
- )
45
- vae = AutoencoderKL.from_pretrained(
46
- pretrained_model_name_or_path, subfolder="vae"
47
- )
48
- unet = UNet2DConditionModel.from_pretrained(
49
- pretrained_model_name_or_path, subfolder="unet"
50
- )
51
-
52
- # Congelar parâmetros do VAE e Text Encoder
53
- vae.requires_grad_(False)
54
- text_encoder.requires_grad_(False)
55
-
56
- # Configurar LoRA
57
- # Adicionar adaptadores LoRA ao UNet
58
- # A função `add_adapter` do diffusers já configura os módulos LoRA e os torna treináveis.
59
- unet.add_adapter(DiffusersLoRAAttnProcessor)
60
-
61
- # Otimizador
62
- # Apenas os parâmetros do LoRA devem ser treináveis
63
- # O `add_adapter` já faz isso, então podemos simplesmente pegar os parâmetros treináveis do UNet.
64
- lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
65
-
66
- optimizer = torch.optim.AdamW(
67
- lora_parameters,
68
- lr=learning_rate,
69
- )
70
-
71
- # Scheduler
72
- lr_scheduler = get_scheduler(
73
- "constant",
74
- optimizer=optimizer,
75
- num_warmup_steps=0,
76
- num_training_steps=num_epochs * len(os.listdir(instance_data_dir)),
77
- )
78
-
79
- # Dataset e DataLoader (simplificado para o exemplo)
80
- class DreamBoothDataset(Dataset):
81
- def __init__(self, instance_data_root, tokenizer, size=512, train_prompt="a photo of sks dog"):
82
- self.instance_data_root = instance_data_root
83
- self.tokenizer = tokenizer
84
- self.size = size
85
- self.train_prompt = train_prompt
86
- self.instance_images_path = [os.path.join(instance_data_root, file_path) for file_path in os.listdir(instance_data_root) if file_path.endswith((".png", ".jpg", ".jpeg"))]
87
- self.transform = transforms.Compose(
88
- [
89
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
90
- transforms.CenterCrop(size),
91
- transforms.ToTensor(),
92
- transforms.Normalize([0.5], [0.5]),
93
- ]
94
- )
95
-
96
- def __len__(self):
97
- return len(self.instance_images_path)
98
-
99
- def __getitem__(self, index):
100
- instance_image = Image.open(self.instance_images_path[index])
101
- if not instance_image.mode == "RGB":
102
- instance_image = instance_image.convert("RGB")
103
- example = {}
104
- example["instance_images"] = self.transform(instance_image)
105
- example["instance_prompt_ids"] = self.tokenizer(self.train_prompt,
106
- truncation=True,
107
- padding="max_length",
108
- max_length=self.tokenizer.model_max_length,
109
- return_tensors="pt",
110
- ).input_ids[0]
111
- return example
112
-
113
- train_dataset = DreamBoothDataset(instance_data_dir, tokenizer, resolution, train_prompt)
114
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
115
-
116
- # Preparar para treinamento com Accelerator
117
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
118
- unet, optimizer, train_dataloader, lr_scheduler
119
- )
120
-
121
- # Loop de treinamento
122
- for epoch in range(num_epochs):
123
- unet.train()
124
- for step, batch in enumerate(train_dataloader):
125
- with accelerator.accumulate(unet):
126
- # Forward pass
127
- latents = vae.encode(batch["instance_images"]).latent_dist.sample()
128
- latents = latents * vae.config.scaling_factor
129
-
130
- noise = torch.randn_like(latents)
131
- timesteps = torch.randint(0, 1000, (batch_size,), device=latents.device).long()
132
-
133
- noisy_latents = DDPMScheduler().add_noise(latents, noise, timesteps)
134
-
135
- encoder_hidden_states = text_encoder(batch["instance_prompt_ids"])[0]
136
-
137
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
138
-
139
- # Calcular perda
140
- loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
141
-
142
- # Backward pass
143
- accelerator.backward(loss)
144
- optimizer.step()
145
- lr_scheduler.step()
146
- optimizer.zero_grad()
147
-
148
- accelerator.log({"loss": loss.item()}, step=epoch * len(train_dataloader) + step)
149
- print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")
150
-
151
- # Salvar o modelo treinado
152
- # Salvar apenas os pesos LoRA
153
- lora_state_dict = {}
154
- for name, param in unet.named_parameters():
155
- if "lora" in name:
156
- lora_state_dict[name] = param
157
-
158
- lora_path = os.path.join(output_dir, "lora_model.safetensors")
159
- # Usar safetensors para salvar o modelo
160
- save_file(lora_state_dict, lora_path)
161
-
162
- return lora_path
163
-
164
-
165
- def run_training(
166
- dataset_zip,
167
- resolution,
168
- learning_rate,
169
- batch_size,
170
- num_epochs,
171
- train_prompt,
172
- ):
173
- if dataset_zip is None:
174
- return "Por favor, faça o upload de um arquivo ZIP com seu dataset.", None
175
-
176
- # Limpar diretórios anteriores
177
- if os.path.exists("./data/dataset"):
178
- shutil.rmtree("./data/dataset")
179
- if os.path.exists("./outputs"):
180
- shutil.rmtree("./outputs")
181
- os.makedirs("./data/dataset", exist_ok=True)
182
- os.makedirs("./outputs", exist_ok=True)
183
-
184
- # Salvar e extrair o dataset
185
- dataset_dir = "./data/dataset"
186
- zip_path = os.path.join("./data", os.path.basename(dataset_zip.name))
187
- with open(zip_path, "wb") as f:
188
- f.write(dataset_zip.read())
189
-
190
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
191
- zip_ref.extractall(dataset_dir)
192
-
193
- # Iniciar treinamento
194
- output_dir = "./outputs"
195
-
196
- lora_model_path = train_lora(
197
- instance_data_dir=dataset_dir,
198
- output_dir=output_dir,
199
- resolution=resolution,
200
- learning_rate=learning_rate,
201
- batch_size=batch_size,
202
- num_epochs=num_epochs,
203
- train_prompt=train_prompt,
204
- )
205
-
206
- return f"Treinamento concluído! Modelo salvo em: {lora_model_path}", lora_model_path
207
-
208
-
209
- with gr.Blocks() as demo:
210
- gr.Markdown("# Treinador LoRA para Hugging Face Spaces")
211
-
212
- with gr.Row():
213
- with gr.Column():
214
- dataset_zip = gr.File(label="Upload do Dataset (ZIP)", file_types=[".zip"])
215
- resolution = gr.Slider(minimum=128, maximum=1024, value=512, step=128, label="Resolução da Imagem")
216
- learning_rate = gr.Number(value=1e-4, label="Learning Rate")
217
- batch_size = gr.Slider(minimum=1, maximum=8, value=1, step=1, label="Batch Size")
218
- num_epochs = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Número de Epochs")
219
- train_prompt = gr.Textbox(label="Prompt de Treinamento (ex: a photo of sks dog)", value="a photo of sks dog")
220
- train_button = gr.Button("Iniciar Treinamento")
221
-
222
- with gr.Column():
223
- output_text = gr.Textbox(label="Status do Treinamento")
224
- output_file = gr.File(label="Modelo LoRA Treinado")
225
-
226
- train_button.click(
227
- run_training,
228
- inputs=[
229
- dataset_zip,
230
- resolution,
231
- learning_rate,
232
- batch_size,
233
- num_epochs,
234
- train_prompt,
235
- ],
236
- outputs=[output_text, output_file],
237
- )
238
-
239
-
240
- if __name__ == "__main__":
241
- demo.launch(debug=True)
242
-
243
-