Ntdeseb commited on
Commit
e5f6fa0
·
1 Parent(s): 04fd479

Arreglar imágenes negras - Mejorar carga de modelos, manejo de variants fp16, verificación de imágenes generadas

Browse files
Files changed (1) hide show
  1. app.py +100 -17
app.py CHANGED
@@ -73,9 +73,8 @@ MODELS = {
73
  "Midjourney Style (prompthero/openjourney)": "prompthero/openjourney",
74
  "Orange Mixs (WarriorMama777/OrangeMixs)": "WarriorMama777/OrangeMixs",
75
  "Kohaku V2.1 (KBlueLeaf/kohaku-v2.1)": "KBlueLeaf/kohaku-v2.1",
76
- # Modelos avanzados que aprovechan H200
77
  "SDXL Lightning (ByteDance/SDXL-Lightning)": "ByteDance/SDXL-Lightning",
78
- "SDXL Lightning 4Step (ByteDance/SDXL-Lightning-4Step)": "ByteDance/SDXL-Lightning-4Step",
79
  "FLUX.1-Kontext-Dev (API External)": "api_external",
80
  }
81
 
@@ -84,8 +83,6 @@ if HF_TOKEN:
84
  FLUX_MODELS = {
85
  "FLUX.1-dev (black-forest-labs/FLUX.1-dev)": "black-forest-labs/FLUX.1-dev",
86
  "FLUX.1-schnell (black-forest-labs/FLUX.1-schnell)": "black-forest-labs/FLUX.1-schnell",
87
- # Modelos FLUX adicionales que aprovechan H200
88
- "FLUX.1-pro (black-forest-labs/FLUX.1-pro)": "black-forest-labs/FLUX.1-pro",
89
  }
90
  MODELS.update(FLUX_MODELS)
91
  print("🔓 Modelos FLUX habilitados con autenticación")
@@ -116,22 +113,38 @@ def load_model(model_id):
116
  try:
117
  start_time = time.time()
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  # Usar token de autenticación si está disponible
120
  if HF_TOKEN and ("flux" in model_id.lower() or "black-forest" in model_id.lower()):
121
  print(f"🔐 Cargando modelo gated: {model_id}")
122
  print(f"🔑 Usando token de autenticación...")
 
 
123
  pipe = DiffusionPipeline.from_pretrained(
124
  model_id,
125
  torch_dtype=torch_dtype,
126
  use_auth_token=HF_TOKEN,
127
- variant="fp16" if torch.cuda.is_available() else None
128
  )
129
  else:
130
  print(f"📦 Cargando modelo público: {model_id}")
131
  pipe = DiffusionPipeline.from_pretrained(
132
  model_id,
133
  torch_dtype=torch_dtype,
134
- variant="fp16" if torch.cuda.is_available() else None
135
  )
136
 
137
  load_time = time.time() - start_time
@@ -144,34 +157,65 @@ def load_model(model_id):
144
  if torch.cuda.is_available():
145
  print("🔧 Aplicando optimizaciones para H200...")
146
 
147
- # Habilitar optimizaciones de memoria
148
  if hasattr(pipe, 'enable_attention_slicing'):
149
  pipe.enable_attention_slicing()
150
  print("✅ Attention slicing habilitado")
151
 
152
- if hasattr(pipe, 'enable_model_cpu_offload'):
 
153
  pipe.enable_model_cpu_offload()
154
- print("✅ CPU offload habilitado")
155
 
156
  if hasattr(pipe, 'enable_vae_slicing'):
157
  pipe.enable_vae_slicing()
158
  print("✅ VAE slicing habilitado")
159
 
 
160
  if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
161
  try:
162
  pipe.enable_xformers_memory_efficient_attention()
163
  print("✅ XFormers memory efficient attention habilitado")
164
- except:
165
- print("⚠️ XFormers no disponible, usando atención estándar")
 
166
 
167
  current_model_id = model_id
168
  print(f"✅ Modelo {model_id} cargado exitosamente")
169
- print(f"💾 Memoria utilizada: {torch.cuda.memory_allocated() / 1024**3:.2f} GB" if torch.cuda.is_available() else "💾 Memoria CPU")
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  except Exception as e:
172
  print(f"❌ Error cargando modelo {model_id}: {e}")
173
  print(f"🔍 Tipo de error: {type(e).__name__}")
174
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  else:
176
  print(f"♻️ Modelo {model_id} ya está cargado, reutilizando...")
177
 
@@ -417,7 +461,8 @@ def infer(
417
  with torch.autocast(device_type='cuda', dtype=torch.float16):
418
  print("⚡ Usando mixed precision para H200")
419
 
420
- image = pipe(
 
421
  prompt=prompt,
422
  negative_prompt=negative_prompt,
423
  guidance_scale=final_guidance_scale,
@@ -426,10 +471,47 @@ def infer(
426
  height=height,
427
  generator=generator,
428
  **additional_params
429
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  else:
431
  # Fallback para CPU
432
- image = pipe(
433
  prompt=prompt,
434
  negative_prompt=negative_prompt,
435
  guidance_scale=final_guidance_scale,
@@ -438,7 +520,8 @@ def infer(
438
  height=height,
439
  generator=generator,
440
  **additional_params
441
- ).images[0]
 
442
 
443
  inference_time = time.time() - inference_start
444
  total_time = time.time() - start_time
 
73
  "Midjourney Style (prompthero/openjourney)": "prompthero/openjourney",
74
  "Orange Mixs (WarriorMama777/OrangeMixs)": "WarriorMama777/OrangeMixs",
75
  "Kohaku V2.1 (KBlueLeaf/kohaku-v2.1)": "KBlueLeaf/kohaku-v2.1",
76
+ # Modelos avanzados que aprovechan H200 (solo los que existen)
77
  "SDXL Lightning (ByteDance/SDXL-Lightning)": "ByteDance/SDXL-Lightning",
 
78
  "FLUX.1-Kontext-Dev (API External)": "api_external",
79
  }
80
 
 
83
  FLUX_MODELS = {
84
  "FLUX.1-dev (black-forest-labs/FLUX.1-dev)": "black-forest-labs/FLUX.1-dev",
85
  "FLUX.1-schnell (black-forest-labs/FLUX.1-schnell)": "black-forest-labs/FLUX.1-schnell",
 
 
86
  }
87
  MODELS.update(FLUX_MODELS)
88
  print("🔓 Modelos FLUX habilitados con autenticación")
 
113
  try:
114
  start_time = time.time()
115
 
116
+ # Determinar si usar variant fp16 basado en el modelo
117
+ use_fp16_variant = False
118
+ if torch.cuda.is_available():
119
+ # Solo usar fp16 variant para modelos que lo soportan
120
+ fp16_supported_models = [
121
+ "stabilityai/sdxl-turbo",
122
+ "stabilityai/sd-turbo",
123
+ "stabilityai/stable-diffusion-xl-base-1.0",
124
+ "runwayml/stable-diffusion-v1-5",
125
+ "CompVis/stable-diffusion-v1-4"
126
+ ]
127
+ use_fp16_variant = any(model in model_id for model in fp16_supported_models)
128
+ print(f"🔧 FP16 variant: {'✅ Habilitado' if use_fp16_variant else '❌ Deshabilitado'} para {model_id}")
129
+
130
  # Usar token de autenticación si está disponible
131
  if HF_TOKEN and ("flux" in model_id.lower() or "black-forest" in model_id.lower()):
132
  print(f"🔐 Cargando modelo gated: {model_id}")
133
  print(f"🔑 Usando token de autenticación...")
134
+
135
+ # Para modelos FLUX, no usar variant fp16
136
  pipe = DiffusionPipeline.from_pretrained(
137
  model_id,
138
  torch_dtype=torch_dtype,
139
  use_auth_token=HF_TOKEN,
140
+ variant="fp16" if use_fp16_variant else None
141
  )
142
  else:
143
  print(f"📦 Cargando modelo público: {model_id}")
144
  pipe = DiffusionPipeline.from_pretrained(
145
  model_id,
146
  torch_dtype=torch_dtype,
147
+ variant="fp16" if use_fp16_variant else None
148
  )
149
 
150
  load_time = time.time() - start_time
 
157
  if torch.cuda.is_available():
158
  print("🔧 Aplicando optimizaciones para H200...")
159
 
160
+ # Habilitar optimizaciones de memoria (más conservadoras)
161
  if hasattr(pipe, 'enable_attention_slicing'):
162
  pipe.enable_attention_slicing()
163
  print("✅ Attention slicing habilitado")
164
 
165
+ # Solo usar CPU offload para modelos grandes
166
+ if hasattr(pipe, 'enable_model_cpu_offload') and "sdxl" in model_id.lower():
167
  pipe.enable_model_cpu_offload()
168
+ print("✅ CPU offload habilitado (modelo grande)")
169
 
170
  if hasattr(pipe, 'enable_vae_slicing'):
171
  pipe.enable_vae_slicing()
172
  print("✅ VAE slicing habilitado")
173
 
174
+ # XFormers solo si está disponible y el modelo lo soporta
175
  if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
176
  try:
177
  pipe.enable_xformers_memory_efficient_attention()
178
  print("✅ XFormers memory efficient attention habilitado")
179
+ except Exception as e:
180
+ print(f"⚠️ XFormers no disponible: {e}")
181
+ print("🔄 Usando atención estándar")
182
 
183
  current_model_id = model_id
184
  print(f"✅ Modelo {model_id} cargado exitosamente")
185
+
186
+ if torch.cuda.is_available():
187
+ memory_used = torch.cuda.memory_allocated() / 1024**3
188
+ memory_reserved = torch.cuda.memory_reserved() / 1024**3
189
+ print(f"💾 Memoria GPU utilizada: {memory_used:.2f} GB")
190
+ print(f"💾 Memoria GPU reservada: {memory_reserved:.2f} GB")
191
+
192
+ # Verificar si la memoria es sospechosamente baja
193
+ if memory_used < 0.1:
194
+ print("⚠️ ADVERTENCIA: Memoria GPU muy baja - posible problema de carga")
195
+ else:
196
+ print("💾 Memoria CPU")
197
 
198
  except Exception as e:
199
  print(f"❌ Error cargando modelo {model_id}: {e}")
200
  print(f"🔍 Tipo de error: {type(e).__name__}")
201
+
202
+ # Intentar cargar sin variant fp16 si falló
203
+ if "variant" in str(e) and "fp16" in str(e):
204
+ print("🔄 Reintentando sin variant fp16...")
205
+ try:
206
+ pipe = DiffusionPipeline.from_pretrained(
207
+ model_id,
208
+ torch_dtype=torch_dtype,
209
+ use_auth_token=HF_TOKEN if HF_TOKEN and ("flux" in model_id.lower() or "black-forest" in model_id.lower()) else None
210
+ )
211
+ pipe = pipe.to(device)
212
+ current_model_id = model_id
213
+ print(f"✅ Modelo {model_id} cargado exitosamente (sin fp16 variant)")
214
+ except Exception as e2:
215
+ print(f"❌ Error en segundo intento: {e2}")
216
+ raise e2
217
+ else:
218
+ raise e
219
  else:
220
  print(f"♻️ Modelo {model_id} ya está cargado, reutilizando...")
221
 
 
461
  with torch.autocast(device_type='cuda', dtype=torch.float16):
462
  print("⚡ Usando mixed precision para H200")
463
 
464
+ # Generar la imagen
465
+ result = pipe(
466
  prompt=prompt,
467
  negative_prompt=negative_prompt,
468
  guidance_scale=final_guidance_scale,
 
471
  height=height,
472
  generator=generator,
473
  **additional_params
474
+ )
475
+
476
+ # Verificar que la imagen se generó correctamente
477
+ if hasattr(result, 'images') and len(result.images) > 0:
478
+ image = result.images[0]
479
+
480
+ # Verificar que la imagen no sea completamente negra
481
+ if image is not None:
482
+ # Convertir a numpy para verificar
483
+ img_array = np.array(image)
484
+ if img_array.size > 0:
485
+ # Verificar si la imagen es completamente negra
486
+ if np.all(img_array == 0) or np.all(img_array < 10):
487
+ print("⚠️ ADVERTENCIA: Imagen generada es completamente negra")
488
+ print("🔄 Reintentando con parámetros ajustados...")
489
+
490
+ # Reintentar con parámetros más conservadores
491
+ result = pipe(
492
+ prompt=prompt,
493
+ negative_prompt=negative_prompt,
494
+ guidance_scale=max(1.0, final_guidance_scale * 0.8),
495
+ num_inference_steps=max(10, final_inference_steps),
496
+ width=width,
497
+ height=height,
498
+ generator=generator
499
+ )
500
+ image = result.images[0]
501
+ else:
502
+ print("✅ Imagen generada correctamente")
503
+ else:
504
+ print("❌ Error: Imagen vacía")
505
+ raise Exception("Imagen vacía generada")
506
+ else:
507
+ print("❌ Error: Imagen es None")
508
+ raise Exception("Imagen es None")
509
+ else:
510
+ print("❌ Error: No se generaron imágenes")
511
+ raise Exception("No se generaron imágenes")
512
  else:
513
  # Fallback para CPU
514
+ result = pipe(
515
  prompt=prompt,
516
  negative_prompt=negative_prompt,
517
  guidance_scale=final_guidance_scale,
 
520
  height=height,
521
  generator=generator,
522
  **additional_params
523
+ )
524
+ image = result.images[0]
525
 
526
  inference_time = time.time() - inference_start
527
  total_time = time.time() - start_time