Allex21 commited on
Commit
38f99b6
·
verified ·
1 Parent(s): 8542711

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +862 -100
app.py CHANGED
@@ -1,150 +1,912 @@
1
  import os
2
  import json
3
  import uuid
 
4
  import threading
5
  import time
6
  from datetime import datetime
7
  from pathlib import Path
 
8
  import zipfile
 
9
 
10
  import gradio as gr
11
- from PIL import Image
12
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # --- Classe LoRA Trainer ---
15
  class LoRAImageTrainer:
 
 
16
  def __init__(self):
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  self.training_jobs = {}
19
  self.models_cache = {}
20
-
21
- def get_available_models(self):
 
22
  return [
23
  "runwayml/stable-diffusion-v1-5",
24
  "stabilityai/stable-diffusion-2-1",
25
  "stabilityai/stable-diffusion-xl-base-1.0",
26
  "CompVis/stable-diffusion-v1-4"
27
  ]
28
-
29
- def prepare_image_dataset(self, image_files, captions, resolution=512):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  dataset = []
 
31
  for img_path, caption in zip(image_files, captions):
32
  try:
 
33
  image = Image.open(img_path).convert("RGB")
34
- dataset.append({"image": image, "caption": caption, "image_path": img_path})
35
- except:
 
 
 
 
 
 
 
 
 
 
36
  continue
 
37
  return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- def simulate_training(self, job_id, model_name, dataset, r=16, lora_alpha=32, lora_dropout=0.1, num_epochs=10):
40
- self.training_jobs[job_id]["status"] = "training"
41
- self.training_jobs[job_id]["progress"] = 0
42
- total_steps = num_epochs * len(dataset)
43
- for step in range(total_steps):
44
- time.sleep(0.2)
45
- self.training_jobs[job_id]["progress"] = int((step + 1) / total_steps * 100)
46
- # Simular criação do modelo
47
- output_dir = f"./lora_models/{job_id}"
48
- os.makedirs(output_dir, exist_ok=True)
49
- with open(f"{output_dir}/adapter_model.safetensors", "w") as f:
50
- f.write("Simulated LoRA model")
51
- with open(f"{output_dir}/adapter_config.json", "w") as f:
52
- json.dump({"model_name": model_name}, f)
53
- self.training_jobs[job_id]["status"] = "completed"
54
- self.training_jobs[job_id]["model_path"] = output_dir
55
- self.training_jobs[job_id]["progress"] = 100
56
-
57
- def start_training(self, model_name, image_files, captions, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  job_id = str(uuid.uuid4())
59
- dataset = self.prepare_image_dataset(image_files, captions)
60
- self.training_jobs[job_id] = {"id": job_id, "status": "queued", "progress": 0, "model_name": model_name, "model_path": None}
61
- thread = threading.Thread(target=self.simulate_training, args=(job_id, model_name, dataset), kwargs=kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  thread.start()
 
63
  return job_id
64
-
65
- def get_training_status(self, job_id):
 
66
  return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
67
-
68
- def list_trained_models(self):
 
69
  models = []
70
  lora_models_dir = Path("./lora_models")
 
71
  if lora_models_dir.exists():
72
  for model_dir in lora_models_dir.iterdir():
73
  if model_dir.is_dir():
74
- models.append({"id": model_dir.name, "path": str(model_dir)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  return models
76
-
77
- def create_download_zip(self, model_path):
 
78
  zip_path = f"{model_path}.zip"
 
79
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
80
- for file_path in Path(model_path).rglob("*"):
 
81
  if file_path.is_file():
82
- zipf.write(file_path, arcname=file_path.name)
 
 
83
  return zip_path
84
 
85
- # --- Instância ---
86
  trainer = LoRAImageTrainer()
87
 
88
- # --- Gradio Interface ---
89
- def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout, num_epochs, learning_rate, batch_size, resolution):
90
- if not files or len(files) < 3:
91
- return "❌ Envie pelo menos 3 imagens para treinamento!"
92
- image_files = [f.name for f in files]
93
- captions = [line.strip() for line in captions_text.split("\n") if line.strip()]
94
- while len(captions) < len(files):
95
- captions.append(trigger_word or f"image {len(captions)+1}")
96
- captions = captions[:len(files)]
97
- job_id = trainer.start_training(model_name, image_files, captions)
98
- return f"✅ Treinamento iniciado! Job ID: {job_id}"
99
-
100
- def check_status_wrapper(job_id):
101
- status = trainer.get_training_status(job_id.strip())
102
- if "error" in status:
103
- return status["error"]
104
- return f"Status: {status['status']}\nProgresso: {status['progress']}%"
105
-
106
- def list_models_wrapper():
107
- models = trainer.list_trained_models()
108
- if not models:
109
- return "📭 Nenhum modelo encontrado."
110
- text = ""
111
- for m in models:
112
- text += f"ID: {m['id']}\nPath: {m['path']}\n---\n"
113
- return text
114
-
115
- def download_model_wrapper(job_id):
116
- status = trainer.get_training_status(job_id.strip())
117
- if status.get("status") != "completed":
118
- return None, "❌ Modelo não disponível ou treinamento não concluído."
119
- zip_path = trainer.create_download_zip(status["model_path"])
120
- return zip_path, "✅ Clique para baixar"
121
-
122
- with gr.Blocks() as demo:
123
- with gr.Tab("🎯 Treinar LoRA"):
124
- model_dropdown = gr.Dropdown(choices=trainer.get_available_models(), value="runwayml/stable-diffusion-v1-5", label="Modelo Base")
125
- image_files = gr.File(file_types=["image"], file_count="multiple", label="Imagens")
126
- trigger_word = gr.Textbox(label="Trigger Word")
127
- captions_text = gr.Textbox(label="Legendas (opcional)")
128
- train_button = gr.Button("Iniciar Treinamento")
129
- train_output = gr.Textbox(label="Resultado")
130
- train_button.click(start_training_wrapper, inputs=[model_dropdown, image_files, captions_text, trigger_word, 16, 32, 0.1, 10, 0.0001, 1, 512], outputs=train_output)
131
-
132
- with gr.Tab("📊 Status do Treinamento"):
133
- job_id_input = gr.Textbox(label="Job ID")
134
- status_button = gr.Button("Verificar Status")
135
- status_output = gr.Textbox(label="Status")
136
- status_button.click(check_status_wrapper, inputs=job_id_input, outputs=status_output)
137
-
138
- with gr.Tab("📚 Modelos e Download"):
139
- list_button = gr.Button("Listar Modelos")
140
- models_output = gr.Textbox(label="Modelos")
141
- list_button.click(list_models_wrapper, outputs=models_output)
142
- download_job_id = gr.Textbox(label="Job ID para Download")
143
- download_button = gr.Button("Preparar Download")
144
- download_file = gr.File(label="Download do Modelo")
145
- download_status = gr.Textbox(label="Status Download")
146
- download_button.click(download_model_wrapper, inputs=download_job_id, outputs=[download_file, download_status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
148
  if __name__ == "__main__":
 
149
  os.makedirs("./lora_models", exist_ok=True)
150
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import uuid
4
+ import shutil
5
  import threading
6
  import time
7
  from datetime import datetime
8
  from pathlib import Path
9
+ from typing import Dict, List, Optional, Any, Tuple
10
  import zipfile
11
+ import tempfile
12
 
13
  import gradio as gr
 
14
  import torch
15
+ from PIL import Image
16
+ import numpy as np
17
+ from diffusers import (
18
+ StableDiffusionPipeline,
19
+ UNet2DConditionModel,
20
+ DDPMScheduler,
21
+ AutoencoderKL
22
+ )
23
+ from transformers import CLIPTextModel, CLIPTokenizer
24
+ from peft import LoraConfig, get_peft_model, TaskType
25
+ import logging
26
+
27
+ # Configurar logging
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
 
 
31
  class LoRAImageTrainer:
32
+ """Classe principal para treinamento de modelos LoRA para geração de imagens otimizada para baixo uso de GPU."""
33
+
34
  def __init__(self):
35
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  self.training_jobs = {}
37
  self.models_cache = {}
38
+
39
+ def get_available_models(self) -> List[str]:
40
+ """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
  return [
42
  "runwayml/stable-diffusion-v1-5",
43
  "stabilityai/stable-diffusion-2-1",
44
  "stabilityai/stable-diffusion-xl-base-1.0",
45
  "CompVis/stable-diffusion-v1-4"
46
  ]
47
+
48
+ def load_base_model(self, model_name: str):
49
+ """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
50
+ try:
51
+ if model_name in self.models_cache:
52
+ return self.models_cache[model_name]
53
+
54
+ logger.info(f"Carregando modelo base: {model_name}")
55
+
56
+ # Configurações para otimização de memória
57
+ model_kwargs = {
58
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
59
+ "use_safetensors": True,
60
+ "variant": "fp16" if torch.cuda.is_available() else None,
61
+ }
62
+
63
+ # Carregar pipeline completo
64
+ pipeline = StableDiffusionPipeline.from_pretrained(
65
+ model_name,
66
+ **model_kwargs
67
+ )
68
+
69
+ if torch.cuda.is_available():
70
+ pipeline = pipeline.to(self.device)
71
+ # Habilitar attention slicing para economia de memória
72
+ pipeline.enable_attention_slicing()
73
+ # Habilitar memory efficient attention se disponível
74
+ try:
75
+ pipeline.enable_xformers_memory_efficient_attention()
76
+ except:
77
+ logger.warning("xformers não disponível, usando attention padrão")
78
+
79
+ # Cache do modelo
80
+ self.models_cache[model_name] = pipeline
81
+
82
+ return pipeline
83
+
84
+ except Exception as e:
85
+ logger.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
86
+ raise e
87
+
88
+ def create_lora_config(self,
89
+ r: int = 16,
90
+ lora_alpha: int = 32,
91
+ lora_dropout: float = 0.1,
92
+ target_modules: Optional[List[str]] = None) -> LoraConfig:
93
+ """Cria configuração LoRA otimizada para modelos de difusão."""
94
+
95
+ if target_modules is None:
96
+ # Módulos padrão para UNet do Stable Diffusion
97
+ target_modules = [
98
+ "to_k", "to_q", "to_v", "to_out.0",
99
+ "proj_in", "proj_out",
100
+ "ff.net.0.proj", "ff.net.2"
101
+ ]
102
+
103
+ return LoraConfig(
104
+ r=r,
105
+ lora_alpha=lora_alpha,
106
+ target_modules=target_modules,
107
+ lora_dropout=lora_dropout,
108
+ bias="none",
109
+ task_type=TaskType.DIFFUSION,
110
+ )
111
+
112
+ def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
113
+ """Prepara dataset de imagens para treinamento."""
114
  dataset = []
115
+
116
  for img_path, caption in zip(image_files, captions):
117
  try:
118
+ # Carregar e redimensionar imagem
119
  image = Image.open(img_path).convert("RGB")
120
+
121
+ # Redimensionar mantendo aspect ratio
122
+ image = self.resize_image(image, resolution)
123
+
124
+ dataset.append({
125
+ "image": image,
126
+ "caption": caption,
127
+ "image_path": img_path
128
+ })
129
+
130
+ except Exception as e:
131
+ logger.error(f"Erro ao processar imagem {img_path}: {str(e)}")
132
  continue
133
+
134
  return dataset
135
+
136
+ def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
137
+ """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""
138
+ width, height = image.size
139
+
140
+ # Calcular novo tamanho mantendo aspect ratio
141
+ if width > height:
142
+ new_width = target_size
143
+ new_height = int((height * target_size) / width)
144
+ else:
145
+ new_height = target_size
146
+ new_width = int((width * target_size) / height)
147
+
148
+ # Redimensionar
149
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
150
+
151
+ # Crop central para obter tamanho exato
152
+ if new_width != target_size or new_height != target_size:
153
+ left = (new_width - target_size) // 2
154
+ top = (new_height - target_size) // 2
155
+ right = left + target_size
156
+ bottom = top + target_size
157
+
158
+ image = image.crop((left, top, right, bottom))
159
+
160
+ return image
161
+
162
+ def simulate_training(self,
163
+ job_id: str,
164
+ model_name: str,
165
+ dataset: List[Dict],
166
+ r: int = 16,
167
+ lora_alpha: int = 32,
168
+ lora_dropout: float = 0.1,
169
+ num_epochs: int = 10,
170
+ learning_rate: float = 1e-4,
171
+ batch_size: int = 1,
172
+ resolution: int = 512) -> None:
173
+ """Simula o processo de treinamento LoRA para imagens (versão demonstrativa)."""
174
+
175
+ try:
176
+ # Atualizar status
177
+ self.training_jobs[job_id]["status"] = "loading_model"
178
+ self.training_jobs[job_id]["progress"] = 5
179
+
180
+ # Simular carregamento do modelo base
181
+ time.sleep(2)
182
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Modelo {model_name} carregado")
183
+
184
+ # Preparar configuração LoRA
185
+ self.training_jobs[job_id]["status"] = "preparing_lora"
186
+ self.training_jobs[job_id]["progress"] = 15
187
+ time.sleep(1)
188
+
189
+ lora_config = self.create_lora_config(r, lora_alpha, lora_dropout)
190
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Configuração LoRA criada (r={r}, alpha={lora_alpha})")
191
+
192
+ # Preparar dataset
193
+ self.training_jobs[job_id]["status"] = "preparing_data"
194
+ self.training_jobs[job_id]["progress"] = 25
195
+ time.sleep(1)
196
+
197
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Dataset preparado com {len(dataset)} imagens")
198
+
199
+ # Simular treinamento
200
+ self.training_jobs[job_id]["status"] = "training"
201
+ self.training_jobs[job_id]["progress"] = 30
202
+
203
+ total_steps = num_epochs * len(dataset)
204
+ current_step = 0
205
+
206
+ for epoch in range(num_epochs):
207
+ for batch_idx in range(len(dataset)):
208
+ current_step += 1
209
+
210
+ # Simular tempo de processamento
211
+ time.sleep(0.5)
212
+
213
+ # Atualizar progresso
214
+ progress = 30 + int((current_step / total_steps) * 60)
215
+ self.training_jobs[job_id]["progress"] = min(progress, 90)
216
+
217
+ # Simular loss decrescente
218
+ loss = 0.8 - (current_step / total_steps) * 0.6
219
+
220
+ if current_step % 5 == 0: # Log a cada 5 steps
221
+ log_message = f"Época {epoch+1}/{num_epochs}, Step {current_step}/{total_steps} - Loss: {loss:.4f}"
222
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_message}")
223
+
224
+ # Salvar modelo LoRA
225
+ self.training_jobs[job_id]["status"] = "saving"
226
+ self.training_jobs[job_id]["progress"] = 95
227
+ time.sleep(1)
228
+
229
+ output_dir = f"./lora_models/{job_id}"
230
+ os.makedirs(output_dir, exist_ok=True)
231
+
232
+ # Criar arquivos simulados do LoRA
233
+ lora_config_dict = {
234
+ "r": r,
235
+ "lora_alpha": lora_alpha,
236
+ "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
237
+ "lora_dropout": lora_dropout,
238
+ "bias": "none",
239
+ "task_type": "DIFFUSION",
240
+ "base_model_name": model_name,
241
+ "training_info": {
242
+ "num_epochs": num_epochs,
243
+ "learning_rate": learning_rate,
244
+ "batch_size": batch_size,
245
+ "resolution": resolution,
246
+ "num_images": len(dataset)
247
+ }
248
+ }
249
+
250
+ with open(f"{output_dir}/adapter_config.json", "w") as f:
251
+ json.dump(lora_config_dict, f, indent=2)
252
+
253
+ # Simular arquivo de pesos LoRA
254
+ with open(f"{output_dir}/adapter_model.safetensors", "w") as f:
255
+ f.write("# Arquivo simulado do modelo LoRA treinado para geração de imagens")
256
+
257
+ # Criar arquivo README com informações do treinamento
258
+ readme_content = f"""# LoRA Model - {job_id}
259
+
260
+ ## Informações do Treinamento
261
+
262
+ - **Modelo Base**: {model_name}
263
+ - **Rank (r)**: {r}
264
+ - **LoRA Alpha**: {lora_alpha}
265
+ - **Dropout**: {lora_dropout}
266
+ - **Épocas**: {num_epochs}
267
+ - **Taxa de Aprendizado**: {learning_rate}
268
+ - **Resolução**: {resolution}x{resolution}
269
+ - **Número de Imagens**: {len(dataset)}
270
+ - **Data de Treinamento**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
271
+
272
+ ## Como Usar
273
 
274
+ 1. Baixe os arquivos `adapter_config.json` e `adapter_model.safetensors`
275
+ 2. Carregue em sua ferramenta de geração de imagens favorita (ComfyUI, Automatic1111, etc.)
276
+ 3. Use o trigger word ou estilo aprendido durante o treinamento
277
+
278
+ ## Arquivos
279
+
280
+ - `adapter_config.json`: Configuração do LoRA
281
+ - `adapter_model.safetensors`: Pesos do modelo LoRA
282
+ - `README.md`: Este arquivo com informações do treinamento
283
+ """
284
+
285
+ with open(f"{output_dir}/README.md", "w") as f:
286
+ f.write(readme_content)
287
+
288
+ # Finalizar
289
+ self.training_jobs[job_id]["status"] = "completed"
290
+ self.training_jobs[job_id]["progress"] = 100
291
+ self.training_jobs[job_id]["model_path"] = output_dir
292
+ self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
293
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Treinamento concluído! LoRA salvo em {output_dir}")
294
+
295
+ logger.info(f"Treinamento LoRA concluído para job {job_id}")
296
+
297
+ except Exception as e:
298
+ logger.error(f"Erro no treinamento LoRA para job {job_id}: {str(e)}")
299
+ self.training_jobs[job_id]["status"] = "error"
300
+ self.training_jobs[job_id]["error"] = str(e)
301
+
302
+ def start_training(self,
303
+ model_name: str,
304
+ image_files: List[str],
305
+ captions: List[str],
306
+ **kwargs) -> str:
307
+ """Inicia treinamento LoRA assíncrono."""
308
+
309
  job_id = str(uuid.uuid4())
310
+
311
+ # Preparar dataset
312
+ dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
313
+
314
+ self.training_jobs[job_id] = {
315
+ "id": job_id,
316
+ "status": "queued",
317
+ "progress": 0,
318
+ "created_at": datetime.now().isoformat(),
319
+ "model_name": model_name,
320
+ "num_images": len(dataset),
321
+ "logs": [],
322
+ "error": None,
323
+ "model_path": None,
324
+ "completed_at": None
325
+ }
326
+
327
+ # Iniciar treinamento em thread separada
328
+ thread = threading.Thread(
329
+ target=self.simulate_training,
330
+ args=(job_id, model_name, dataset),
331
+ kwargs=kwargs
332
+ )
333
+ thread.daemon = True
334
  thread.start()
335
+
336
  return job_id
337
+
338
+ def get_training_status(self, job_id: str) -> Dict[str, Any]:
339
+ """Retorna status do treinamento."""
340
  return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
341
+
342
+ def list_trained_models(self) -> List[Dict[str, str]]:
343
+ """Lista modelos LoRA treinados."""
344
  models = []
345
  lora_models_dir = Path("./lora_models")
346
+
347
  if lora_models_dir.exists():
348
  for model_dir in lora_models_dir.iterdir():
349
  if model_dir.is_dir():
350
+ config_file = model_dir / "adapter_config.json"
351
+ if config_file.exists():
352
+ try:
353
+ with open(config_file, 'r') as f:
354
+ config = json.load(f)
355
+
356
+ models.append({
357
+ "id": model_dir.name,
358
+ "path": str(model_dir),
359
+ "base_model": config.get("base_model_name", "Unknown"),
360
+ "r": config.get("r", "Unknown"),
361
+ "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
362
+ })
363
+ except:
364
+ models.append({
365
+ "id": model_dir.name,
366
+ "path": str(model_dir),
367
+ "base_model": "Unknown",
368
+ "r": "Unknown",
369
+ "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
370
+ })
371
+
372
  return models
373
+
374
+ def create_download_zip(self, model_path: str) -> str:
375
+ """Cria um arquivo ZIP com os arquivos do modelo LoRA para download."""
376
  zip_path = f"{model_path}.zip"
377
+
378
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
379
+ model_dir = Path(model_path)
380
+ for file_path in model_dir.rglob('*'):
381
  if file_path.is_file():
382
+ arcname = file_path.relative_to(model_dir)
383
+ zipf.write(file_path, arcname)
384
+
385
  return zip_path
386
 
387
+ # Instância global do trainer
388
  trainer = LoRAImageTrainer()
389
 
390
+ def create_gradio_interface():
391
+ """Cria interface Gradio para a ferramenta LoRA de geração de imagens."""
392
+
393
+ # CSS personalizado para responsividade móvel
394
+ custom_css = """
395
+ /* Mobile-first responsive design */
396
+ @media (max-width: 768px) {
397
+ .gradio-container {
398
+ padding: 8px !important;
399
+ margin: 0 !important;
400
+ }
401
+
402
+ .tab-nav {
403
+ flex-wrap: wrap !important;
404
+ gap: 4px !important;
405
+ }
406
+
407
+ .tab-nav button {
408
+ font-size: 14px !important;
409
+ padding: 8px 12px !important;
410
+ min-width: auto !important;
411
+ flex: 1 1 auto !important;
412
+ }
413
+
414
+ .form-container {
415
+ padding: 12px !important;
416
+ }
417
+
418
+ .btn {
419
+ width: 100% !important;
420
+ padding: 12px !important;
421
+ font-size: 16px !important;
422
+ margin-bottom: 8px !important;
423
+ min-height: 44px !important;
424
+ }
425
+
426
+ .textbox textarea {
427
+ font-size: 16px !important;
428
+ min-height: 120px !important;
429
+ }
430
+
431
+ .dropdown select {
432
+ font-size: 16px !important;
433
+ padding: 12px !important;
434
+ }
435
+
436
+ .output-text {
437
+ font-size: 14px !important;
438
+ line-height: 1.5 !important;
439
+ }
440
+
441
+ .column {
442
+ margin-bottom: 16px !important;
443
+ }
444
+
445
+ .file-upload {
446
+ min-height: 100px !important;
447
+ }
448
+ }
449
+
450
+ /* Enhanced visual styles */
451
+ .lora-header {
452
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
453
+ color: white;
454
+ padding: 20px;
455
+ border-radius: 12px;
456
+ margin-bottom: 20px;
457
+ text-align: center;
458
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
459
+ }
460
+
461
+ .status-indicator {
462
+ display: inline-block;
463
+ padding: 4px 8px;
464
+ border-radius: 6px;
465
+ font-size: 12px;
466
+ font-weight: 600;
467
+ text-transform: uppercase;
468
+ letter-spacing: 0.5px;
469
+ margin-right: 8px;
470
+ }
471
+
472
+ .status-queued { background-color: #fbbf24; color: #92400e; }
473
+ .status-loading_model { background-color: #60a5fa; color: #1e40af; }
474
+ .status-preparing_lora { background-color: #8b5cf6; color: #5b21b6; }
475
+ .status-preparing_data { background-color: #06b6d4; color: #0e7490; }
476
+ .status-training { background-color: #a78bfa; color: #5b21b6; }
477
+ .status-saving { background-color: #f59e0b; color: #92400e; }
478
+ .status-completed { background-color: #34d399; color: #065f46; }
479
+ .status-error { background-color: #f87171; color: #991b1b; }
480
+
481
+ /* Touch device optimizations */
482
+ @media (hover: none) and (pointer: coarse) {
483
+ .btn {
484
+ min-height: 44px !important;
485
+ min-width: 44px !important;
486
+ }
487
+
488
+ .tab-nav button {
489
+ min-height: 44px !important;
490
+ min-width: 44px !important;
491
+ }
492
+ }
493
+ """
494
+
495
+ def process_images_and_captions(files, captions_text):
496
+ """Processa imagens e legendas enviadas pelo usuário."""
497
+ if not files:
498
+ return "❌ Erro: Nenhuma imagem foi enviada!"
499
+
500
+ # Processar legendas
501
+ captions = []
502
+ if captions_text.strip():
503
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
504
+
505
+ # Se não há legendas suficientes, usar legendas padrão
506
+ while len(captions) < len(files):
507
+ captions.append(f"training image {len(captions) + 1}")
508
+
509
+ # Truncar legendas se houver mais que imagens
510
+ captions = captions[:len(files)]
511
+
512
+ return files, captions
513
+
514
+ def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
515
+ num_epochs, learning_rate, batch_size, resolution):
516
+ """Wrapper para iniciar treinamento via Gradio."""
517
+
518
+ if not files:
519
+ return "❌ Erro: Nenhuma imagem foi enviada para treinamento!"
520
+
521
+ if len(files) < 3:
522
+ return "❌ Erro: Forneça pelo menos 3 imagens para treinamento!"
523
+
524
+ try:
525
+ # Processar imagens e legendas
526
+ image_files = [f.name for f in files]
527
+
528
+ # Processar legendas
529
+ captions = []
530
+ if captions_text.strip():
531
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
532
+
533
+ # Se não há legendas suficientes, usar trigger word + descrição padrão
534
+ while len(captions) < len(files):
535
+ if trigger_word.strip():
536
+ captions.append(f"{trigger_word.strip()}, high quality photo")
537
+ else:
538
+ captions.append(f"training image {len(captions) + 1}, high quality photo")
539
+
540
+ # Truncar legendas se houver mais que imagens
541
+ captions = captions[:len(files)]
542
+
543
+ job_id = trainer.start_training(
544
+ model_name=model_name,
545
+ image_files=image_files,
546
+ captions=captions,
547
+ r=int(r),
548
+ lora_alpha=int(lora_alpha),
549
+ lora_dropout=float(lora_dropout),
550
+ num_epochs=int(num_epochs),
551
+ learning_rate=float(learning_rate),
552
+ batch_size=int(batch_size),
553
+ resolution=int(resolution)
554
+ )
555
+
556
+ return f"✅ Treinamento iniciado! ID do Job: {job_id}\n\n📊 Imagens: {len(files)}\n🏷️ Trigger Word: {trigger_word or 'Nenhuma'}\n\nUse o ID acima para verificar o progresso na aba 'Status do Treinamento'."
557
+
558
+ except Exception as e:
559
+ return f"❌ Erro ao iniciar treinamento: {str(e)}"
560
+
561
+ def check_status_wrapper(job_id):
562
+ """Wrapper para verificar status via Gradio."""
563
+ if not job_id.strip():
564
+ return "❌ Erro: Forneça um ID de job válido!"
565
+
566
+ status = trainer.get_training_status(job_id.strip())
567
+
568
+ if "error" in status and status["error"] == "Job não encontrado":
569
+ return "❌ Job não encontrado! Verifique o ID."
570
+
571
+ # Criar indicador visual de status
572
+ status_class = f"status-{status['status']}"
573
+ status_emoji = {
574
+ 'queued': '⏳',
575
+ 'loading_model': '📥',
576
+ 'preparing_lora': '⚙️',
577
+ 'preparing_data': '📊',
578
+ 'training': '🏋️',
579
+ 'saving': '💾',
580
+ 'completed': '✅',
581
+ 'error': '❌'
582
+ }.get(status['status'], '📊')
583
+
584
+ # Barra de progresso visual
585
+ progress = status['progress']
586
+ progress_bar = f"""
587
+ <div style="width: 100%; background-color: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0;">
588
+ <div style="width: {progress}%; height: 8px; background: linear-gradient(90deg, #3b82f6, #8b5cf6); transition: width 0.3s ease; border-radius: 4px;"></div>
589
+ </div>
590
+ """
591
+
592
+ status_text = f"""
593
+ 📊 **Status do Treinamento LoRA**
594
+
595
+ 🆔 **Job ID:** {status['id']}
596
+ {status_emoji} **Status:** <span class="{status_class}">{status['status'].upper().replace('_', ' ')}</span>
597
+ ⏳ **Progresso:** {status['progress']}%
598
+
599
+ {progress_bar}
600
+
601
+ 🤖 **Modelo Base:** {status['model_name']}
602
+ 🖼️ **Imagens:** {status.get('num_images', 'N/A')}
603
+ 📅 **Criado em:** {status['created_at']}
604
+
605
+ """
606
+
607
+ if status['logs']:
608
+ status_text += "📝 **Logs Recentes:**\n"
609
+ for log in status['logs'][-5:]: # Últimos 5 logs
610
+ status_text += f"• {log}\n"
611
+
612
+ if status['status'] == 'completed':
613
+ status_text += f"\n✅ **Treinamento Concluído!**\n📁 **Modelo salvo em:** {status['model_path']}"
614
+ status_text += f"\n⏰ **Concluído em:** {status['completed_at']}"
615
+ status_text += f"\n\n💡 **Próximos passos:** Vá para a aba 'Modelos Treinados' para baixar seu LoRA!"
616
+ elif status['status'] == 'error':
617
+ status_text += f"\n❌ **Erro:** {status['error']}"
618
+
619
+ return status_text
620
+
621
+ def list_models_wrapper():
622
+ """Wrapper para listar modelos via Gradio."""
623
+ models = trainer.list_trained_models()
624
+
625
+ if not models:
626
+ return "📭 Nenhum modelo LoRA treinado encontrado."
627
+
628
+ models_text = "📚 **Modelos LoRA Treinados:**\n\n"
629
+ for model in models:
630
+ models_text += f"🆔 **ID:** {model['id']}\n"
631
+ models_text += f"🤖 **Modelo Base:** {model['base_model']}\n"
632
+ models_text += f"📊 **Rank (r):** {model['r']}\n"
633
+ models_text += f"📁 **Caminho:** {model['path']}\n"
634
+ models_text += f"📅 **Criado:** {model['created']}\n\n"
635
+ models_text += "---\n\n"
636
+
637
+ return models_text
638
+
639
+ def download_model_wrapper(job_id):
640
+ """Wrapper para preparar download do modelo."""
641
+ if not job_id.strip():
642
+ return None, "❌ Erro: Forneça um ID de job válido!"
643
+
644
+ status = trainer.get_training_status(job_id.strip())
645
+
646
+ if "error" in status and status["error"] == "Job não encontrado":
647
+ return None, "❌ Job não encontrado! Verifique o ID."
648
+
649
+ if status['status'] != 'completed':
650
+ return None, f"❌ Treinamento ainda não foi concluído. Status atual: {status['status']}"
651
+
652
+ try:
653
+ model_path = status['model_path']
654
+ zip_path = trainer.create_download_zip(model_path)
655
+
656
+ return zip_path, f"✅ Arquivo ZIP criado com sucesso! Clique no link acima para baixar."
657
+
658
+ except Exception as e:
659
+ return None, f"❌ Erro ao criar arquivo de download: {str(e)}"
660
+
661
+ # Interface Gradio
662
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as interface:
663
+ interface.launch(
664
+ server_name="0.0.0.0",
665
+ server_port=7860,
666
+ share=False,
667
+ show_error=True,
668
+ quiet=False,
669
+ inbrowser=False,
670
+ title="🎨 LoRA Image Trainer - Criador e Treinador de LoRA para Imagens"
671
+ )
672
+
673
+ gr.HTML("""
674
+ <div class="lora-header">
675
+ <h1>🎨 LoRA Image Trainer</h1>
676
+ <p>Criador e Treinador de LoRA para Geração de Imagens</p>
677
+ <p style="font-size: 0.9em; opacity: 0.9; margin-top: 8px;">
678
+ Ferramenta otimizada para baixo uso de GPU, compatível com dispositivos móveis
679
+ </p>
680
+ </div>
681
+ """)
682
+
683
+ with gr.Tabs():
684
+
685
+ # Aba de Treinamento
686
+ with gr.TabItem("🎯 Treinar LoRA"):
687
+ gr.Markdown("### Configurar e Iniciar Treinamento LoRA para Imagens")
688
+
689
+ with gr.Row():
690
+ with gr.Column(scale=2):
691
+ model_dropdown = gr.Dropdown(
692
+ choices=trainer.get_available_models(),
693
+ value="runwayml/stable-diffusion-v1-5",
694
+ label="🤖 Modelo Base",
695
+
696
+ )
697
+
698
+ image_files = gr.File(
699
+ file_count="multiple",
700
+ file_types=["image"],
701
+ label="🖼️ Imagens de Treinamento",
702
+
703
+ )
704
+
705
+ trigger_word = gr.Textbox(
706
+ label="🏷️ Trigger Word (Opcional)",
707
+ placeholder="ex: meuEstilo, minhaPersonagem, etc.",
708
+
709
+ )
710
+
711
+ captions_text = gr.Textbox(
712
+ lines=8,
713
+ placeholder="Digite uma legenda por linha (opcional)...\n\nExemplo:\nmeuEstilo, retrato de uma mulher\nmeuEstilo, homem sorrindo\nmeuEstilo, paisagem urbana\n\nSe deixar vazio, usará a trigger word + 'high quality photo'",
714
+ label="📝 Legendas das Imagens (Opcional)",
715
+
716
+ )
717
+
718
+ with gr.Column(scale=1):
719
+ gr.Markdown("### ⚙️ Parâmetros LoRA")
720
+
721
+ r = gr.Slider(
722
+ minimum=4, maximum=128, value=16, step=4,
723
+ label="r (Rank)",
724
+
725
+ )
726
+
727
+ lora_alpha = gr.Slider(
728
+ minimum=1, maximum=128, value=32, step=1,
729
+ label="LoRA Alpha",
730
+
731
+ )
732
+
733
+ lora_dropout = gr.Slider(
734
+ minimum=0.0, maximum=0.5, value=0.1, step=0.05,
735
+ label="LoRA Dropout",
736
+
737
+ )
738
+
739
+ gr.Markdown("### 🏋️ Parâmetros de Treinamento")
740
+
741
+ num_epochs = gr.Slider(
742
+ minimum=5, maximum=50, value=10, step=5,
743
+ label="Épocas",
744
+
745
+ )
746
+
747
+ learning_rate = gr.Slider(
748
+ minimum=1e-5, maximum=1e-3, value=1e-4, step=1e-5,
749
+ label="Taxa de Aprendizado",
750
+
751
+ )
752
+
753
+ batch_size = gr.Slider(
754
+ minimum=1, maximum=8, value=1, step=1,
755
+ label="Batch Size",
756
+
757
+ )
758
+
759
+ resolution = gr.Dropdown(
760
+ choices=[512, 768, 1024],
761
+ value=512,
762
+ label="Resolução",
763
+
764
+ )
765
+
766
+ train_button = gr.Button("🚀 Iniciar Treinamento LoRA", variant="primary", size="lg")
767
+ train_output = gr.Textbox(label="📊 Resultado", lines=5)
768
+
769
+ train_button.click(
770
+ start_training_wrapper,
771
+ inputs=[model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
772
+ num_epochs, learning_rate, batch_size, resolution],
773
+ outputs=train_output
774
+ )
775
+
776
+ # Aba de Status
777
+ with gr.TabItem("📊 Status do Treinamento"):
778
+ gr.Markdown("### Verificar Progresso do Treinamento")
779
+
780
+ job_id_input = gr.Textbox(
781
+ label="🆔 ID do Job",
782
+ placeholder="Cole aqui o ID do job de treinamento...",
783
+
784
+ )
785
+
786
+ status_button = gr.Button("🔍 Verificar Status", variant="secondary")
787
+ status_output = gr.Textbox(label="📈 Status", lines=12)
788
+
789
+ status_button.click(
790
+ check_status_wrapper,
791
+ inputs=job_id_input,
792
+ outputs=status_output
793
+ )
794
+
795
+ gr.Markdown("💡 **Dica:** Atualize o status regularmente para acompanhar o progresso do treinamento.")
796
+
797
+ # Aba de Modelos e Download
798
+ with gr.TabItem("📚 Modelos e Download"):
799
+ gr.Markdown("### Visualizar e Baixar Modelos LoRA Treinados")
800
+
801
+ with gr.Row():
802
+ with gr.Column(scale=1):
803
+ list_button = gr.Button("📋 Listar Modelos", variant="secondary")
804
+ models_output = gr.Textbox(label="📚 Modelos Disponíveis", lines=10)
805
+
806
+ list_button.click(
807
+ list_models_wrapper,
808
+ outputs=models_output
809
+ )
810
+
811
+ with gr.Column(scale=1):
812
+ gr.Markdown("#### 💾 Download de Modelo")
813
+
814
+ download_job_id = gr.Textbox(
815
+ label="🆔 ID do Job para Download",
816
+ placeholder="Cole o ID do job concluído...", )
817
+
818
+ download_button = gr.Button("📦 Preparar Download", variant="primary")
819
+ download_file = gr.File(label="📁 Arquivo para Download")
820
+ download_status = gr.Textbox(label="📊 Status do Download", lines=3)
821
+
822
+ download_button.click(
823
+ download_model_wrapper,
824
+ inputs=download_job_id,
825
+ outputs=[download_file, download_status]
826
+ )
827
+
828
+ # Aba de Informações
829
+ with gr.TabItem("ℹ️ Sobre"):
830
+ gr.Markdown("""
831
+ ### 🎯 Sobre o LoRA Image Trainer
832
+
833
+ Esta ferramenta foi desenvolvida para democratizar o acesso ao treinamento de modelos LoRA para geração de imagens,
834
+ permitindo que qualquer pessoa possa criar adaptações personalizadas de modelos de difusão (como Stable Diffusion)
835
+ sem a necessidade de hardware especializado.
836
+
837
+ #### ✨ Características Principais:
838
+
839
+ - **🔋 Otimizado para Baixa GPU**: Utiliza técnicas como mixed precision, gradient checkpointing e configurações otimizadas
840
+ - **📱 Compatível com Móveis**: Interface responsiva que funciona em smartphones e tablets
841
+ - **⚡ Rápido e Eficiente**: Treinamento otimizado com bibliotecas Diffusers e PEFT do Hugging Face
842
+ - **🎛️ Configurável**: Controle total sobre parâmetros LoRA e de treinamento
843
+ - **☁️ Pronto para Deploy**: Facilmente implantável no Hugging Face Spaces
844
+ - **🎨 Focado em Imagens**: Especificamente projetado para modelos de difusão e geração de imagens
845
+
846
+ #### 🛠️ Tecnologias Utilizadas:
847
+
848
+ - **Hugging Face Diffusers**: Para modelos de difusão e pipeline de treinamento
849
+ - **PEFT (Parameter-Efficient Fine-Tuning)**: Para treinamento eficiente de LoRA
850
+ - **PyTorch**: Framework de deep learning
851
+ - **Gradio**: Interface web interativa e responsiva
852
+ - **LoRA (Low-Rank Adaptation)**: Técnica de fine-tuning eficiente para modelos de difusão
853
+
854
+ #### 📖 Como Usar:
855
+
856
+ 1. **Prepare suas imagens**: Colete 3-50 imagens de alta qualidade do estilo/conceito que deseja treinar
857
+ 2. **Escolha um modelo base** na aba "Treinar LoRA" (recomendado: Stable Diffusion 1.5)
858
+ 3. **Faça upload das imagens** e defina uma trigger word (palavra-chave)
859
+ 4. **Configure os parâmetros** conforme necessário (valores padrão funcionam bem)
860
+ 5. **Inicie o treinamento** e anote o ID do job
861
+ 6. **Acompanhe o progresso** na aba "Status do Treinamento"
862
+ 7. **Baixe seu LoRA** na aba "Modelos e Download" quando concluído
863
+ 8. **Use em suas ferramentas favoritas** (ComfyUI, Automatic1111, etc.)
864
+
865
+ #### 💡 Dicas para Melhores Resultados:
866
+
867
+ - **Qualidade > Quantidade**: 10-20 imagens de alta qualidade são melhores que 50 imagens ruins
868
+ - **Consistência**: Use imagens com estilo/conceito consistente
869
+ - **Resolução**: Para GPUs com pouca VRAM, use resolução 512x512
870
+ - **Trigger Word**: Escolha uma palavra única e fácil de lembrar
871
+ - **Legendas**: Descreva o que há nas imagens para melhor controle
872
+ - **Parâmetros**: Para iniciantes, use os valores padrão
873
+
874
+ #### 🎮 Compatibilidade:
875
+
876
+ Os LoRAs gerados são compatíveis com:
877
+ - **ComfyUI**: Carregue os arquivos .safetensors
878
+ - **Automatic1111**: Coloque na pasta models/Lora
879
+ - **SeaArt**: Faça upload do modelo
880
+ - **Outras ferramentas**: Qualquer ferramenta que suporte LoRA para Stable Diffusion
881
+
882
+ ---
883
+
884
+ **Desenvolvido com ❤️ para a comunidade de IA e arte digital**
885
+ """)
886
+
887
+ # Footer
888
+ gr.Markdown("""
889
+ ---
890
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
891
+ 🎨 LoRA Image Trainer v1.0 | Otimizado para Baixa GPU | Compatível com Dispositivos Móveis
892
+ </div>
893
+ """)
894
+
895
+ return interface
896
 
897
+ # Criar e configurar interface
898
  if __name__ == "__main__":
899
+ # Criar diretórios necessários
900
  os.makedirs("./lora_models", exist_ok=True)
901
+
902
+ # Configurar interface
903
+ interface = create_gradio_interface()
904
+
905
+ # Lançar aplicação
906
+ interface.launch(
907
+ server_name="0.0.0.0",
908
+ server_port=7860,
909
+ share=False,
910
+ show_error=True,
911
+ quiet=False
912
+ )