telcom's picture
Update app.py
3c92bd1 verified
# ============================================================
# IMPORTANT: imports order matters for Hugging Face Spaces
# ============================================================
import os
import gc
import random
from typing import Dict, Optional
# ---- Spaces GPU decorator (must be imported early) ----------
try:
import spaces
SPACES_AVAILABLE = True
except Exception:
SPACES_AVAILABLE = False
import gradio as gr
import numpy as np
from PIL import Image
import torch
from diffusers import (
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
EulerAncestralDiscreteScheduler,
)
from huggingface_hub import login, hf_hub_download
from compel import CompelForSDXL
# ============================================================
# Helpers for env flags
# ============================================================
def _env_flag(name: str, default: bool = False) -> bool:
raw = os.getenv(name)
if raw is None:
return default
raw = raw.strip().lower()
return raw in ("1", "true", "yes", "y", "on")
# ============================================================
# Unlearn patch: download from Hub (model repo + revision) then apply
# Adds UNLEARN_PATCH_FORMAT=absolute|delta (default delta)
# ============================================================
MODEL_ID = "telcom/deewaiREALCN"
REVISION = "main" # base model revision for the pipeline
UNLEARN_REPO_ID = MODEL_ID
UNLEARN_REVISION = "main-safe" # branch shown in your screenshot
UNLEARN_FILENAME = "unlearnt/NSFW_wa2.safetensors" # path inside that model repo
def _strip_known_prefixes(k: str) -> str:
for p in ("unet.", "model.unet.", "diffusion_model.", "module.", "state_dict."):
if k.startswith(p):
return k[len(p):]
return k
def _apply_unlearn_patch_to_unet_from_hub(
unet: torch.nn.Module,
hf_token: str = "",
) -> Dict[str, object]:
enabled = _env_flag("UNLEARN_PATCH_ENABLED", default=True)
alpha_raw = os.getenv("UNLEARN_PATCH_ALPHA", "0.2").strip()
try:
alpha = float(alpha_raw)
except ValueError as exc:
raise ValueError(f"Invalid UNLEARN_PATCH_ALPHA={alpha_raw!r}. Must be a float.") from exc
mode = os.getenv("UNLEARN_PATCH_MODE", "blend").strip().lower()
strict = _env_flag("UNLEARN_PATCH_STRICT", default=False)
# NEW: absolute vs delta interpretation
patch_format = os.getenv("UNLEARN_PATCH_FORMAT", "delta").strip().lower()
if patch_format not in ("absolute", "delta"):
msg = f"Unsupported UNLEARN_PATCH_FORMAT={patch_format!r}. Use 'absolute' or 'delta'."
if strict:
raise ValueError(msg)
patch_format = "absolute"
repo_id = os.getenv("UNLEARN_PATCH_REPO_ID", UNLEARN_REPO_ID).strip()
revision = os.getenv("UNLEARN_PATCH_REVISION", UNLEARN_REVISION).strip()
filename = os.getenv("UNLEARN_PATCH_FILENAME", UNLEARN_FILENAME).strip()
alpha = max(0.0, min(1.0, alpha))
details: Dict[str, object] = {
"enabled": enabled,
"alpha": alpha,
"mode": mode,
"format": patch_format,
"strict": strict,
"repo_id": repo_id,
"revision": revision,
"filename": filename,
"downloaded_path": "",
"applied": False,
"applied_keys": 0,
"unexpected_keys": 0,
"mismatched_shapes": 0,
"errors": "",
}
if not enabled or alpha <= 0.0:
return details
if mode not in ("blend", "replace"):
msg = f"Unsupported UNLEARN_PATCH_MODE={mode!r}. Use 'blend' or 'replace'."
details["errors"] = msg
if strict:
raise ValueError(msg)
return details
# 1) Download patch from the model repo + revision into HF cache
try:
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
token=hf_token if hf_token else None,
)
details["downloaded_path"] = downloaded_path
except Exception as e:
msg = f"Failed to download patch from hub: {type(e).__name__}: {e}"
details["errors"] = msg
if strict:
raise
return details
# 2) Apply patch
try:
from safetensors.torch import safe_open
except Exception as e:
msg = f"safetensors not available: {e}"
details["errors"] = msg
if strict:
raise
return details
target_tensors = unet.state_dict()
if not target_tensors:
msg = "UNet state_dict is empty."
details["errors"] = msg
if strict:
raise RuntimeError(msg)
return details
ref_param = next(unet.parameters())
target_device = ref_param.device
applied = 0
unexpected = 0
mismatched = 0
with torch.no_grad():
# Keep device=str(target_device) for speed.
# If you hit GPU OOM during patch apply, switch to device="cpu" and move per-tensor.
with safe_open(downloaded_path, framework="pt", device=str(target_device)) as f:
for raw_key in f.keys():
key = _strip_known_prefixes(raw_key)
if key not in target_tensors:
unexpected += 1
if strict:
raise KeyError(f"Unexpected key in patch: {raw_key} (mapped to {key})")
continue
patch_tensor = f.get_tensor(raw_key)
tgt = target_tensors[key]
if patch_tensor.shape != tgt.shape:
mismatched += 1
if strict:
raise ValueError(
f"Shape mismatch for {key}: patch {tuple(patch_tensor.shape)} vs target {tuple(tgt.shape)}"
)
continue
if patch_tensor.dtype != tgt.dtype:
patch_tensor = patch_tensor.to(dtype=tgt.dtype)
# ============================================================
# UPDATE RULES
# - format=absolute:
# mode=blend -> new = (1-a)*old + a*patch
# mode=replace -> new = patch
# - format=delta:
# mode=blend -> new = old + a*delta
# mode=replace -> new = old + delta (alpha ignored)
# ============================================================
if patch_format == "absolute":
if mode == "replace":
new_t = patch_tensor
else:
new_t = (1.0 - alpha) * tgt + alpha * patch_tensor
else:
# delta
if mode == "replace":
new_t = tgt + patch_tensor
else:
new_t = tgt + alpha * patch_tensor
tgt.copy_(new_t)
applied += 1
details["applied"] = applied > 0
details["applied_keys"] = applied
details["unexpected_keys"] = unexpected
details["mismatched_shapes"] = mismatched
return details
# ============================================================
# Auth (optional)
# ============================================================
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
if HF_TOKEN:
login(token=HF_TOKEN)
MAX_SEED = np.iinfo(np.int32).max
# ============================================================
# Device & dtype
# ============================================================
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
dtype = torch.float16 if cuda_available else torch.float32
MAX_IMAGE_SIZE = 1216 if cuda_available else 768
fallback_msg = ""
if not cuda_available:
fallback_msg = "GPU unavailable. Running in CPU fallback mode."
# ============================================================
# Load pipelines
# ============================================================
pipe_txt2img = None
pipe_img2img = None
compel = None
model_loaded = False
load_error = None
unlearn_details: Optional[Dict[str, object]] = None
try:
from_pretrained_kwargs = {
"torch_dtype": dtype,
"use_safetensors": True,
}
if cuda_available:
from_pretrained_kwargs["variant"] = "fp16"
if HF_TOKEN:
from_pretrained_kwargs["token"] = HF_TOKEN
pipe_txt2img = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID,
revision=REVISION,
**from_pretrained_kwargs,
)
pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_txt2img.scheduler.config
)
pipe_txt2img = pipe_txt2img.to(device)
# Apply the unlearn patch from the model repo (main-safe branch)
# Set UNLEARN_PATCH_FORMAT=delta if your safetensors stores deltas
unlearn_details = _apply_unlearn_patch_to_unet_from_hub(
pipe_txt2img.unet,
hf_token=HF_TOKEN,
)
# Memory optimisations
pipe_txt2img.enable_vae_slicing()
pipe_txt2img.enable_attention_slicing()
try:
pipe_txt2img.enable_xformers_memory_efficient_attention()
except Exception:
pass
pipe_txt2img.set_progress_bar_config(disable=True)
# img2img pipeline shares weights
pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_img2img.scheduler.config
)
pipe_img2img = pipe_img2img.to(device)
compel = CompelForSDXL(pipe_txt2img, device=str(device))
model_loaded = True
except Exception as e:
load_error = repr(e)
model_loaded = False
# ============================================================
# Utility: error image
# ============================================================
def make_error_image(w, h):
return Image.new("RGB", (w, h), (18, 18, 22))
def _format_unlearn_details(d: Optional[Dict[str, object]]) -> str:
if not d:
return "Unlearn patch: (no info)"
lines = [
f"enabled: {d.get('enabled')}",
f"repo_id: {d.get('repo_id')}",
f"revision: {d.get('revision')}",
f"filename: {d.get('filename')}",
f"downloaded_path: {d.get('downloaded_path')}",
f"format: {d.get('format')}",
f"mode: {d.get('mode')} | alpha: {d.get('alpha')} | strict: {d.get('strict')}",
f"applied: {d.get('applied')} | applied_keys: {d.get('applied_keys')} | "
f"unexpected_keys: {d.get('unexpected_keys')} | mismatched_shapes: {d.get('mismatched_shapes')}",
]
if d.get("errors"):
lines.append(f"errors: {d.get('errors')}")
return "\n".join(lines)
# ============================================================
# Inference function
# ============================================================
def _infer_impl(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
):
width = int(width)
height = int(height)
seed = int(seed)
if not model_loaded:
return make_error_image(width, height), f"Model load failed: {load_error}"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
status = f"Seed: {seed}"
if fallback_msg:
status += f" | {fallback_msg}"
try:
with torch.inference_mode():
conditioning = compel(prompt, negative_prompt=negative_prompt)
common_kwargs = dict(
prompt_embeds=conditioning.embeds,
pooled_prompt_embeds=conditioning.pooled_embeds,
negative_prompt_embeds=conditioning.negative_embeds,
negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
)
if device.type == "cuda":
with torch.autocast("cuda", dtype=dtype):
if init_image is not None:
image = pipe_img2img(
image=init_image,
strength=float(strength),
**common_kwargs,
).images[0]
else:
image = pipe_txt2img(
width=width,
height=height,
**common_kwargs,
).images[0]
else:
if init_image is not None:
image = pipe_img2img(
image=init_image,
strength=float(strength),
**common_kwargs,
).images[0]
else:
image = pipe_txt2img(
width=width,
height=height,
**common_kwargs,
).images[0]
return image, status
except Exception as e:
return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
finally:
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
if SPACES_AVAILABLE:
@spaces.GPU
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
):
return _infer_impl(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
)
else:
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
):
return _infer_impl(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
)
# ============================================================
# UI
# ============================================================
CSS = """
body {
background: #000;
color: #fff;
}
"""
with gr.Blocks(title="SDXL txt2img + img2img") as demo:
gr.HTML(f"<style>{CSS}</style>")
if fallback_msg:
gr.Markdown(f"**{fallback_msg}**")
if not model_loaded:
gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
else:
gr.Markdown("### Unlearn patch status")
gr.Markdown(f"```\n{_format_unlearn_details(unlearn_details)}\n```")
gr.Markdown(
"Tip: set `UNLEARN_PATCH_FORMAT=delta` in Space env vars if the safetensors stores deltas.\n"
"Also you can override source with:\n"
"`UNLEARN_PATCH_REPO_ID`, `UNLEARN_PATCH_REVISION`, `UNLEARN_PATCH_FILENAME`."
)
gr.Markdown("## SDXL Generator (txt2img + img2img)")
prompt = gr.Textbox(label="Prompt", lines=2)
init_image = gr.Image(label="Initial image (optional)", type="pil")
run_button = gr.Button("Generate")
result = gr.Image(label="Result")
status = gr.Markdown("")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(label="Negative prompt")
seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
width = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Width")
height = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Height")
guidance_scale = gr.Slider(0, 20, step=0.1, value=7, label="Guidance scale")
num_inference_steps = gr.Slider(1, 40, step=1, value=20, label="Steps")
strength = gr.Slider(0.0, 1.0, step=0.05, value=0.7, label="Image strength")
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
init_image,
strength,
],
outputs=[result, status],
)
demo.queue().launch(ssr_mode=False)