Allex21 commited on
Commit
0eeda66
·
verified ·
1 Parent(s): 56c93f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -239
app.py CHANGED
@@ -1,246 +1,69 @@
1
- import os
2
- import json
3
- import uuid
4
- import shutil
5
- import threading
6
  import time
7
- from datetime import datetime
8
- from pathlib import Path
9
- from typing import Dict, List, Optional, Any, Tuple
10
- import zipfile
11
- import tempfile
12
 
13
- import gradio as gr
14
- import torch
15
- from PIL import Image
16
- import numpy as np
17
- from diffusers import (
18
- StableDiffusionPipeline,
19
- UNet2DConditionModel,
20
- DDPMScheduler,
21
- AutoencoderKL
22
- )
23
- from transformers import CLIPTextModel, CLIPTokenizer
24
- from peft import LoraConfig, get_peft_model, TaskType
25
- import logging
26
 
27
- # Configurar logging
28
- logging.basicConfig(level=logging.INFO)
29
- logger = logging.getLogger(__name__)
30
 
31
- class LoRAImageTrainer:
32
- """Classe principal para treinamento de modelos LoRA para geração de imagens otimizada para baixo uso de GPU."""
33
-
34
- def __init__(self):
35
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- self.training_jobs = {}
37
- self.models_cache = {}
38
-
39
- def get_available_models(self) -> List[str]:
40
- """Retorna lista de modelos base disponíveis para treinamento LoRA."""
41
- return [
42
- "runwayml/stable-diffusion-v1-5",
43
- "stabilityai/stable-diffusion-2-1",
44
- "stabilityai/stable-diffusion-xl-base-1.0",
45
- "CompVis/stable-diffusion-v1-4"
46
- ]
47
-
48
- def load_base_model(self, model_name: str):
49
- """Carrega modelo base de difusão com otimizações para baixo uso de GPU."""
50
- try:
51
- if model_name in self.models_cache:
52
- return self.models_cache[model_name]
53
-
54
- logger.info(f"Carregando modelo base: {model_name}")
55
-
56
- # Configurações para otimização de memória
57
- model_kwargs = {
58
- "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
59
- "use_safetensors": True,
60
- "variant": "fp16" if torch.cuda.is_available() else None,
61
- }
62
-
63
- # Carregar pipeline completo
64
- pipeline = StableDiffusionPipeline.from_pretrained(
65
- model_name,
66
- **model_kwargs
67
- )
68
-
69
- if torch.cuda.is_available():
70
- pipeline = pipeline.to(self.device)
71
- # Habilitar attention slicing para economia de memória
72
- pipeline.enable_attention_slicing()
73
- # Habilitar memory efficient attention se disponível
74
- try:
75
- pipeline.enable_xformers_memory_efficient_attention()
76
- except:
77
- logger.warning("xformers não disponível, usando attention padrão")
78
-
79
- # Cache do modelo
80
- self.models_cache[model_name] = pipeline
81
-
82
- return pipeline
83
-
84
- except Exception as e:
85
- logger.error(f"Erro ao carregar modelo {model_name}: {str(e)}")
86
- raise e
87
-
88
- def create_lora_config(self,
89
- r: int = 16,
90
- lora_alpha: int = 32,
91
- lora_dropout: float = 0.1,
92
- target_modules: Optional[List[str]] = None) -> LoraConfig:
93
- """Cria configuração LoRA otimizada para modelos de difusão."""
94
-
95
- if target_modules is None:
96
- # Módulos padrão para UNet do Stable Diffusion
97
- target_modules = [
98
- "to_k", "to_q", "to_v", "to_out.0",
99
- "proj_in", "proj_out",
100
- "ff.net.0.proj", "ff.net.2"
101
- ]
102
-
103
- return LoraConfig(
104
- r=r,
105
- lora_alpha=lora_alpha,
106
- target_modules=target_modules,
107
- lora_dropout=lora_dropout,
108
- bias="none",
109
- task_type=TaskType.DIFFUSION,
110
- )
111
-
112
- def prepare_image_dataset(self, image_files: List[str], captions: List[str], resolution: int = 512) -> List[Dict]:
113
- """Prepara dataset de imagens para treinamento."""
114
- dataset = []
115
-
116
- for img_path, caption in zip(image_files, captions):
117
- try:
118
- # Carregar e redimensionar imagem
119
- image = Image.open(img_path).convert("RGB")
120
-
121
- # Redimensionar mantendo aspect ratio
122
- image = self.resize_image(image, resolution)
123
-
124
- dataset.append({
125
- "image": image,
126
- "caption": caption,
127
- "image_path": img_path
128
- })
129
-
130
- except Exception as e:
131
- logger.error(f"Erro ao processar imagem {img_path}: {str(e)}")
132
- continue
133
-
134
- return dataset
135
-
136
- def resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
137
- """Redimensiona imagem mantendo aspect ratio e fazendo crop central se necessário."""
138
- width, height = image.size
139
-
140
- # Calcular novo tamanho mantendo aspect ratio
141
- if width > height:
142
- new_width = target_size
143
- new_height = int((height * target_size) / width)
144
- else:
145
- new_height = target_size
146
- new_width = int((width * target_size) / height)
147
-
148
- # Redimensionar
149
- image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
150
 
151
- # Crop central para obter tamanho exato
152
- if new_width != target_size or new_height != target_size:
153
- left = (new_width - target_size) // 2
154
- top = (new_height - target_size) // 2
155
- right = left + target_size
156
- bottom = top + target_size
157
-
158
- image = image.crop((left, top, right, bottom))
159
 
160
- return image
161
-
162
- def simulate_training(self,
163
- job_id: str,
164
- model_name: str,
165
- dataset: List[Dict],
166
- r: int = 16,
167
- lora_alpha: int = 32,
168
- lora_dropout: float = 0.1,
169
- num_epochs: int = 10,
170
- learning_rate: float = 1e-4,
171
- batch_size: int = 1,
172
- resolution: int = 512) -> None:
173
- """Simula o processo de treinamento LoRA para imagens (versão demonstrativa)."""
174
 
175
- try:
176
- # Atualizar status
177
- self.training_jobs[job_id]["status"] = "loading_model"
178
- self.training_jobs[job_id]["progress"] = 5
179
-
180
- # Simular carregamento do modelo base
181
- time.sleep(2)
182
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Modelo {model_name} carregado")
183
-
184
- # Preparar configuração LoRA
185
- self.training_jobs[job_id]["status"] = "preparing_lora"
186
- self.training_jobs[job_id]["progress"] = 15
187
- time.sleep(1)
188
-
189
- lora_config = self.create_lora_config(r, lora_alpha, lora_dropout)
190
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Configuração LoRA criada (r={r}, alpha={lora_alpha})")
191
-
192
- # Preparar dataset
193
- self.training_jobs[job_id]["status"] = "preparing_data"
194
- self.training_jobs[job_id]["progress"] = 25
195
- time.sleep(1)
196
-
197
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - Dataset preparado com {len(dataset)} imagens")
198
-
199
- # Simular treinamento
200
- self.training_jobs[job_id]["status"] = "training"
201
- self.training_jobs[job_id]["progress"] = 30
202
-
203
- total_steps = num_epochs * len(dataset)
204
- current_step = 0
205
-
206
- for epoch in range(num_epochs):
207
- for batch_idx in range(len(dataset)):
208
- current_step += 1
209
-
210
- # Simular tempo de processamento
211
- time.sleep(0.5)
212
-
213
- # Atualizar progresso
214
- progress = 30 + int((current_step / total_steps) * 60)
215
- self.training_jobs[job_id]["progress"] = min(progress, 90)
216
-
217
- # Simular loss decrescente
218
- loss = 0.8 - (current_step / total_steps) * 0.6
219
-
220
- if current_step % 5 == 0: # Log a cada 5 steps
221
- log_message = f"Época {epoch+1}/{num_epochs}, Step {current_step}/{total_steps} - Loss: {loss:.4f}"
222
- self.training_jobs[job_id]["logs"].append(f"{datetime.now().strftime('%H:%M:%S')} - {log_message}")
223
-
224
- # Salvar modelo LoRA
225
- self.training_jobs[job_id]["status"] = "saving"
226
- self.training_jobs[job_id]["progress"] = 95
227
- time.sleep(1)
228
-
229
- output_dir = f"./lora_models/{job_id}"
230
- os.makedirs(output_dir, exist_ok=True)
231
-
232
- # Criar arquivos simulados do LoRA
233
- lora_config_dict = {
234
- "r": r,
235
- "lora_alpha": lora_alpha,
236
- "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
237
- "lora_dropout": lora_dropout,
238
- "bias": "none",
239
- "task_type": "DIFFUSION",
240
- "base_model_name": model_name,
241
- "training_info": {
242
- "num_epochs": num_epochs,
243
- "learning_rate": learning_rate,
244
- "batch_size": batch_size,
245
- "resolution": resolution,
246
- "num_images": len(dataset)
 
1
+ import gradio as gr
 
 
 
 
2
  import time
3
+ import uuid
4
+ import os
 
 
 
5
 
6
+ # Pasta para salvar LoRAs simuladas
7
+ os.makedirs("lora_models", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Armazena jobs
10
+ training_jobs = {}
 
11
 
12
+ def start_training(model_name, num_images):
13
+ job_id = str(uuid.uuid4())
14
+ training_jobs[job_id] = {"status": "Iniciando...", "progress": 0, "logs": []}
15
+
16
+ def train():
17
+ training_jobs[job_id]["logs"].append("Carregando modelo base...")
18
+ time.sleep(1)
19
+ training_jobs[job_id]["progress"] = 20
20
+ training_jobs[job_id]["logs"].append(f"Modelo {model_name} carregado")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ training_jobs[job_id]["status"] = "Treinando..."
23
+ total_steps = num_images
24
+ for step in range(1, total_steps + 1):
25
+ time.sleep(0.5) # simula processamento
26
+ training_jobs[job_id]["progress"] = int(20 + (step / total_steps) * 70)
27
+ training_jobs[job_id]["logs"].append(f"Treinamento passo {step}/{total_steps}")
 
 
28
 
29
+ training_jobs[job_id]["status"] = "Salvando LoRA..."
30
+ time.sleep(1)
31
+ lora_path = f"lora_models/{job_id}.txt"
32
+ with open(lora_path, "w") as f:
33
+ f.write(f"LoRA simulada para {model_name}, {num_images} imagens")
 
 
 
 
 
 
 
 
 
34
 
35
+ training_jobs[job_id]["progress"] = 100
36
+ training_jobs[job_id]["status"] = "Concluído"
37
+ training_jobs[job_id]["logs"].append(f"LoRA salva em {lora_path}")
38
+
39
+ # Rodar treino em thread separada
40
+ import threading
41
+ threading.Thread(target=train).start()
42
+
43
+ return job_id
44
+
45
+ def check_status(job_id):
46
+ job = training_jobs.get(job_id, None)
47
+ if not job:
48
+ return 0, "Job não encontrado", ""
49
+ return job["progress"], job["status"], "\n".join(job["logs"])
50
+
51
+ with gr.Blocks() as demo:
52
+ gr.Markdown("## Treinador de LoRA Simulado")
53
+ model_input = gr.Dropdown(["stable-diffusion-v1-5", "stable-diffusion-2-1"], label="Modelo Base")
54
+ images_input = gr.Slider(1, 50, step=1, label="Número de imagens")
55
+ start_btn = gr.Button("Iniciar Treinamento")
56
+ status_text = gr.Textbox(label="Status", interactive=False)
57
+ progress_bar = gr.Progress(label="Progresso")
58
+ logs_box = gr.Textbox(label="Logs", interactive=False)
59
+
60
+ job_id_holder = gr.Textbox(visible=False)
61
+
62
+ start_btn.click(fn=start_training, inputs=[model_input, images_input], outputs=job_id_holder)
63
+
64
+ def update_status(job_id):
65
+ return check_status(job_id)
66
+
67
+ status_updater = gr.Interval(update_status, inputs=job_id_holder, outputs=[progress_bar, status_text, logs_box], every=1)
68
+
69
+ demo.launch()