Allex21 commited on
Commit
8f49e8c
·
verified ·
1 Parent(s): a8c8b53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +893 -85
app.py CHANGED
@@ -1,99 +1,907 @@
1
- import gradio as gr
2
  import os
 
3
  import uuid
 
4
  import threading
 
 
5
  from pathlib import Path
6
- from diffusers import StableDiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
7
- from diffusers.pipelines.lora import save_lora_weights, LoRAConfig
 
 
 
8
  import torch
9
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Pasta para salvar imagens e LoRAs
12
- os.makedirs("uploads", exist_ok=True)
13
- os.makedirs("lora_models", exist_ok=True)
14
 
15
- # Jobs ativos
16
- training_jobs = {}
 
17
 
18
- def save_uploaded_images(images):
19
- image_paths = []
20
- for i, img in enumerate(images):
21
- path = f"uploads/{uuid.uuid4()}.png"
22
- img.save(path)
23
- image_paths.append(path)
24
- return image_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def train_lora(model_name, images, rank=4, steps=20):
27
- job_id = str(uuid.uuid4())
28
- training_jobs[job_id] = {"status": "Iniciando...", "progress": 0, "logs": []}
29
 
30
- def train():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  try:
32
- training_jobs[job_id]["logs"].append("Carregando modelo base...")
33
- training_jobs[job_id]["status"] = "Carregando modelo..."
34
- device = "cuda" if torch.cuda.is_available() else "cpu"
35
-
36
- # Carregando modelo
37
- pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16 if device=="cuda" else torch.float32)
38
- pipe.to(device)
39
- training_jobs[job_id]["logs"].append(f"Modelo {model_name} carregado no {device}")
40
- training_jobs[job_id]["progress"] = 10
41
-
42
- # Preparar LoRA
43
- lora_config = LoRAConfig(r=rank, alpha=16)
44
- unet = pipe.unet
45
- training_jobs[job_id]["status"] = "Treinando LoRA..."
46
- training_jobs[job_id]["progress"] = 20
47
-
48
- for i, img_path in enumerate(images):
49
- training_jobs[job_id]["logs"].append(f"Processando imagem {i+1}/{len(images)}: {img_path}")
50
- training_jobs[job_id]["progress"] = 20 + int((i+1)/len(images)*70)
51
- # Aqui você pode adicionar código de treinamento real se quiser
52
- torch.cuda.empty_cache() if device=="cuda" else None
53
-
54
- # Salvar LoRA
55
- lora_file = f"lora_models/{job_id}.pt"
56
- save_lora_weights(unet, lora_file, lora_config)
57
- training_jobs[job_id]["status"] = "Concluído"
58
- training_jobs[job_id]["progress"] = 100
59
- training_jobs[job_id]["logs"].append(f"LoRA salva em {lora_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  except Exception as e:
62
- training_jobs[job_id]["status"] = "Erro"
63
- training_jobs[job_id]["logs"].append(str(e))
64
-
65
- threading.Thread(target=train).start()
66
- return job_id
67
-
68
- def start_training(model_name, images):
69
- if not images:
70
- return "", 0, "Nenhuma imagem enviada", ""
71
- image_paths = save_uploaded_images(images)
72
- job_id = train_lora(model_name, image_paths)
73
- return job_id, 0, "Iniciando...", ""
74
-
75
- def check_status(job_id):
76
- job = training_jobs.get(job_id)
77
- if not job:
78
- return 0, "Job não encontrado", ""
79
- return job["progress"], job["status"], "\n".join(job["logs"])
80
-
81
- with gr.Blocks() as demo:
82
- gr.Markdown("## Treinador de LoRA Real")
83
- with gr.Row():
84
- model_input = gr.Dropdown(["runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1"], label="Modelo Base")
85
- images_input = gr.File(file_types=[".png", ".jpg", ".jpeg"], file_types_description="Imagens", type="pil", file_count="multiple")
86
- start_btn = gr.Button("Iniciar Treinamento")
87
- job_id_holder = gr.Textbox(visible=False)
88
- progress_bar = gr.Progress(label="Progresso")
89
- status_text = gr.Textbox(label="Status", interactive=False)
90
- logs_box = gr.Textbox(label="Logs", interactive=False)
91
-
92
- start_btn.click(fn=start_training, inputs=[model_input, images_input], outputs=[job_id_holder, progress_bar, status_text, logs_box])
93
-
94
- def update(job_id):
95
- return check_status(job_id)
96
-
97
- gr.Interval(update, inputs=job_id_holder, outputs=[progress_bar, status_text, logs_box], every=1)
98
-
99
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
  import uuid
4
+ import shutil
5
  import threading
6
+ import time
7
+ from datetime import datetime
8
  from pathlib import Path
9
+ from typing import Dict, List, Optional, Any, Tuple
10
+ import zipfile
11
+ import tempfile
12
+
13
+ import gradio as gr
14
  import torch
15
  from PIL import Image
16
+ import numpy as np
17
+ from diffusers import (
18
+ StableDiffusionPipeline,
19
+ UNet2DConditionModel,
20
+ DDPMScheduler,
21
+ AutoencoderKL
22
+ )
23
+ from transformers import CLIPTextModel, CLIPTokenizer
24
+ from peft import LoraConfig, get_peft_model, TaskType
25
+ import logging
26
+
27
+ # Configurar logging
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+ class LoRAImageTrainer:
32
+ """Classe principal para treinamento de modelos LoRA para geração de imagens otimizada para baixo uso de GPU."""
33
+
34
+ def __init__(self):
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ self.training_jobs = {}
37
+ self.models_cache = {}
38
+
39
+ def get_available_models(self) -> List[str]:
40
+ """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
+ return [
42
+ "runwayml/stable-diffusion-v1-5",
43
+ "stabilityai/stable-diffusion-2-1",
44
+ "stabilityai/stable-diffusion-xl-base-1.0",
45
+ "CompVis/stable-diffusion-v1-4"
46
+ ]
47
+
48
+ def load_base_model(self, model_name: str):
49
+ """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
50
+ try:
51
+ if model_name in self.models_cache:
52
+ return self.models_cache[model_name]
53
+
54
+ logger.info(f"Carregando modelo base: {model_name}")
55
+
56
+ # Configurações para otimização de memória
57
+ model_kwargs = {
58
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
59
+ "use_safetensors": True,
60
+ "variant": "fp16" if torch.cuda.is_available() else None,
61
+ }
62
+
63
+ # Carregar pipeline completo
64
+ pipeline = StableDiffusionPipeline.from_pretrained(
65
+ model_name,
66
+ **model_kwargs
67
+ )
68
+
69
+ if torch.cuda.is_available():
70
+ pipeline = pipeline.to(self.device)
71
+ # Habilitar attention slicing para economia de memória
72
+ pipeline.enable_attention_slicing()
73
+ # Habilitar memory efficient attention se disponível
74
+ try:
75
+ pipeline.enable_xformers_memory_efficient_attention()
76
+ except:
77
+ logger.warning("xformers não disponível, usando attention padrão")
78
+
79
+ # Cache do modelo
80
+ self.models_cache[model_name] = pipeline
81
+
82
+ return pipeline
83
+
84
+ except Exception as e:
85
+ logger.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
86
+ raise e
87
+
88
+ def create_lora_config(self,
89
+ r: int = 16,
90
+ lora_alpha: int = 32,
91
+ lora_dropout: float = 0.1,
92
+ target_modules: Optional[List[str]] = None) -> LoraConfig:
93
+ """Cria configuração LoRA otimizada para modelos de difusão."""
94
+
95
+ if target_modules is None:
96
+ # Módulos padrão para UNet do Stable Diffusion
97
+ target_modules = [
98
+ "to_k", "to_q", "to_v", "to_out.0",
99
+ "proj_in", "proj_out",
100
+ "ff.net.0.proj", "ff.net.2"
101
+ ]
102
+
103
+ return LoraConfig(
104
+ r=r,
105
+ lora_alpha=lora_alpha,
106
+ target_modules=target_modules,
107
+ lora_dropout=lora_dropout,
108
+ bias="none",
109
+ task_type=TaskType.DIFFUSION,
110
+ )
111
+
112
+ def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
113
+ """Prepara dataset de imagens para treinamento."""
114
+ dataset = []
115
+
116
+ for img_path, caption in zip(image_files, captions):
117
+ try:
118
+ # Carregar e redimensionar imagem
119
+ image = Image.open(img_path).convert("RGB")
120
+
121
+ # Redimensionar mantendo aspect ratio
122
+ image = self.resize_image(image, resolution)
123
+
124
+ dataset.append({
125
+ "image": image,
126
+ "caption": caption,
127
+ "image_path": img_path
128
+ })
129
+
130
+ except Exception as e:
131
+ logger.error(f"Erro ao processar imagem {img_path}: {str(e)}")
132
+ continue
133
+
134
+ return dataset
135
+
136
+ def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
137
+ """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""
138
+ width, height = image.size
139
+
140
+ # Calcular novo tamanho mantendo aspect ratio
141
+ if width > height:
142
+ new_width = target_size
143
+ new_height = int((height * target_size) / width)
144
+ else:
145
+ new_height = target_size
146
+ new_width = int((width * target_size) / height)
147
+
148
+ # Redimensionar
149
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
150
+
151
+ # Crop central para obter tamanho exato
152
+ if new_width != target_size or new_height != target_size:
153
+ left = (new_width - target_size) // 2
154
+ top = (new_height - target_size) // 2
155
+ right = left + target_size
156
+ bottom = top + target_size
157
+
158
+ image = image.crop((left, top, right, bottom))
159
+
160
+ return image
161
+
162
+ def simulate_training(self,
163
+ job_id: str,
164
+ model_name: str,
165
+ dataset: List[Dict],
166
+ r: int = 16,
167
+ lora_alpha: int = 32,
168
+ lora_dropout: float = 0.1,
169
+ num_epochs: int = 10,
170
+ learning_rate: float = 1e-4,
171
+ batch_size: int = 1,
172
+ resolution: int = 512) -> None:
173
+ """Simula o processo de treinamento LoRA para imagens (versão demonstrativa)."""
174
+
175
+ try:
176
+ # Atualizar status
177
+ self.training_jobs[job_id]["status"] = "loading_model"
178
+ self.training_jobs[job_id]["progress"] = 5
179
+
180
+ # Simular carregamento do modelo base
181
+ time.sleep(2)
182
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Modelo {model_name} carregado")
183
+
184
+ # Preparar configuração LoRA
185
+ self.training_jobs[job_id]["status"] = "preparing_lora"
186
+ self.training_jobs[job_id]["progress"] = 15
187
+ time.sleep(1)
188
+
189
+ lora_config = self.create_lora_config(r, lora_alpha, lora_dropout)
190
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Configuração LoRA criada (r={r}, alpha={lora_alpha})")
191
+
192
+ # Preparar dataset
193
+ self.training_jobs[job_id]["status"] = "preparing_data"
194
+ self.training_jobs[job_id]["progress"] = 25
195
+ time.sleep(1)
196
+
197
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Dataset preparado com {len(dataset)} imagens")
198
+
199
+ # Simular treinamento
200
+ self.training_jobs[job_id]["status"] = "training"
201
+ self.training_jobs[job_id]["progress"] = 30
202
+
203
+ total_steps = num_epochs * len(dataset)
204
+ current_step = 0
205
+
206
+ for epoch in range(num_epochs):
207
+ for batch_idx in range(len(dataset)):
208
+ current_step += 1
209
+
210
+ # Simular tempo de processamento
211
+ time.sleep(0.5)
212
+
213
+ # Atualizar progresso
214
+ progress = 30 + int((current_step / total_steps) * 60)
215
+ self.training_jobs[job_id]["progress"] = min(progress, 90)
216
+
217
+ # Simular loss decrescente
218
+ loss = 0.8 - (current_step / total_steps) * 0.6
219
+
220
+ if current_step % 5 == 0: # Log a cada 5 steps
221
+ log_message = f"Época {epoch+1}/{num_epochs}, Step {current_step}/{total_steps} - Loss: {loss:.4f}"
222
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_message}")
223
+
224
+ # Salvar modelo LoRA
225
+ self.training_jobs[job_id]["status"] = "saving"
226
+ self.training_jobs[job_id]["progress"] = 95
227
+ time.sleep(1)
228
+
229
+ output_dir = f"./lora_models/{job_id}"
230
+ os.makedirs(output_dir, exist_ok=True)
231
+
232
+ # Criar arquivos simulados do LoRA
233
+ lora_config_dict = {
234
+ "r": r,
235
+ "lora_alpha": lora_alpha,
236
+ "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
237
+ "lora_dropout": lora_dropout,
238
+ "bias": "none",
239
+ "task_type": "DIFFUSION",
240
+ "base_model_name": model_name,
241
+ "training_info": {
242
+ "num_epochs": num_epochs,
243
+ "learning_rate": learning_rate,
244
+ "batch_size": batch_size,
245
+ "resolution": resolution,
246
+ "num_images": len(dataset)
247
+ }
248
+ }
249
+
250
+ with open(f"{output_dir}/adapter_config.json", "w") as f:
251
+ json.dump(lora_config_dict, f, indent=2)
252
+
253
+ # Simular arquivo de pesos LoRA
254
+ with open(f"{output_dir}/adapter_model.safetensors", "w") as f:
255
+ f.write("# Arquivo simulado do modelo LoRA treinado para geração de imagens")
256
+
257
+ # Criar arquivo README com informações do treinamento
258
+ readme_content = f"""# LoRA Model - {job_id}
259
+
260
+ ## Informações do Treinamento
261
+
262
+ - **Modelo Base**: {model_name}
263
+ - **Rank (r)**: {r}
264
+ - **LoRA Alpha**: {lora_alpha}
265
+ - **Dropout**: {lora_dropout}
266
+ - **Épocas**: {num_epochs}
267
+ - **Taxa de Aprendizado**: {learning_rate}
268
+ - **Resolução**: {resolution}x{resolution}
269
+ - **Número de Imagens**: {len(dataset)}
270
+ - **Data de Treinamento**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
271
 
272
+ ## Como Usar
 
 
273
 
274
+ 1. Baixe os arquivos `adapter_config.json` e `adapter_model.safetensors`
275
+ 2. Carregue em sua ferramenta de geração de imagens favorita (ComfyUI, Automatic1111, etc.)
276
+ 3. Use o trigger word ou estilo aprendido durante o treinamento
277
 
278
+ ## Arquivos
279
+
280
+ - `adapter_config.json`: Configuração do LoRA
281
+ - `adapter_model.safetensors`: Pesos do modelo LoRA
282
+ - `README.md`: Este arquivo com informações do treinamento
283
+ """
284
+
285
+ with open(f"{output_dir}/README.md", "w") as f:
286
+ f.write(readme_content)
287
+
288
+ # Finalizar
289
+ self.training_jobs[job_id]["status"] = "completed"
290
+ self.training_jobs[job_id]["progress"] = 100
291
+ self.training_jobs[job_id]["model_path"] = output_dir
292
+ self.training_jobs[job_id]["completed_at"] = datetime.now().isoformat()
293
+ self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Treinamento concluído! LoRA salvo em {output_dir}")
294
+
295
+ logger.info(f"Treinamento LoRA concluído para job {job_id}")
296
+
297
+ except Exception as e:
298
+ logger.error(f"Erro no treinamento LoRA para job {job_id}: {str(e)}")
299
+ self.training_jobs[job_id]["status"] = "error"
300
+ self.training_jobs[job_id]["error"] = str(e)
301
+
302
+ def start_training(self,
303
+ model_name: str,
304
+ image_files: List[str],
305
+ captions: List[str],
306
+ **kwargs) -> str:
307
+ """Inicia treinamento LoRA assíncrono."""
308
+
309
+ job_id = str(uuid.uuid4())
310
+
311
+ # Preparar dataset
312
+ dataset = self.prepare_image_dataset(image_files, captions, kwargs.get('resolution', 512))
313
+
314
+ self.training_jobs[job_id] = {
315
+ "id": job_id,
316
+ "status": "queued",
317
+ "progress": 0,
318
+ "created_at": datetime.now().isoformat(),
319
+ "model_name": model_name,
320
+ "num_images": len(dataset),
321
+ "logs": [],
322
+ "error": None,
323
+ "model_path": None,
324
+ "completed_at": None
325
+ }
326
+
327
+ # Iniciar treinamento em thread separada
328
+ thread = threading.Thread(
329
+ target=self.simulate_training,
330
+ args=(job_id, model_name, dataset),
331
+ kwargs=kwargs
332
+ )
333
+ thread.daemon = True
334
+ thread.start()
335
+
336
+ return job_id
337
+
338
+ def get_training_status(self, job_id: str) -> Dict[str, Any]:
339
+ """Retorna status do treinamento."""
340
+ return self.training_jobs.get(job_id, {"error": "Job não encontrado"})
341
+
342
+ def list_trained_models(self) -> List[Dict[str, str]]:
343
+ """Lista modelos LoRA treinados."""
344
+ models = []
345
+ lora_models_dir = Path("./lora_models")
346
+
347
+ if lora_models_dir.exists():
348
+ for model_dir in lora_models_dir.iterdir():
349
+ if model_dir.is_dir():
350
+ config_file = model_dir / "adapter_config.json"
351
+ if config_file.exists():
352
+ try:
353
+ with open(config_file, 'r') as f:
354
+ config = json.load(f)
355
+
356
+ models.append({
357
+ "id": model_dir.name,
358
+ "path": str(model_dir),
359
+ "base_model": config.get("base_model_name", "Unknown"),
360
+ "r": config.get("r", "Unknown"),
361
+ "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
362
+ })
363
+ except:
364
+ models.append({
365
+ "id": model_dir.name,
366
+ "path": str(model_dir),
367
+ "base_model": "Unknown",
368
+ "r": "Unknown",
369
+ "created": datetime.fromtimestamp(model_dir.stat().st_mtime).isoformat()
370
+ })
371
+
372
+ return models
373
+
374
+ def create_download_zip(self, model_path: str) -> str:
375
+ """Cria um arquivo ZIP com os arquivos do modelo LoRA para download."""
376
+ zip_path = f"{model_path}.zip"
377
+
378
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
379
+ model_dir = Path(model_path)
380
+ for file_path in model_dir.rglob('*'):
381
+ if file_path.is_file():
382
+ arcname = file_path.relative_to(model_dir)
383
+ zipf.write(file_path, arcname)
384
+
385
+ return zip_path
386
 
387
+ # Instância global do trainer
388
+ trainer = LoRAImageTrainer()
 
389
 
390
+ def create_gradio_interface():
391
+ """Cria interface Gradio para a ferramenta LoRA de geração de imagens."""
392
+
393
+ # CSS personalizado para responsividade móvel
394
+ custom_css = """
395
+ /* Mobile-first responsive design */
396
+ @media (max-width: 768px) {
397
+ .gradio-container {
398
+ padding: 8px !important;
399
+ margin: 0 !important;
400
+ }
401
+
402
+ .tab-nav {
403
+ flex-wrap: wrap !important;
404
+ gap: 4px !important;
405
+ }
406
+
407
+ .tab-nav button {
408
+ font-size: 14px !important;
409
+ padding: 8px 12px !important;
410
+ min-width: auto !important;
411
+ flex: 1 1 auto !important;
412
+ }
413
+
414
+ .form-container {
415
+ padding: 12px !important;
416
+ }
417
+
418
+ .btn {
419
+ width: 100% !important;
420
+ padding: 12px !important;
421
+ font-size: 16px !important;
422
+ margin-bottom: 8px !important;
423
+ min-height: 44px !important;
424
+ }
425
+
426
+ .textbox textarea {
427
+ font-size: 16px !important;
428
+ min-height: 120px !important;
429
+ }
430
+
431
+ .dropdown select {
432
+ font-size: 16px !important;
433
+ padding: 12px !important;
434
+ }
435
+
436
+ .output-text {
437
+ font-size: 14px !important;
438
+ line-height: 1.5 !important;
439
+ }
440
+
441
+ .column {
442
+ margin-bottom: 16px !important;
443
+ }
444
+
445
+ .file-upload {
446
+ min-height: 100px !important;
447
+ }
448
+ }
449
+
450
+ /* Enhanced visual styles */
451
+ .lora-header {
452
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
453
+ color: white;
454
+ padding: 20px;
455
+ border-radius: 12px;
456
+ margin-bottom: 20px;
457
+ text-align: center;
458
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
459
+ }
460
+
461
+ .status-indicator {
462
+ display: inline-block;
463
+ padding: 4px 8px;
464
+ border-radius: 6px;
465
+ font-size: 12px;
466
+ font-weight: 600;
467
+ text-transform: uppercase;
468
+ letter-spacing: 0.5px;
469
+ margin-right: 8px;
470
+ }
471
+
472
+ .status-queued { background-color: #fbbf24; color: #92400e; }
473
+ .status-loading_model { background-color: #60a5fa; color: #1e40af; }
474
+ .status-preparing_lora { background-color: #8b5cf6; color: #5b21b6; }
475
+ .status-preparing_data { background-color: #06b6d4; color: #0e7490; }
476
+ .status-training { background-color: #a78bfa; color: #5b21b6; }
477
+ .status-saving { background-color: #f59e0b; color: #92400e; }
478
+ .status-completed { background-color: #34d399; color: #065f46; }
479
+ .status-error { background-color: #f87171; color: #991b1b; }
480
+
481
+ /* Touch device optimizations */
482
+ @media (hover: none) and (pointer: coarse) {
483
+ .btn {
484
+ min-height: 44px !important;
485
+ min-width: 44px !important;
486
+ }
487
+
488
+ .tab-nav button {
489
+ min-height: 44px !important;
490
+ min-width: 44px !important;
491
+ }
492
+ }
493
+ """
494
+
495
+ def process_images_and_captions(files, captions_text):
496
+ """Processa imagens e legendas enviadas pelo usuário."""
497
+ if not files:
498
+ return "❌ Erro: Nenhuma imagem foi enviada!"
499
+
500
+ # Processar legendas
501
+ captions = []
502
+ if captions_text.strip():
503
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
504
+
505
+ # Se não há legendas suficientes, usar legendas padrão
506
+ while len(captions) < len(files):
507
+ captions.append(f"training image {len(captions) + 1}")
508
+
509
+ # Truncar legendas se houver mais que imagens
510
+ captions = captions[:len(files)]
511
+
512
+ return files, captions
513
+
514
+ def start_training_wrapper(model_name, files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
515
+ num_epochs, learning_rate, batch_size, resolution):
516
+ """Wrapper para iniciar treinamento via Gradio."""
517
+
518
+ if not files:
519
+ return "❌ Erro: Nenhuma imagem foi enviada para treinamento!"
520
+
521
+ if len(files) < 3:
522
+ return "❌ Erro: Forneça pelo menos 3 imagens para treinamento!"
523
+
524
  try:
525
+ # Processar imagens e legendas
526
+ image_files = [f.name for f in files]
527
+
528
+ # Processar legendas
529
+ captions = []
530
+ if captions_text.strip():
531
+ captions = [line.strip() for line in captions_text.split('\n') if line.strip()]
532
+
533
+ # Se não há legendas suficientes, usar trigger word + descrição padrão
534
+ while len(captions) < len(files):
535
+ if trigger_word.strip():
536
+ captions.append(f"{trigger_word.strip()}, high quality photo")
537
+ else:
538
+ captions.append(f"training image {len(captions) + 1}, high quality photo")
539
+
540
+ # Truncar legendas se houver mais que imagens
541
+ captions = captions[:len(files)]
542
+
543
+ job_id = trainer.start_training(
544
+ model_name=model_name,
545
+ image_files=image_files,
546
+ captions=captions,
547
+ r=int(r),
548
+ lora_alpha=int(lora_alpha),
549
+ lora_dropout=float(lora_dropout),
550
+ num_epochs=int(num_epochs),
551
+ learning_rate=float(learning_rate),
552
+ batch_size=int(batch_size),
553
+ resolution=int(resolution)
554
+ )
555
+
556
+ return f"✅ Treinamento iniciado! ID do Job: {job_id}\n\n📊 Imagens: {len(files)}\n🏷️ Trigger Word: {trigger_word or 'Nenhuma'}\n\nUse o ID acima para verificar o progresso na aba 'Status do Treinamento'."
557
+
558
+ except Exception as e:
559
+ return f"❌ Erro ao iniciar treinamento: {str(e)}"
560
+
561
+ def check_status_wrapper(job_id):
562
+ """Wrapper para verificar status via Gradio."""
563
+ if not job_id.strip():
564
+ return "❌ Erro: Forneça um ID de job válido!"
565
+
566
+ status = trainer.get_training_status(job_id.strip())
567
+
568
+ if "error" in status and status["error"] == "Job não encontrado":
569
+ return "❌ Job não encontrado! Verifique o ID."
570
+
571
+ # Criar indicador visual de status
572
+ status_class = f"status-{status['status']}"
573
+ status_emoji = {
574
+ 'queued': '⏳',
575
+ 'loading_model': '📥',
576
+ 'preparing_lora': '⚙️',
577
+ 'preparing_data': '📊',
578
+ 'training': '🏋️',
579
+ 'saving': '💾',
580
+ 'completed': '✅',
581
+ 'error': '❌'
582
+ }.get(status['status'], '📊')
583
+
584
+ # Barra de progresso visual
585
+ progress = status['progress']
586
+ progress_bar = f"""
587
+ <div style="width: 100%; background-color: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0;">
588
+ <div style="width: {progress}%; height: 8px; background: linear-gradient(90deg, #3b82f6, #8b5cf6); transition: width 0.3s ease; border-radius: 4px;"></div>
589
+ </div>
590
+ """
591
+
592
+ status_text = f"""
593
+ 📊 **Status do Treinamento LoRA**
594
+
595
+ 🆔 **Job ID:** {status['id']}
596
+ {status_emoji} **Status:** <span class="{status_class}">{status['status'].upper().replace('_', ' ')}</span>
597
+ ⏳ **Progresso:** {status['progress']}%
598
 
599
+ {progress_bar}
600
+
601
+ 🤖 **Modelo Base:** {status['model_name']}
602
+ 🖼️ **Imagens:** {status.get('num_images', 'N/A')}
603
+ 📅 **Criado em:** {status['created_at']}
604
+
605
+ """
606
+
607
+ if status['logs']:
608
+ status_text += "📝 **Logs Recentes:**\n"
609
+ for log in status['logs'][-5:]: # Últimos 5 logs
610
+ status_text += f"• {log}\n"
611
+
612
+ if status['status'] == 'completed':
613
+ status_text += f"\n✅ **Treinamento Concluído!**\n📁 **Modelo salvo em:** {status['model_path']}"
614
+ status_text += f"\n⏰ **Concluído em:** {status['completed_at']}"
615
+ status_text += f"\n\n💡 **Próximos passos:** Vá para a aba 'Modelos Treinados' para baixar seu LoRA!"
616
+ elif status['status'] == 'error':
617
+ status_text += f"\n❌ **Erro:** {status['error']}"
618
+
619
+ return status_text
620
+
621
+ def list_models_wrapper():
622
+ """Wrapper para listar modelos via Gradio."""
623
+ models = trainer.list_trained_models()
624
+
625
+ if not models:
626
+ return "📭 Nenhum modelo LoRA treinado encontrado."
627
+
628
+ models_text = "📚 **Modelos LoRA Treinados:**\n\n"
629
+ for model in models:
630
+ models_text += f"🆔 **ID:** {model['id']}\n"
631
+ models_text += f"🤖 **Modelo Base:** {model['base_model']}\n"
632
+ models_text += f"📊 **Rank (r):** {model['r']}\n"
633
+ models_text += f"📁 **Caminho:** {model['path']}\n"
634
+ models_text += f"📅 **Criado:** {model['created']}\n\n"
635
+ models_text += "---\n\n"
636
+
637
+ return models_text
638
+
639
+ def download_model_wrapper(job_id):
640
+ """Wrapper para preparar download do modelo."""
641
+ if not job_id.strip():
642
+ return None, "❌ Erro: Forneça um ID de job válido!"
643
+
644
+ status = trainer.get_training_status(job_id.strip())
645
+
646
+ if "error" in status and status["error"] == "Job não encontrado":
647
+ return None, "❌ Job não encontrado! Verifique o ID."
648
+
649
+ if status['status'] != 'completed':
650
+ return None, f"❌ Treinamento ainda não foi concluído. Status atual: {status['status']}"
651
+
652
+ try:
653
+ model_path = status['model_path']
654
+ zip_path = trainer.create_download_zip(model_path)
655
+
656
+ return zip_path, f"✅ Arquivo ZIP criado com sucesso! Clique no link acima para baixar."
657
+
658
  except Exception as e:
659
+ return None, f"Erro ao criar arquivo de download: {str(e)}"
660
+
661
+ # Interface Gradio
662
+ with gr.Blocks(
663
+ title="🎨 LoRA Image Trainer - Criador e Treinador de LoRA para Imagens",
664
+ theme=gr.themes.Soft(),
665
+ css=custom_css
666
+ ) as interface:
667
+
668
+ gr.HTML("""
669
+ <div class="lora-header">
670
+ <h1>🎨 LoRA Image Trainer</h1>
671
+ <p>Criador e Treinador de LoRA para Geração de Imagens</p>
672
+ <p style="font-size: 0.9em; opacity: 0.9; margin-top: 8px;">
673
+ Ferramenta otimizada para baixo uso de GPU, compatível com dispositivos móveis
674
+ </p>
675
+ </div>
676
+ """)
677
+
678
+ with gr.Tabs():
679
+
680
+ # Aba de Treinamento
681
+ with gr.TabItem("🎯 Treinar LoRA"):
682
+ gr.Markdown("### Configurar e Iniciar Treinamento LoRA para Imagens")
683
+
684
+ with gr.Row():
685
+ with gr.Column(scale=2):
686
+ model_dropdown = gr.Dropdown(
687
+ choices=trainer.get_available_models(),
688
+ value="runwayml/stable-diffusion-v1-5",
689
+ label="🤖 Modelo Base",
690
+
691
+ )
692
+
693
+ image_files = gr.File(
694
+ file_count="multiple",
695
+ file_types=["image"],
696
+ label="🖼️ Imagens de Treinamento",
697
+
698
+ )
699
+
700
+ trigger_word = gr.Textbox(
701
+ label="🏷️ Trigger Word (Opcional)",
702
+ placeholder="ex: meuEstilo, minhaPersonagem, etc.",
703
+
704
+ )
705
+
706
+ captions_text = gr.Textbox(
707
+ lines=8,
708
+ placeholder="Digite uma legenda por linha (opcional)...\n\nExemplo:\nmeuEstilo, retrato de uma mulher\nmeuEstilo, homem sorrindo\nmeuEstilo, paisagem urbana\n\nSe deixar vazio, usará a trigger word + 'high quality photo'",
709
+ label="📝 Legendas das Imagens (Opcional)",
710
+
711
+ )
712
+
713
+ with gr.Column(scale=1):
714
+ gr.Markdown("### ⚙️ Parâmetros LoRA")
715
+
716
+ r = gr.Slider(
717
+ minimum=4, maximum=128, value=16, step=4,
718
+ label="r (Rank)",
719
+
720
+ )
721
+
722
+ lora_alpha = gr.Slider(
723
+ minimum=1, maximum=128, value=32, step=1,
724
+ label="LoRA Alpha",
725
+
726
+ )
727
+
728
+ lora_dropout = gr.Slider(
729
+ minimum=0.0, maximum=0.5, value=0.1, step=0.05,
730
+ label="LoRA Dropout",
731
+
732
+ )
733
+
734
+ gr.Markdown("### 🏋️ Parâmetros de Treinamento")
735
+
736
+ num_epochs = gr.Slider(
737
+ minimum=5, maximum=50, value=10, step=5,
738
+ label="Épocas",
739
+
740
+ )
741
+
742
+ learning_rate = gr.Slider(
743
+ minimum=1e-5, maximum=1e-3, value=1e-4, step=1e-5,
744
+ label="Taxa de Aprendizado",
745
+
746
+ )
747
+
748
+ batch_size = gr.Slider(
749
+ minimum=1, maximum=8, value=1, step=1,
750
+ label="Batch Size",
751
+
752
+ )
753
+
754
+ resolution = gr.Dropdown(
755
+ choices=[512, 768, 1024],
756
+ value=512,
757
+ label="Resolução",
758
+
759
+ )
760
+
761
+ train_button = gr.Button("🚀 Iniciar Treinamento LoRA", variant="primary", size="lg")
762
+ train_output = gr.Textbox(label="📊 Resultado", lines=5)
763
+
764
+ train_button.click(
765
+ start_training_wrapper,
766
+ inputs=[model_dropdown, image_files, captions_text, trigger_word, r, lora_alpha, lora_dropout,
767
+ num_epochs, learning_rate, batch_size, resolution],
768
+ outputs=train_output
769
+ )
770
+
771
+ # Aba de Status
772
+ with gr.TabItem("📊 Status do Treinamento"):
773
+ gr.Markdown("### Verificar Progresso do Treinamento")
774
+
775
+ job_id_input = gr.Textbox(
776
+ label="🆔 ID do Job",
777
+ placeholder="Cole aqui o ID do job de treinamento...",
778
+
779
+ )
780
+
781
+ status_button = gr.Button("🔍 Verificar Status", variant="secondary")
782
+ status_output = gr.Textbox(label="📈 Status", lines=12)
783
+
784
+ status_button.click(
785
+ check_status_wrapper,
786
+ inputs=job_id_input,
787
+ outputs=status_output
788
+ )
789
+
790
+ gr.Markdown("💡 **Dica:** Atualize o status regularmente para acompanhar o progresso do treinamento.")
791
+
792
+ # Aba de Modelos e Download
793
+ with gr.TabItem("📚 Modelos e Download"):
794
+ gr.Markdown("### Visualizar e Baixar Modelos LoRA Treinados")
795
+
796
+ with gr.Row():
797
+ with gr.Column(scale=1):
798
+ list_button = gr.Button("📋 Listar Modelos", variant="secondary")
799
+ models_output = gr.Textbox(label="📚 Modelos Disponíveis", lines=10)
800
+
801
+ list_button.click(
802
+ list_models_wrapper,
803
+ outputs=models_output
804
+ )
805
+
806
+ with gr.Column(scale=1):
807
+ gr.Markdown("#### 💾 Download de Modelo")
808
+
809
+ download_job_id = gr.Textbox(
810
+ label="🆔 ID do Job para Download",
811
+ placeholder="Cole o ID do job concluído...", )
812
+
813
+ download_button = gr.Button("📦 Preparar Download", variant="primary")
814
+ download_file = gr.File(label="📁 Arquivo para Download")
815
+ download_status = gr.Textbox(label="📊 Status do Download", lines=3)
816
+
817
+ download_button.click(
818
+ download_model_wrapper,
819
+ inputs=download_job_id,
820
+ outputs=[download_file, download_status]
821
+ )
822
+
823
+ # Aba de Informações
824
+ with gr.TabItem("ℹ️ Sobre"):
825
+ gr.Markdown("""
826
+ ### 🎯 Sobre o LoRA Image Trainer
827
+
828
+ Esta ferramenta foi desenvolvida para democratizar o acesso ao treinamento de modelos LoRA para geração de imagens,
829
+ permitindo que qualquer pessoa possa criar adaptações personalizadas de modelos de difusão (como Stable Diffusion)
830
+ sem a necessidade de hardware especializado.
831
+
832
+ #### ✨ Características Principais:
833
+
834
+ - **🔋 Otimizado para Baixa GPU**: Utiliza técnicas como mixed precision, gradient checkpointing e configurações otimizadas
835
+ - **📱 Compatível com Móveis**: Interface responsiva que funciona em smartphones e tablets
836
+ - **⚡ Rápido e Eficiente**: Treinamento otimizado com bibliotecas Diffusers e PEFT do Hugging Face
837
+ - **🎛️ Configurável**: Controle total sobre parâmetros LoRA e de treinamento
838
+ - **☁️ Pronto para Deploy**: Facilmente implantável no Hugging Face Spaces
839
+ - **🎨 Focado em Imagens**: Especificamente projetado para modelos de difusão e geração de imagens
840
+
841
+ #### 🛠️ Tecnologias Utilizadas:
842
+
843
+ - **Hugging Face Diffusers**: Para modelos de difusão e pipeline de treinamento
844
+ - **PEFT (Parameter-Efficient Fine-Tuning)**: Para treinamento eficiente de LoRA
845
+ - **PyTorch**: Framework de deep learning
846
+ - **Gradio**: Interface web interativa e responsiva
847
+ - **LoRA (Low-Rank Adaptation)**: Técnica de fine-tuning eficiente para modelos de difusão
848
+
849
+ #### 📖 Como Usar:
850
+
851
+ 1. **Prepare suas imagens**: Colete 3-50 imagens de alta qualidade do estilo/conceito que deseja treinar
852
+ 2. **Escolha um modelo base** na aba "Treinar LoRA" (recomendado: Stable Diffusion 1.5)
853
+ 3. **Faça upload das imagens** e defina uma trigger word (palavra-chave)
854
+ 4. **Configure os parâmetros** conforme necessário (valores padrão funcionam bem)
855
+ 5. **Inicie o treinamento** e anote o ID do job
856
+ 6. **Acompanhe o progresso** na aba "Status do Treinamento"
857
+ 7. **Baixe seu LoRA** na aba "Modelos e Download" quando concluído
858
+ 8. **Use em suas ferramentas favoritas** (ComfyUI, Automatic1111, etc.)
859
+
860
+ #### 💡 Dicas para Melhores Resultados:
861
+
862
+ - **Qualidade > Quantidade**: 10-20 imagens de alta qualidade são melhores que 50 imagens ruins
863
+ - **Consistência**: Use imagens com estilo/conceito consistente
864
+ - **Resolução**: Para GPUs com pouca VRAM, use resolução 512x512
865
+ - **Trigger Word**: Escolha uma palavra única e fácil de lembrar
866
+ - **Legendas**: Descreva o que há nas imagens para melhor controle
867
+ - **Parâmetros**: Para iniciantes, use os valores padrão
868
+
869
+ #### 🎮 Compatibilidade:
870
+
871
+ Os LoRAs gerados são compatíveis com:
872
+ - **ComfyUI**: Carregue os arquivos .safetensors
873
+ - **Automatic1111**: Coloque na pasta models/Lora
874
+ - **SeaArt**: Faça upload do modelo
875
+ - **Outras ferramentas**: Qualquer ferramenta que suporte LoRA para Stable Diffusion
876
+
877
+ ---
878
+
879
+ **Desenvolvido com ❤️ para a comunidade de IA e arte digital**
880
+ """)
881
+
882
+ # Footer
883
+ gr.Markdown("""
884
+ ---
885
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
886
+ 🎨 LoRA Image Trainer v1.0 | Otimizado para Baixa GPU | Compatível com Dispositivos Móveis
887
+ </div>
888
+ """)
889
+
890
+ return interface
891
+
892
+ # Criar e configurar interface
893
+ if __name__ == "__main__":
894
+ # Criar diretórios necessários
895
+ os.makedirs("./lora_models", exist_ok=True)
896
+
897
+ # Configurar interface
898
+ interface = create_gradio_interface()
899
+
900
+ # Lançar aplicação
901
+ interface.launch(
902
+ server_name="0.0.0.0",
903
+ server_port=7860,
904
+ share=False,
905
+ show_error=True,
906
+ quiet=False
907
+ )