telcom's picture
Update app.py
e1970ea verified
raw
history blame
6.21 kB
import os
import random
import gc
import gradio as gr
import numpy as np
from PIL import Image
try:
import spaces
GPU_DECORATOR = spaces.GPU
except Exception:
def GPU_DECORATOR(fn):
return fn
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
EulerAncestralDiscreteScheduler,
)
from transformers import CLIPTokenizer, CLIPTextModel
from huggingface_hub import login
# ============================================================
# Config
# ============================================================
MODEL_ID = "telcom/dee-unlearning-tiny-sd"
REVISION = "main"
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
if HF_TOKEN:
login(token=HF_TOKEN)
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
dtype = torch.float16 if cuda_available else torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 768 if not cuda_available else 1024
pipe_txt2img = None
pipe_img2img = None
model_loaded = False
load_error = None
# ============================================================
# Load model (FORCED tokenizer fix)
# ============================================================
try:
pipe_txt2img = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
revision=REVISION,
torch_dtype=dtype,
safety_checker=None,
).to(device)
# 🔑 FORCE tokenizer + text encoder
pipe_txt2img.tokenizer = CLIPTokenizer.from_pretrained(
MODEL_ID, subfolder="tokenizer"
)
pipe_txt2img.text_encoder = CLIPTextModel.from_pretrained(
MODEL_ID,
subfolder="text_encoder",
torch_dtype=dtype,
).to(device)
# Scheduler
pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_txt2img.scheduler.config
)
# Memory optimisations
try:
pipe_txt2img.enable_attention_slicing()
pipe_txt2img.enable_vae_slicing()
except Exception:
pass
try:
pipe_txt2img.enable_xformers_memory_efficient_attention()
except Exception:
pass
pipe_txt2img.set_progress_bar_config(disable=True)
# Img2Img pipeline (share components)
pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_txt2img.components).to(device)
pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_img2img.scheduler.config
)
# Defensive checks
assert pipe_txt2img.tokenizer is not None
assert pipe_txt2img.text_encoder is not None
model_loaded = True
except Exception as e:
load_error = repr(e)
model_loaded = False
# ============================================================
# Helpers
# ============================================================
def _make_error_image(w, h):
return Image.new("RGB", (w, h), (30, 30, 40))
# ============================================================
# Inference
# ============================================================
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
):
width = int(width)
height = int(height)
if not model_loaded:
return _make_error_image(width, height), load_error
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
with torch.inference_mode():
if init_image is not None:
image = pipe_img2img(
prompt=prompt,
negative_prompt=negative_prompt,
image=init_image,
strength=float(strength),
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
).images[0]
else:
image = pipe_txt2img(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
).images[0]
return image, f"Seed: {seed}"
except Exception as e:
return _make_error_image(width, height), str(e)
finally:
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
# ============================================================
# UI
# ============================================================
with gr.Blocks(title="Stable Diffusion (Unlearning Model)") as demo:
gr.Markdown("## Stable Diffusion Generator")
if not model_loaded:
gr.Markdown(f"⚠️ **Model failed to load**\n\n{load_error}")
prompt = gr.Textbox(label="Prompt", lines=2)
init_image = gr.Image(label="Initial image (optional)", type="pil")
run_button = gr.Button("Generate")
result = gr.Image(label="Result")
status = gr.Markdown("")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(label="Negative prompt", value="")
seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
randomize_seed = gr.Checkbox(True, label="Randomize seed")
width = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Height")
guidance_scale = gr.Slider(1, 20, step=0.5, value=7.5, label="Guidance scale")
num_inference_steps = gr.Slider(1, 40, step=1, value=20, label="Steps")
strength = gr.Slider(0.0, 1.0, step=0.05, value=0.7, label="Image strength")
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
],
outputs=[result, status],
)
demo.queue().launch(ssr_mode=False)