import os import random import gc import gradio as gr import numpy as np from PIL import Image try: import spaces GPU_DECORATOR = spaces.GPU except Exception: def GPU_DECORATOR(fn): return fn import torch from diffusers import ( StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, EulerAncestralDiscreteScheduler, ) from transformers import CLIPTokenizer, CLIPTextModel from huggingface_hub import login # ============================================================ # Config # ============================================================ MODEL_ID = "telcom/dee-unlearning-tiny-sd" REVISION = "main" HF_TOKEN = os.getenv("HF_TOKEN", "").strip() if HF_TOKEN: login(token=HF_TOKEN) cuda_available = torch.cuda.is_available() device = torch.device("cuda" if cuda_available else "cpu") dtype = torch.float16 if cuda_available else torch.float32 MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 768 if not cuda_available else 1024 pipe_txt2img = None pipe_img2img = None model_loaded = False load_error = None # ============================================================ # Load model (FORCED tokenizer fix) # ============================================================ try: pipe_txt2img = StableDiffusionPipeline.from_pretrained( MODEL_ID, revision=REVISION, torch_dtype=dtype, safety_checker=None, ).to(device) # 🔑 FORCE tokenizer + text encoder pipe_txt2img.tokenizer = CLIPTokenizer.from_pretrained( MODEL_ID, subfolder="tokenizer" ) pipe_txt2img.text_encoder = CLIPTextModel.from_pretrained( MODEL_ID, subfolder="text_encoder", torch_dtype=dtype, ).to(device) # Scheduler pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe_txt2img.scheduler.config ) # Memory optimisations try: pipe_txt2img.enable_attention_slicing() pipe_txt2img.enable_vae_slicing() except Exception: pass try: pipe_txt2img.enable_xformers_memory_efficient_attention() except Exception: pass pipe_txt2img.set_progress_bar_config(disable=True) # Img2Img pipeline (share components) pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_txt2img.components).to(device) pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe_img2img.scheduler.config ) # Defensive checks assert pipe_txt2img.tokenizer is not None assert pipe_txt2img.text_encoder is not None model_loaded = True except Exception as e: load_error = repr(e) model_loaded = False # ============================================================ # Helpers # ============================================================ def _make_error_image(w, h): return Image.new("RGB", (w, h), (30, 30, 40)) # ============================================================ # Inference # ============================================================ def infer( prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, init_image, strength, ): width = int(width) height = int(height) if not model_loaded: return _make_error_image(width, height), load_error if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) try: with torch.inference_mode(): if init_image is not None: image = pipe_img2img( prompt=prompt, negative_prompt=negative_prompt, image=init_image, strength=float(strength), guidance_scale=float(guidance_scale), num_inference_steps=int(num_inference_steps), generator=generator, ).images[0] else: image = pipe_txt2img( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, guidance_scale=float(guidance_scale), num_inference_steps=int(num_inference_steps), generator=generator, ).images[0] return image, f"Seed: {seed}" except Exception as e: return _make_error_image(width, height), str(e) finally: gc.collect() if device.type == "cuda": torch.cuda.empty_cache() # ============================================================ # UI # ============================================================ with gr.Blocks(title="Stable Diffusion (Unlearning Model)") as demo: gr.Markdown("## Stable Diffusion Generator") if not model_loaded: gr.Markdown(f"⚠️ **Model failed to load**\n\n{load_error}") prompt = gr.Textbox(label="Prompt", lines=2) init_image = gr.Image(label="Initial image (optional)", type="pil") run_button = gr.Button("Generate") result = gr.Image(label="Result") status = gr.Markdown("") with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Textbox(label="Negative prompt", value="") seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed") randomize_seed = gr.Checkbox(True, label="Randomize seed") width = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Width") height = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Height") guidance_scale = gr.Slider(1, 20, step=0.5, value=7.5, label="Guidance scale") num_inference_steps = gr.Slider(1, 40, step=1, value=20, label="Steps") strength = gr.Slider(0.0, 1.0, step=0.05, value=0.7, label="Image strength") run_button.click( fn=infer, inputs=[ prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, init_image, strength, ], outputs=[result, status], ) demo.queue().launch(ssr_mode=False)