psurmreqmer
.
af1df58
raw
history blame
5.51 kB
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from PIL import Image
# --- Configuración del Modelo ---
# NOTA IMPORTANTE: SDXL Refiner (stabilityai/stable-diffusion-xl-refiner-1.0)
# REQUIERE mucha VRAM (aprox. 12 GB). Si tu GPU no es potente o estás usando CPU,
# la carga fallará, el pipe será None, y por eso verás la imagen roja de error.
device = "cuda" if torch.cuda.is_available() else "cpu"
# Usar float16/bfloat16 es solo para acelerar en GPU; si falla, volvemos a float32
dtype_config = torch.bfloat16 if device == "cuda" and torch.cuda.is_available() else torch.float32
# Modelo Stable Diffusion XL Refiner
model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
pipe = None
try:
if device == "cuda":
# Intentar cargar con aceleración
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype_config,
use_safetensors=True
).to(device)
else:
# Si es CPU, cargar solo en float32 (será LENTO, pero intentará funcionar)
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float32,
use_safetensors=True
).to("cpu")
print(f"Modelo SDXL Refiner cargado con éxito en: {device.upper()}")
except Exception as e:
print(f"Error CRÍTICO al cargar el modelo SDXL: {e}")
print("El modelo NO ha podido cargarse. Esto es la causa de la imagen roja.")
print("Solución: Usar un modelo mucho más pequeño o una GPU con más VRAM (min 8-12 GB).")
# --- Función de Procesamiento con Difusión (i2i) ---
# Se eliminó 'prompt_base' de los argumentos, y lo definimos internamente
def procesar_con_sdxl(imagen_entrada, estilo_radial, strength_slider):
"""
Aplica transformación i2i guiada por el estilo radial seleccionado (sin prompt de usuario).
"""
# -----------------------------------------------------------------
# SOLUCIÓN DE IMAGEN ROJA: Verificar si el modelo se cargó primero.
if pipe is None:
error_text = "ERROR: El modelo SDXL no se pudo cargar (Falta VRAM o GPU potente)."
print(error_text)
return Image.new('RGB', (1024, 1024), color = 'red')
# -----------------------------------------------------------------
if imagen_entrada is None:
return None
# 1. Prompt Base Fijo (ya no lo escribe el usuario)
prompt_base = "fotografía de alta calidad, retrato detallado"
estilo_prompts = {
"Blanco y Negro (Monocromático)": ", monocromático, alto contraste, película de 35mm, dramático",
"Alto Contraste y Saturación": ", colores vívidos, alto contraste, HDR, saturación extrema, cinematográfico",
"Original (Poco Ruido)": ", fotografía de alta calidad, realista, colores naturales, sutil",
}
# Combinar el prompt base fijo con el estilo del radial
full_prompt = prompt_base + estilo_prompts.get(estilo_radial, "")
# 2. Preprocesar la imagen
init_image = imagen_entrada.convert("RGB").resize((1024, 1024))
try:
# 3. Ejecutar el pipeline de difusión i2i
image = pipe(
prompt=full_prompt,
image=init_image,
strength=strength_slider,
guidance_scale=7.5
).images[0]
return image
except Exception as e:
print(f"Error durante la ejecución del pipeline: {e}")
# Devuelve un cuadro de error si el proceso falla
return Image.new('RGB', (1024, 1024), color = 'red')
# --- Interfaz Gradio con gr.Blocks() ---
with gr.Blocks(title="SDXL Refiner con Estilos Fijos") as demo:
gr.Markdown(
"""
# 🌟 Tarea con SDXL Refiner (Image-to-Image)
Carga una imagen y selecciona un **Estilo Radial** para que el modelo de difusión la transforme.
El Prompt ahora es **fijo** en el código.
"""
)
with gr.Row():
# Lado izquierdo: Inputs y Controles
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="1. Cargar Imagen Inicial",
)
# ELIMINAMOS EL TEXTBOX DE PROMPT AQUÍ
# --- Control Radial (Radio Buttons) ---
estilo_radial = gr.Radio(
["Original (Poco Ruido)", "Blanco y Negro (Monocromático)", "Alto Contraste y Saturación"],
label="2. Selecciona el Estilo de Transformación",
value="Original (Poco Ruido)"
)
# -----------------------------------
strength_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.05,
label="3. Fuerza de Difusión (Strength): 0.1=sutil, 1.0=cambio total"
)
process_button = gr.Button("✨ Aplicar Difusión SDXL", variant="primary")
# Lado derecho: Output
with gr.Column(scale=1):
image_output = gr.Image(
type="pil",
label="Imagen Transformada por SDXL",
height=512
)
# Conexión de la acción: ¡Cambiamos los inputs!
process_button.click(
fn=procesar_con_sdxl,
inputs=[image_input, estilo_radial, strength_slider], # Eliminamos el prompt_input
outputs=image_output
)
demo.launch(inbrowser=True)