Spaces:
Running on Zero
Running on Zero
| import os | |
| import random | |
| import gc | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from diffusers import ( | |
| StableDiffusionXLPipeline, | |
| StableDiffusionXLImg2ImgPipeline, | |
| EulerAncestralDiscreteScheduler, | |
| ) | |
| from huggingface_hub import login | |
| # ============================================================ | |
| # GPU decorator (optional) | |
| # ============================================================ | |
| try: | |
| import spaces | |
| GPU_DECORATOR = spaces.GPU | |
| except Exception: | |
| def GPU_DECORATOR(fn): | |
| return fn | |
| from compel import CompelForSDXL | |
| MODEL_ID = "telcom/dee-unlearning-tiny-sd" | |
| REVISION="main" | |
| HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| # ============================================================ | |
| # Detect device | |
| # ============================================================ | |
| cuda_available = torch.cuda.is_available() | |
| device = torch.device("cuda" if cuda_available else "cpu") | |
| dtype = torch.float16 if cuda_available else torch.float32 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1216 if cuda_available else 768 # CPU smaller | |
| pipe_txt2img = None | |
| pipe_img2img = None | |
| compel = None | |
| model_loaded = False | |
| load_error = None | |
| fallback_msg = "" | |
| # ============================================================ | |
| # Load model (txt2img + img2img sharing weights) | |
| # ============================================================ | |
| try: | |
| from_pretrained_kwargs = dict( | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| ) | |
| if cuda_available: | |
| from_pretrained_kwargs["variant"] = "fp16" | |
| if HF_TOKEN: | |
| from_pretrained_kwargs["token"] = HF_TOKEN | |
| # Base txt2img pipeline revision=REVISION, | |
| pipe_txt2img = StableDiffusionXLPipeline.from_pretrained( | |
| MODEL_ID, revision=REVISION, **from_pretrained_kwargs | |
| ) | |
| pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipe_txt2img.scheduler.config | |
| ) | |
| pipe_txt2img = pipe_txt2img.to(device) | |
| # Memory opts | |
| try: | |
| pipe_txt2img.enable_vae_slicing() | |
| except Exception: | |
| pass | |
| try: | |
| pipe_txt2img.enable_attention_slicing() | |
| except Exception: | |
| pass | |
| try: | |
| pipe_txt2img.enable_xformers_memory_efficient_attention() | |
| except Exception: | |
| pass | |
| pipe_txt2img.set_progress_bar_config(disable=True) | |
| # Create img2img pipeline from txt2img components (no extra weights) | |
| pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components) | |
| pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipe_img2img.scheduler.config | |
| ) | |
| pipe_img2img = pipe_img2img.to(device) | |
| try: | |
| compel = CompelForSDXL(pipe_txt2img, device=str(device)) | |
| except TypeError: | |
| compel = CompelForSDXL(pipe_txt2img) | |
| model_loaded = True | |
| except Exception as e: | |
| load_error = repr(e) | |
| model_loaded = False | |
| if not cuda_available: | |
| fallback_msg = "GPU unavailable. Running in CPU fallback mode (slower, smaller images)." | |
| # ============================================================ | |
| # Error image | |
| # ============================================================ | |
| def _make_error_image(w, h, text): | |
| img = Image.new("RGB", (w, h), (18, 18, 22)) | |
| return img | |
| # ============================================================ | |
| # Inference (txt2img or img2img depending on init_image) | |
| # ============================================================ | |
| def infer( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| init_image, # new: optional image | |
| strength, # new: img2img strength | |
| ): | |
| width = int(width) | |
| height = int(height) | |
| seed = int(seed) | |
| if not model_loaded or pipe_txt2img is None or pipe_img2img is None or compel is None: | |
| msg = "Model failed to load." | |
| if load_error: | |
| msg += f" (details: {load_error})" | |
| return _make_error_image(width, height, msg), msg | |
| # Randomize seed if requested | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| if device.type == "cuda": | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| else: | |
| generator = torch.Generator().manual_seed(seed) | |
| status = f"Seed: {seed}" | |
| if fallback_msg: | |
| status += f" | {fallback_msg}" | |
| try: | |
| with torch.inference_mode(): | |
| conditioning = compel(prompt, negative_prompt=negative_prompt) | |
| common_kwargs = dict( | |
| prompt_embeds=conditioning.embeds, | |
| pooled_prompt_embeds=conditioning.pooled_embeds, | |
| negative_prompt_embeds=conditioning.negative_embeds, | |
| negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds, | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(num_inference_steps), | |
| generator=generator, | |
| ) | |
| if device.type == "cuda": | |
| with torch.autocast("cuda", dtype=dtype): | |
| # If init_image is provided, use img2img | |
| if init_image is not None: | |
| image = pipe_img2img( | |
| image=init_image, | |
| strength=float(strength), | |
| **common_kwargs, | |
| ).images[0] | |
| else: | |
| image = pipe_txt2img( | |
| width=width, | |
| height=height, | |
| **common_kwargs, | |
| ).images[0] | |
| else: | |
| if init_image is not None: | |
| image = pipe_img2img( | |
| image=init_image, | |
| strength=float(strength), | |
| **common_kwargs, | |
| ).images[0] | |
| else: | |
| image = pipe_txt2img( | |
| width=width, | |
| height=height, | |
| **common_kwargs, | |
| ).images[0] | |
| return image, status | |
| except Exception as e: | |
| msg = f"Error during generation: {type(e).__name__}: {e}" | |
| return _make_error_image(width, height, msg), msg | |
| finally: | |
| gc.collect() | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| CSS = """ | |
| body{ | |
| background:#000; | |
| color:#fff; | |
| } | |
| """ | |
| with gr.Blocks(title="Text to Image / Image to Image") as demo: | |
| gr.HTML(f"<style>{CSS}</style>") | |
| with gr.Column(): | |
| # banner first | |
| if fallback_msg: | |
| gr.Markdown(f"**{fallback_msg}**") | |
| if not model_loaded: | |
| gr.Markdown( | |
| f"⚠️ **Model failed to load.**\n\nDetails: {load_error}", | |
| elem_classes=["small-note"], | |
| ) | |
| gr.Markdown("## SDXL Generator (txt2img + img2img)") | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt...", | |
| lines=2, | |
| ) | |
| # NEW: optional initial image for img2img | |
| init_image = gr.Image( | |
| label="Initial image (optional)", | |
| type="pil", | |
| ) | |
| run_button = gr.Button("Generate") | |
| result = gr.Image(label="Result") | |
| status = gr.Markdown("") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt = gr.Textbox(label="Negative prompt", value="") | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512) | |
| height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512) | |
| guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=7) | |
| num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=40, step=1, value=20) | |
| # NEW: strength for img2img | |
| strength = gr.Slider( | |
| label="Image strength (for img2img)", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.7, | |
| ) | |
| run_button.click( | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| init_image, | |
| strength, | |
| ], | |
| outputs=[result, status], | |
| ) | |
| demo.queue().launch(ssr_mode=False) |