Spaces:
Paused
Paused
| # ============================================================ | |
| # IMPORTANT: imports order matters for Hugging Face Spaces | |
| # ============================================================ | |
| import os | |
| import gc | |
| import random | |
| from typing import Dict, Optional | |
| # ---- Spaces GPU decorator (must be imported early) ---------- | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| except Exception: | |
| SPACES_AVAILABLE = False | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from diffusers import ( | |
| StableDiffusionXLPipeline, | |
| StableDiffusionXLImg2ImgPipeline, | |
| EulerAncestralDiscreteScheduler, | |
| ) | |
| from huggingface_hub import login, hf_hub_download | |
| from compel import CompelForSDXL | |
| # ============================================================ | |
| # Helpers for env flags | |
| # ============================================================ | |
| def _env_flag(name: str, default: bool = False) -> bool: | |
| raw = os.getenv(name) | |
| if raw is None: | |
| return default | |
| raw = raw.strip().lower() | |
| return raw in ("1", "true", "yes", "y", "on") | |
| # ============================================================ | |
| # Unlearn patch: download from Hub (model repo + revision) then apply | |
| # Adds UNLEARN_PATCH_FORMAT=absolute|delta (default delta) | |
| # ============================================================ | |
| MODEL_ID = "telcom/deewaiREALCN" | |
| REVISION = "main" # base model revision for the pipeline | |
| UNLEARN_REPO_ID = MODEL_ID | |
| UNLEARN_REVISION = "main-safe" # branch shown in your screenshot | |
| UNLEARN_FILENAME = "unlearnt/NSFW_wa2.safetensors" # path inside that model repo | |
| def _strip_known_prefixes(k: str) -> str: | |
| for p in ("unet.", "model.unet.", "diffusion_model.", "module.", "state_dict."): | |
| if k.startswith(p): | |
| return k[len(p):] | |
| return k | |
| def _apply_unlearn_patch_to_unet_from_hub( | |
| unet: torch.nn.Module, | |
| hf_token: str = "", | |
| ) -> Dict[str, object]: | |
| enabled = _env_flag("UNLEARN_PATCH_ENABLED", default=True) | |
| alpha_raw = os.getenv("UNLEARN_PATCH_ALPHA", "0.2").strip() | |
| try: | |
| alpha = float(alpha_raw) | |
| except ValueError as exc: | |
| raise ValueError(f"Invalid UNLEARN_PATCH_ALPHA={alpha_raw!r}. Must be a float.") from exc | |
| mode = os.getenv("UNLEARN_PATCH_MODE", "blend").strip().lower() | |
| strict = _env_flag("UNLEARN_PATCH_STRICT", default=False) | |
| # NEW: absolute vs delta interpretation | |
| patch_format = os.getenv("UNLEARN_PATCH_FORMAT", "delta").strip().lower() | |
| if patch_format not in ("absolute", "delta"): | |
| msg = f"Unsupported UNLEARN_PATCH_FORMAT={patch_format!r}. Use 'absolute' or 'delta'." | |
| if strict: | |
| raise ValueError(msg) | |
| patch_format = "absolute" | |
| repo_id = os.getenv("UNLEARN_PATCH_REPO_ID", UNLEARN_REPO_ID).strip() | |
| revision = os.getenv("UNLEARN_PATCH_REVISION", UNLEARN_REVISION).strip() | |
| filename = os.getenv("UNLEARN_PATCH_FILENAME", UNLEARN_FILENAME).strip() | |
| alpha = max(0.0, min(1.0, alpha)) | |
| details: Dict[str, object] = { | |
| "enabled": enabled, | |
| "alpha": alpha, | |
| "mode": mode, | |
| "format": patch_format, | |
| "strict": strict, | |
| "repo_id": repo_id, | |
| "revision": revision, | |
| "filename": filename, | |
| "downloaded_path": "", | |
| "applied": False, | |
| "applied_keys": 0, | |
| "unexpected_keys": 0, | |
| "mismatched_shapes": 0, | |
| "errors": "", | |
| } | |
| if not enabled or alpha <= 0.0: | |
| return details | |
| if mode not in ("blend", "replace"): | |
| msg = f"Unsupported UNLEARN_PATCH_MODE={mode!r}. Use 'blend' or 'replace'." | |
| details["errors"] = msg | |
| if strict: | |
| raise ValueError(msg) | |
| return details | |
| # 1) Download patch from the model repo + revision into HF cache | |
| try: | |
| downloaded_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| revision=revision, | |
| token=hf_token if hf_token else None, | |
| ) | |
| details["downloaded_path"] = downloaded_path | |
| except Exception as e: | |
| msg = f"Failed to download patch from hub: {type(e).__name__}: {e}" | |
| details["errors"] = msg | |
| if strict: | |
| raise | |
| return details | |
| # 2) Apply patch | |
| try: | |
| from safetensors.torch import safe_open | |
| except Exception as e: | |
| msg = f"safetensors not available: {e}" | |
| details["errors"] = msg | |
| if strict: | |
| raise | |
| return details | |
| target_tensors = unet.state_dict() | |
| if not target_tensors: | |
| msg = "UNet state_dict is empty." | |
| details["errors"] = msg | |
| if strict: | |
| raise RuntimeError(msg) | |
| return details | |
| ref_param = next(unet.parameters()) | |
| target_device = ref_param.device | |
| applied = 0 | |
| unexpected = 0 | |
| mismatched = 0 | |
| with torch.no_grad(): | |
| # Keep device=str(target_device) for speed. | |
| # If you hit GPU OOM during patch apply, switch to device="cpu" and move per-tensor. | |
| with safe_open(downloaded_path, framework="pt", device=str(target_device)) as f: | |
| for raw_key in f.keys(): | |
| key = _strip_known_prefixes(raw_key) | |
| if key not in target_tensors: | |
| unexpected += 1 | |
| if strict: | |
| raise KeyError(f"Unexpected key in patch: {raw_key} (mapped to {key})") | |
| continue | |
| patch_tensor = f.get_tensor(raw_key) | |
| tgt = target_tensors[key] | |
| if patch_tensor.shape != tgt.shape: | |
| mismatched += 1 | |
| if strict: | |
| raise ValueError( | |
| f"Shape mismatch for {key}: patch {tuple(patch_tensor.shape)} vs target {tuple(tgt.shape)}" | |
| ) | |
| continue | |
| if patch_tensor.dtype != tgt.dtype: | |
| patch_tensor = patch_tensor.to(dtype=tgt.dtype) | |
| # ============================================================ | |
| # UPDATE RULES | |
| # - format=absolute: | |
| # mode=blend -> new = (1-a)*old + a*patch | |
| # mode=replace -> new = patch | |
| # - format=delta: | |
| # mode=blend -> new = old + a*delta | |
| # mode=replace -> new = old + delta (alpha ignored) | |
| # ============================================================ | |
| if patch_format == "absolute": | |
| if mode == "replace": | |
| new_t = patch_tensor | |
| else: | |
| new_t = (1.0 - alpha) * tgt + alpha * patch_tensor | |
| else: | |
| # delta | |
| if mode == "replace": | |
| new_t = tgt + patch_tensor | |
| else: | |
| new_t = tgt + alpha * patch_tensor | |
| tgt.copy_(new_t) | |
| applied += 1 | |
| details["applied"] = applied > 0 | |
| details["applied_keys"] = applied | |
| details["unexpected_keys"] = unexpected | |
| details["mismatched_shapes"] = mismatched | |
| return details | |
| # ============================================================ | |
| # Auth (optional) | |
| # ============================================================ | |
| HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # ============================================================ | |
| # Device & dtype | |
| # ============================================================ | |
| cuda_available = torch.cuda.is_available() | |
| device = torch.device("cuda" if cuda_available else "cpu") | |
| dtype = torch.float16 if cuda_available else torch.float32 | |
| MAX_IMAGE_SIZE = 1216 if cuda_available else 768 | |
| fallback_msg = "" | |
| if not cuda_available: | |
| fallback_msg = "GPU unavailable. Running in CPU fallback mode." | |
| # ============================================================ | |
| # Load pipelines | |
| # ============================================================ | |
| pipe_txt2img = None | |
| pipe_img2img = None | |
| compel = None | |
| model_loaded = False | |
| load_error = None | |
| unlearn_details: Optional[Dict[str, object]] = None | |
| try: | |
| from_pretrained_kwargs = { | |
| "torch_dtype": dtype, | |
| "use_safetensors": True, | |
| } | |
| if cuda_available: | |
| from_pretrained_kwargs["variant"] = "fp16" | |
| if HF_TOKEN: | |
| from_pretrained_kwargs["token"] = HF_TOKEN | |
| pipe_txt2img = StableDiffusionXLPipeline.from_pretrained( | |
| MODEL_ID, | |
| revision=REVISION, | |
| **from_pretrained_kwargs, | |
| ) | |
| pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipe_txt2img.scheduler.config | |
| ) | |
| pipe_txt2img = pipe_txt2img.to(device) | |
| # Apply the unlearn patch from the model repo (main-safe branch) | |
| # Set UNLEARN_PATCH_FORMAT=delta if your safetensors stores deltas | |
| unlearn_details = _apply_unlearn_patch_to_unet_from_hub( | |
| pipe_txt2img.unet, | |
| hf_token=HF_TOKEN, | |
| ) | |
| # Memory optimisations | |
| pipe_txt2img.enable_vae_slicing() | |
| pipe_txt2img.enable_attention_slicing() | |
| try: | |
| pipe_txt2img.enable_xformers_memory_efficient_attention() | |
| except Exception: | |
| pass | |
| pipe_txt2img.set_progress_bar_config(disable=True) | |
| # img2img pipeline shares weights | |
| pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components) | |
| pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipe_img2img.scheduler.config | |
| ) | |
| pipe_img2img = pipe_img2img.to(device) | |
| compel = CompelForSDXL(pipe_txt2img, device=str(device)) | |
| model_loaded = True | |
| except Exception as e: | |
| load_error = repr(e) | |
| model_loaded = False | |
| # ============================================================ | |
| # Utility: error image | |
| # ============================================================ | |
| def make_error_image(w, h): | |
| return Image.new("RGB", (w, h), (18, 18, 22)) | |
| def _format_unlearn_details(d: Optional[Dict[str, object]]) -> str: | |
| if not d: | |
| return "Unlearn patch: (no info)" | |
| lines = [ | |
| f"enabled: {d.get('enabled')}", | |
| f"repo_id: {d.get('repo_id')}", | |
| f"revision: {d.get('revision')}", | |
| f"filename: {d.get('filename')}", | |
| f"downloaded_path: {d.get('downloaded_path')}", | |
| f"format: {d.get('format')}", | |
| f"mode: {d.get('mode')} | alpha: {d.get('alpha')} | strict: {d.get('strict')}", | |
| f"applied: {d.get('applied')} | applied_keys: {d.get('applied_keys')} | " | |
| f"unexpected_keys: {d.get('unexpected_keys')} | mismatched_shapes: {d.get('mismatched_shapes')}", | |
| ] | |
| if d.get("errors"): | |
| lines.append(f"errors: {d.get('errors')}") | |
| return "\n".join(lines) | |
| # ============================================================ | |
| # Inference function | |
| # ============================================================ | |
| def _infer_impl( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| init_image, | |
| strength, | |
| ): | |
| width = int(width) | |
| height = int(height) | |
| seed = int(seed) | |
| if not model_loaded: | |
| return make_error_image(width, height), f"Model load failed: {load_error}" | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| status = f"Seed: {seed}" | |
| if fallback_msg: | |
| status += f" | {fallback_msg}" | |
| try: | |
| with torch.inference_mode(): | |
| conditioning = compel(prompt, negative_prompt=negative_prompt) | |
| common_kwargs = dict( | |
| prompt_embeds=conditioning.embeds, | |
| pooled_prompt_embeds=conditioning.pooled_embeds, | |
| negative_prompt_embeds=conditioning.negative_embeds, | |
| negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(num_inference_steps), | |
| generator=generator, | |
| ) | |
| if device.type == "cuda": | |
| with torch.autocast("cuda", dtype=dtype): | |
| if init_image is not None: | |
| image = pipe_img2img( | |
| image=init_image, | |
| strength=float(strength), | |
| **common_kwargs, | |
| ).images[0] | |
| else: | |
| image = pipe_txt2img( | |
| width=width, | |
| height=height, | |
| **common_kwargs, | |
| ).images[0] | |
| else: | |
| if init_image is not None: | |
| image = pipe_img2img( | |
| image=init_image, | |
| strength=float(strength), | |
| **common_kwargs, | |
| ).images[0] | |
| else: | |
| image = pipe_txt2img( | |
| width=width, | |
| height=height, | |
| **common_kwargs, | |
| ).images[0] | |
| return image, status | |
| except Exception as e: | |
| return make_error_image(width, height), f"Error: {type(e).__name__}: {e}" | |
| finally: | |
| gc.collect() | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| if SPACES_AVAILABLE: | |
| def infer( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| init_image, | |
| strength, | |
| ): | |
| return _infer_impl( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| init_image, | |
| strength, | |
| ) | |
| else: | |
| def infer( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| init_image, | |
| strength, | |
| ): | |
| return _infer_impl( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| init_image, | |
| strength, | |
| ) | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| CSS = """ | |
| body { | |
| background: #000; | |
| color: #fff; | |
| } | |
| """ | |
| with gr.Blocks(title="SDXL txt2img + img2img") as demo: | |
| gr.HTML(f"<style>{CSS}</style>") | |
| if fallback_msg: | |
| gr.Markdown(f"**{fallback_msg}**") | |
| if not model_loaded: | |
| gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}") | |
| else: | |
| gr.Markdown("### Unlearn patch status") | |
| gr.Markdown(f"```\n{_format_unlearn_details(unlearn_details)}\n```") | |
| gr.Markdown( | |
| "Tip: set `UNLEARN_PATCH_FORMAT=delta` in Space env vars if the safetensors stores deltas.\n" | |
| "Also you can override source with:\n" | |
| "`UNLEARN_PATCH_REPO_ID`, `UNLEARN_PATCH_REVISION`, `UNLEARN_PATCH_FILENAME`." | |
| ) | |
| gr.Markdown("## SDXL Generator (txt2img + img2img)") | |
| prompt = gr.Textbox(label="Prompt", lines=2) | |
| init_image = gr.Image(label="Initial image (optional)", type="pil") | |
| run_button = gr.Button("Generate") | |
| result = gr.Image(label="Result") | |
| status = gr.Markdown("") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt = gr.Textbox(label="Negative prompt") | |
| seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed") | |
| randomize_seed = gr.Checkbox(value=True, label="Randomize seed") | |
| width = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Width") | |
| height = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Height") | |
| guidance_scale = gr.Slider(0, 20, step=0.1, value=7, label="Guidance scale") | |
| num_inference_steps = gr.Slider(1, 40, step=1, value=20, label="Steps") | |
| strength = gr.Slider(0.0, 1.0, step=0.05, value=0.7, label="Image strength") | |
| run_button.click( | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| init_image, | |
| strength, | |
| ], | |
| outputs=[result, status], | |
| ) | |
| demo.queue().launch(ssr_mode=False) | |