telcom's picture
Create app.py
4e13a1e verified
raw
history blame
8.93 kB
import os
import random
import gc
import gradio as gr
import numpy as np
from PIL import Image
import torch
from diffusers import (
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
EulerAncestralDiscreteScheduler,
)
from huggingface_hub import login
# ============================================================
# GPU decorator (optional)
# ============================================================
try:
import spaces
GPU_DECORATOR = spaces.GPU
except Exception:
def GPU_DECORATOR(fn):
return fn
from compel import CompelForSDXL
MODEL_ID = "telcom/dee-unlearning-tiny-sd"
REVISION="main"
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
if HF_TOKEN:
login(token=HF_TOKEN)
# ============================================================
# Detect device
# ============================================================
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
dtype = torch.float16 if cuda_available else torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1216 if cuda_available else 768 # CPU smaller
pipe_txt2img = None
pipe_img2img = None
compel = None
model_loaded = False
load_error = None
fallback_msg = ""
# ============================================================
# Load model (txt2img + img2img sharing weights)
# ============================================================
try:
from_pretrained_kwargs = dict(
torch_dtype=dtype,
use_safetensors=True,
)
if cuda_available:
from_pretrained_kwargs["variant"] = "fp16"
if HF_TOKEN:
from_pretrained_kwargs["token"] = HF_TOKEN
# Base txt2img pipeline revision=REVISION,
pipe_txt2img = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID, revision=REVISION, **from_pretrained_kwargs
)
pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_txt2img.scheduler.config
)
pipe_txt2img = pipe_txt2img.to(device)
# Memory opts
try:
pipe_txt2img.enable_vae_slicing()
except Exception:
pass
try:
pipe_txt2img.enable_attention_slicing()
except Exception:
pass
try:
pipe_txt2img.enable_xformers_memory_efficient_attention()
except Exception:
pass
pipe_txt2img.set_progress_bar_config(disable=True)
# Create img2img pipeline from txt2img components (no extra weights)
pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_img2img.scheduler.config
)
pipe_img2img = pipe_img2img.to(device)
try:
compel = CompelForSDXL(pipe_txt2img, device=str(device))
except TypeError:
compel = CompelForSDXL(pipe_txt2img)
model_loaded = True
except Exception as e:
load_error = repr(e)
model_loaded = False
if not cuda_available:
fallback_msg = "GPU unavailable. Running in CPU fallback mode (slower, smaller images)."
# ============================================================
# Error image
# ============================================================
def _make_error_image(w, h, text):
img = Image.new("RGB", (w, h), (18, 18, 22))
return img
# ============================================================
# Inference (txt2img or img2img depending on init_image)
# ============================================================
@GPU_DECORATOR
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image, # new: optional image
strength, # new: img2img strength
):
width = int(width)
height = int(height)
seed = int(seed)
if not model_loaded or pipe_txt2img is None or pipe_img2img is None or compel is None:
msg = "Model failed to load."
if load_error:
msg += f" (details: {load_error})"
return _make_error_image(width, height, msg), msg
# Randomize seed if requested
if randomize_seed:
seed = random.randint(0, MAX_SEED)
if device.type == "cuda":
generator = torch.Generator(device=device).manual_seed(seed)
else:
generator = torch.Generator().manual_seed(seed)
status = f"Seed: {seed}"
if fallback_msg:
status += f" | {fallback_msg}"
try:
with torch.inference_mode():
conditioning = compel(prompt, negative_prompt=negative_prompt)
common_kwargs = dict(
prompt_embeds=conditioning.embeds,
pooled_prompt_embeds=conditioning.pooled_embeds,
negative_prompt_embeds=conditioning.negative_embeds,
negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
)
if device.type == "cuda":
with torch.autocast("cuda", dtype=dtype):
# If init_image is provided, use img2img
if init_image is not None:
image = pipe_img2img(
image=init_image,
strength=float(strength),
**common_kwargs,
).images[0]
else:
image = pipe_txt2img(
width=width,
height=height,
**common_kwargs,
).images[0]
else:
if init_image is not None:
image = pipe_img2img(
image=init_image,
strength=float(strength),
**common_kwargs,
).images[0]
else:
image = pipe_txt2img(
width=width,
height=height,
**common_kwargs,
).images[0]
return image, status
except Exception as e:
msg = f"Error during generation: {type(e).__name__}: {e}"
return _make_error_image(width, height, msg), msg
finally:
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
# ============================================================
# UI
# ============================================================
CSS = """
body{
background:#000;
color:#fff;
}
"""
with gr.Blocks(title="Text to Image / Image to Image") as demo:
gr.HTML(f"<style>{CSS}</style>")
with gr.Column():
# banner first
if fallback_msg:
gr.Markdown(f"**{fallback_msg}**")
if not model_loaded:
gr.Markdown(
f"⚠️ **Model failed to load.**\n\nDetails: {load_error}",
elem_classes=["small-note"],
)
gr.Markdown("## SDXL Generator (txt2img + img2img)")
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt...",
lines=2,
)
# NEW: optional initial image for img2img
init_image = gr.Image(
label="Initial image (optional)",
type="pil",
)
run_button = gr.Button("Generate")
result = gr.Image(label="Result")
status = gr.Markdown("")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(label="Negative prompt", value="")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=7)
num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=40, step=1, value=20)
# NEW: strength for img2img
strength = gr.Slider(
label="Image strength (for img2img)",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.7,
)
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
],
outputs=[result, status],
)
demo.queue().launch(ssr_mode=False)