Arreglar imágenes negras - Mejorar carga de modelos, manejo de variants fp16, verificación de imágenes generadas
Browse files
app.py
CHANGED
|
@@ -73,9 +73,8 @@ MODELS = {
|
|
| 73 |
"Midjourney Style (prompthero/openjourney)": "prompthero/openjourney",
|
| 74 |
"Orange Mixs (WarriorMama777/OrangeMixs)": "WarriorMama777/OrangeMixs",
|
| 75 |
"Kohaku V2.1 (KBlueLeaf/kohaku-v2.1)": "KBlueLeaf/kohaku-v2.1",
|
| 76 |
-
# Modelos avanzados que aprovechan H200
|
| 77 |
"SDXL Lightning (ByteDance/SDXL-Lightning)": "ByteDance/SDXL-Lightning",
|
| 78 |
-
"SDXL Lightning 4Step (ByteDance/SDXL-Lightning-4Step)": "ByteDance/SDXL-Lightning-4Step",
|
| 79 |
"FLUX.1-Kontext-Dev (API External)": "api_external",
|
| 80 |
}
|
| 81 |
|
|
@@ -84,8 +83,6 @@ if HF_TOKEN:
|
|
| 84 |
FLUX_MODELS = {
|
| 85 |
"FLUX.1-dev (black-forest-labs/FLUX.1-dev)": "black-forest-labs/FLUX.1-dev",
|
| 86 |
"FLUX.1-schnell (black-forest-labs/FLUX.1-schnell)": "black-forest-labs/FLUX.1-schnell",
|
| 87 |
-
# Modelos FLUX adicionales que aprovechan H200
|
| 88 |
-
"FLUX.1-pro (black-forest-labs/FLUX.1-pro)": "black-forest-labs/FLUX.1-pro",
|
| 89 |
}
|
| 90 |
MODELS.update(FLUX_MODELS)
|
| 91 |
print("🔓 Modelos FLUX habilitados con autenticación")
|
|
@@ -116,22 +113,38 @@ def load_model(model_id):
|
|
| 116 |
try:
|
| 117 |
start_time = time.time()
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
# Usar token de autenticación si está disponible
|
| 120 |
if HF_TOKEN and ("flux" in model_id.lower() or "black-forest" in model_id.lower()):
|
| 121 |
print(f"🔐 Cargando modelo gated: {model_id}")
|
| 122 |
print(f"🔑 Usando token de autenticación...")
|
|
|
|
|
|
|
| 123 |
pipe = DiffusionPipeline.from_pretrained(
|
| 124 |
model_id,
|
| 125 |
torch_dtype=torch_dtype,
|
| 126 |
use_auth_token=HF_TOKEN,
|
| 127 |
-
variant="fp16" if
|
| 128 |
)
|
| 129 |
else:
|
| 130 |
print(f"📦 Cargando modelo público: {model_id}")
|
| 131 |
pipe = DiffusionPipeline.from_pretrained(
|
| 132 |
model_id,
|
| 133 |
torch_dtype=torch_dtype,
|
| 134 |
-
variant="fp16" if
|
| 135 |
)
|
| 136 |
|
| 137 |
load_time = time.time() - start_time
|
|
@@ -144,34 +157,65 @@ def load_model(model_id):
|
|
| 144 |
if torch.cuda.is_available():
|
| 145 |
print("🔧 Aplicando optimizaciones para H200...")
|
| 146 |
|
| 147 |
-
# Habilitar optimizaciones de memoria
|
| 148 |
if hasattr(pipe, 'enable_attention_slicing'):
|
| 149 |
pipe.enable_attention_slicing()
|
| 150 |
print("✅ Attention slicing habilitado")
|
| 151 |
|
| 152 |
-
|
|
|
|
| 153 |
pipe.enable_model_cpu_offload()
|
| 154 |
-
print("✅ CPU offload habilitado")
|
| 155 |
|
| 156 |
if hasattr(pipe, 'enable_vae_slicing'):
|
| 157 |
pipe.enable_vae_slicing()
|
| 158 |
print("✅ VAE slicing habilitado")
|
| 159 |
|
|
|
|
| 160 |
if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
|
| 161 |
try:
|
| 162 |
pipe.enable_xformers_memory_efficient_attention()
|
| 163 |
print("✅ XFormers memory efficient attention habilitado")
|
| 164 |
-
except:
|
| 165 |
-
print("⚠️ XFormers no disponible
|
|
|
|
| 166 |
|
| 167 |
current_model_id = model_id
|
| 168 |
print(f"✅ Modelo {model_id} cargado exitosamente")
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
except Exception as e:
|
| 172 |
print(f"❌ Error cargando modelo {model_id}: {e}")
|
| 173 |
print(f"🔍 Tipo de error: {type(e).__name__}")
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
else:
|
| 176 |
print(f"♻️ Modelo {model_id} ya está cargado, reutilizando...")
|
| 177 |
|
|
@@ -417,7 +461,8 @@ def infer(
|
|
| 417 |
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 418 |
print("⚡ Usando mixed precision para H200")
|
| 419 |
|
| 420 |
-
|
|
|
|
| 421 |
prompt=prompt,
|
| 422 |
negative_prompt=negative_prompt,
|
| 423 |
guidance_scale=final_guidance_scale,
|
|
@@ -426,10 +471,47 @@ def infer(
|
|
| 426 |
height=height,
|
| 427 |
generator=generator,
|
| 428 |
**additional_params
|
| 429 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
else:
|
| 431 |
# Fallback para CPU
|
| 432 |
-
|
| 433 |
prompt=prompt,
|
| 434 |
negative_prompt=negative_prompt,
|
| 435 |
guidance_scale=final_guidance_scale,
|
|
@@ -438,7 +520,8 @@ def infer(
|
|
| 438 |
height=height,
|
| 439 |
generator=generator,
|
| 440 |
**additional_params
|
| 441 |
-
)
|
|
|
|
| 442 |
|
| 443 |
inference_time = time.time() - inference_start
|
| 444 |
total_time = time.time() - start_time
|
|
|
|
| 73 |
"Midjourney Style (prompthero/openjourney)": "prompthero/openjourney",
|
| 74 |
"Orange Mixs (WarriorMama777/OrangeMixs)": "WarriorMama777/OrangeMixs",
|
| 75 |
"Kohaku V2.1 (KBlueLeaf/kohaku-v2.1)": "KBlueLeaf/kohaku-v2.1",
|
| 76 |
+
# Modelos avanzados que aprovechan H200 (solo los que existen)
|
| 77 |
"SDXL Lightning (ByteDance/SDXL-Lightning)": "ByteDance/SDXL-Lightning",
|
|
|
|
| 78 |
"FLUX.1-Kontext-Dev (API External)": "api_external",
|
| 79 |
}
|
| 80 |
|
|
|
|
| 83 |
FLUX_MODELS = {
|
| 84 |
"FLUX.1-dev (black-forest-labs/FLUX.1-dev)": "black-forest-labs/FLUX.1-dev",
|
| 85 |
"FLUX.1-schnell (black-forest-labs/FLUX.1-schnell)": "black-forest-labs/FLUX.1-schnell",
|
|
|
|
|
|
|
| 86 |
}
|
| 87 |
MODELS.update(FLUX_MODELS)
|
| 88 |
print("🔓 Modelos FLUX habilitados con autenticación")
|
|
|
|
| 113 |
try:
|
| 114 |
start_time = time.time()
|
| 115 |
|
| 116 |
+
# Determinar si usar variant fp16 basado en el modelo
|
| 117 |
+
use_fp16_variant = False
|
| 118 |
+
if torch.cuda.is_available():
|
| 119 |
+
# Solo usar fp16 variant para modelos que lo soportan
|
| 120 |
+
fp16_supported_models = [
|
| 121 |
+
"stabilityai/sdxl-turbo",
|
| 122 |
+
"stabilityai/sd-turbo",
|
| 123 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 124 |
+
"runwayml/stable-diffusion-v1-5",
|
| 125 |
+
"CompVis/stable-diffusion-v1-4"
|
| 126 |
+
]
|
| 127 |
+
use_fp16_variant = any(model in model_id for model in fp16_supported_models)
|
| 128 |
+
print(f"🔧 FP16 variant: {'✅ Habilitado' if use_fp16_variant else '❌ Deshabilitado'} para {model_id}")
|
| 129 |
+
|
| 130 |
# Usar token de autenticación si está disponible
|
| 131 |
if HF_TOKEN and ("flux" in model_id.lower() or "black-forest" in model_id.lower()):
|
| 132 |
print(f"🔐 Cargando modelo gated: {model_id}")
|
| 133 |
print(f"🔑 Usando token de autenticación...")
|
| 134 |
+
|
| 135 |
+
# Para modelos FLUX, no usar variant fp16
|
| 136 |
pipe = DiffusionPipeline.from_pretrained(
|
| 137 |
model_id,
|
| 138 |
torch_dtype=torch_dtype,
|
| 139 |
use_auth_token=HF_TOKEN,
|
| 140 |
+
variant="fp16" if use_fp16_variant else None
|
| 141 |
)
|
| 142 |
else:
|
| 143 |
print(f"📦 Cargando modelo público: {model_id}")
|
| 144 |
pipe = DiffusionPipeline.from_pretrained(
|
| 145 |
model_id,
|
| 146 |
torch_dtype=torch_dtype,
|
| 147 |
+
variant="fp16" if use_fp16_variant else None
|
| 148 |
)
|
| 149 |
|
| 150 |
load_time = time.time() - start_time
|
|
|
|
| 157 |
if torch.cuda.is_available():
|
| 158 |
print("🔧 Aplicando optimizaciones para H200...")
|
| 159 |
|
| 160 |
+
# Habilitar optimizaciones de memoria (más conservadoras)
|
| 161 |
if hasattr(pipe, 'enable_attention_slicing'):
|
| 162 |
pipe.enable_attention_slicing()
|
| 163 |
print("✅ Attention slicing habilitado")
|
| 164 |
|
| 165 |
+
# Solo usar CPU offload para modelos grandes
|
| 166 |
+
if hasattr(pipe, 'enable_model_cpu_offload') and "sdxl" in model_id.lower():
|
| 167 |
pipe.enable_model_cpu_offload()
|
| 168 |
+
print("✅ CPU offload habilitado (modelo grande)")
|
| 169 |
|
| 170 |
if hasattr(pipe, 'enable_vae_slicing'):
|
| 171 |
pipe.enable_vae_slicing()
|
| 172 |
print("✅ VAE slicing habilitado")
|
| 173 |
|
| 174 |
+
# XFormers solo si está disponible y el modelo lo soporta
|
| 175 |
if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
|
| 176 |
try:
|
| 177 |
pipe.enable_xformers_memory_efficient_attention()
|
| 178 |
print("✅ XFormers memory efficient attention habilitado")
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"⚠️ XFormers no disponible: {e}")
|
| 181 |
+
print("🔄 Usando atención estándar")
|
| 182 |
|
| 183 |
current_model_id = model_id
|
| 184 |
print(f"✅ Modelo {model_id} cargado exitosamente")
|
| 185 |
+
|
| 186 |
+
if torch.cuda.is_available():
|
| 187 |
+
memory_used = torch.cuda.memory_allocated() / 1024**3
|
| 188 |
+
memory_reserved = torch.cuda.memory_reserved() / 1024**3
|
| 189 |
+
print(f"💾 Memoria GPU utilizada: {memory_used:.2f} GB")
|
| 190 |
+
print(f"💾 Memoria GPU reservada: {memory_reserved:.2f} GB")
|
| 191 |
+
|
| 192 |
+
# Verificar si la memoria es sospechosamente baja
|
| 193 |
+
if memory_used < 0.1:
|
| 194 |
+
print("⚠️ ADVERTENCIA: Memoria GPU muy baja - posible problema de carga")
|
| 195 |
+
else:
|
| 196 |
+
print("💾 Memoria CPU")
|
| 197 |
|
| 198 |
except Exception as e:
|
| 199 |
print(f"❌ Error cargando modelo {model_id}: {e}")
|
| 200 |
print(f"🔍 Tipo de error: {type(e).__name__}")
|
| 201 |
+
|
| 202 |
+
# Intentar cargar sin variant fp16 si falló
|
| 203 |
+
if "variant" in str(e) and "fp16" in str(e):
|
| 204 |
+
print("🔄 Reintentando sin variant fp16...")
|
| 205 |
+
try:
|
| 206 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 207 |
+
model_id,
|
| 208 |
+
torch_dtype=torch_dtype,
|
| 209 |
+
use_auth_token=HF_TOKEN if HF_TOKEN and ("flux" in model_id.lower() or "black-forest" in model_id.lower()) else None
|
| 210 |
+
)
|
| 211 |
+
pipe = pipe.to(device)
|
| 212 |
+
current_model_id = model_id
|
| 213 |
+
print(f"✅ Modelo {model_id} cargado exitosamente (sin fp16 variant)")
|
| 214 |
+
except Exception as e2:
|
| 215 |
+
print(f"❌ Error en segundo intento: {e2}")
|
| 216 |
+
raise e2
|
| 217 |
+
else:
|
| 218 |
+
raise e
|
| 219 |
else:
|
| 220 |
print(f"♻️ Modelo {model_id} ya está cargado, reutilizando...")
|
| 221 |
|
|
|
|
| 461 |
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 462 |
print("⚡ Usando mixed precision para H200")
|
| 463 |
|
| 464 |
+
# Generar la imagen
|
| 465 |
+
result = pipe(
|
| 466 |
prompt=prompt,
|
| 467 |
negative_prompt=negative_prompt,
|
| 468 |
guidance_scale=final_guidance_scale,
|
|
|
|
| 471 |
height=height,
|
| 472 |
generator=generator,
|
| 473 |
**additional_params
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Verificar que la imagen se generó correctamente
|
| 477 |
+
if hasattr(result, 'images') and len(result.images) > 0:
|
| 478 |
+
image = result.images[0]
|
| 479 |
+
|
| 480 |
+
# Verificar que la imagen no sea completamente negra
|
| 481 |
+
if image is not None:
|
| 482 |
+
# Convertir a numpy para verificar
|
| 483 |
+
img_array = np.array(image)
|
| 484 |
+
if img_array.size > 0:
|
| 485 |
+
# Verificar si la imagen es completamente negra
|
| 486 |
+
if np.all(img_array == 0) or np.all(img_array < 10):
|
| 487 |
+
print("⚠️ ADVERTENCIA: Imagen generada es completamente negra")
|
| 488 |
+
print("🔄 Reintentando con parámetros ajustados...")
|
| 489 |
+
|
| 490 |
+
# Reintentar con parámetros más conservadores
|
| 491 |
+
result = pipe(
|
| 492 |
+
prompt=prompt,
|
| 493 |
+
negative_prompt=negative_prompt,
|
| 494 |
+
guidance_scale=max(1.0, final_guidance_scale * 0.8),
|
| 495 |
+
num_inference_steps=max(10, final_inference_steps),
|
| 496 |
+
width=width,
|
| 497 |
+
height=height,
|
| 498 |
+
generator=generator
|
| 499 |
+
)
|
| 500 |
+
image = result.images[0]
|
| 501 |
+
else:
|
| 502 |
+
print("✅ Imagen generada correctamente")
|
| 503 |
+
else:
|
| 504 |
+
print("❌ Error: Imagen vacía")
|
| 505 |
+
raise Exception("Imagen vacía generada")
|
| 506 |
+
else:
|
| 507 |
+
print("❌ Error: Imagen es None")
|
| 508 |
+
raise Exception("Imagen es None")
|
| 509 |
+
else:
|
| 510 |
+
print("❌ Error: No se generaron imágenes")
|
| 511 |
+
raise Exception("No se generaron imágenes")
|
| 512 |
else:
|
| 513 |
# Fallback para CPU
|
| 514 |
+
result = pipe(
|
| 515 |
prompt=prompt,
|
| 516 |
negative_prompt=negative_prompt,
|
| 517 |
guidance_scale=final_guidance_scale,
|
|
|
|
| 520 |
height=height,
|
| 521 |
generator=generator,
|
| 522 |
**additional_params
|
| 523 |
+
)
|
| 524 |
+
image = result.images[0]
|
| 525 |
|
| 526 |
inference_time = time.time() - inference_start
|
| 527 |
total_time = time.time() - start_time
|