import os import math import torch import spaces import gradio as gr import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance from functools import lru_cache # ============================================================ # 🎨 AI Image Editor - Powered by HuggingFace Models # Features: # 1. ✏️ Instruction-based Editing (InstructPix2Pix) # 2. 🖌️ Inpainting (SDXL Inpainting) # 3. ✂️ Background Removal (RMBG-2.0 / BiRefNet) # 4. 🔍 Image Upscaling (Swin2SR) # ============================================================ # --- Global model holders (lazy loaded) --- _edit_pipe = None _inpaint_pipe = None _rmbg_model = None _rmbg_transform = None _upscale_processor = None _upscale_model = None def get_edit_pipe(): """Lazy load InstructPix2Pix pipeline""" global _edit_pipe if _edit_pipe is None: from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler _edit_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None, ).to("cuda") _edit_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( _edit_pipe.scheduler.config ) return _edit_pipe def get_inpaint_pipe(): """Lazy load SDXL Inpainting pipeline""" global _inpaint_pipe if _inpaint_pipe is None: from diffusers import AutoPipelineForInpainting _inpaint_pipe = AutoPipelineForInpainting.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16", ).to("cuda") return _inpaint_pipe def get_rmbg_model(): """Lazy load background removal model""" global _rmbg_model, _rmbg_transform if _rmbg_model is None: from torchvision import transforms from transformers import AutoModelForImageSegmentation _rmbg_model = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ).to("cuda").eval() _rmbg_transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) return _rmbg_model, _rmbg_transform def get_upscale_model(): """Lazy load Swin2SR upscaling model""" global _upscale_processor, _upscale_model if _upscale_model is None: from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution _upscale_processor = AutoImageProcessor.from_pretrained( "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" ) _upscale_model = Swin2SRForImageSuperResolution.from_pretrained( "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" ).to("cuda").eval() return _upscale_processor, _upscale_model # ============================================================ # Feature 1: Instruction-based Editing # ============================================================ @spaces.GPU(duration=120) def instruct_edit(input_image, instruction, text_cfg, image_cfg, steps, seed): if input_image is None: raise gr.Error("⚠️ Vui lòng upload ảnh trước!") if not instruction or instruction.strip() == "": raise gr.Error("⚠️ Vui lòng nhập lệnh chỉnh sửa!") pipe = get_edit_pipe() # Resize to be compatible with the model (multiples of 64) width, height = input_image.size factor = 512 / max(width, height) factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) w = int((width * factor) // 64) * 64 h = int((height * factor) // 64) * 64 input_resized = ImageOps.fit(input_image, (w, h), method=Image.Resampling.LANCZOS) generator = torch.Generator("cuda").manual_seed(int(seed)) result = pipe( instruction, image=input_resized, guidance_scale=text_cfg, image_guidance_scale=image_cfg, num_inference_steps=int(steps), generator=generator, ).images[0] # Resize back to original result = result.resize(input_image.size, Image.Resampling.LANCZOS) return result # ============================================================ # Feature 2: Inpainting # ============================================================ @spaces.GPU(duration=120) def inpaint(input_dict, prompt, negative_prompt, guidance_scale, steps, strength, seed): if input_dict is None: raise gr.Error("⚠️ Vui lòng upload ảnh và vẽ mask!") # Extract image and mask from ImageEditor init_image = input_dict["background"].convert("RGB") # Get mask from the drawn layer if len(input_dict["layers"]) > 0: mask_layer = input_dict["layers"][0] # The alpha channel of the layer IS the mask if mask_layer.mode == "RGBA": mask = mask_layer.getchannel("A") else: mask = mask_layer.convert("L") else: raise gr.Error("⚠️ Vui lòng vẽ vùng cần chỉnh sửa trên ảnh!") # Check if mask has any painted area mask_array = np.array(mask) if mask_array.max() == 0: raise gr.Error("⚠️ Vui lòng vẽ vùng cần chỉnh sửa (brush) trên ảnh!") if not prompt or prompt.strip() == "": raise gr.Error("⚠️ Vui lòng nhập mô tả nội dung muốn tạo!") pipe = get_inpaint_pipe() # Resize to 1024x1024 for SDXL init_resized = init_image.resize((1024, 1024), Image.Resampling.LANCZOS) mask_resized = mask.resize((1024, 1024), Image.Resampling.LANCZOS) generator = torch.Generator("cuda").manual_seed(int(seed)) output = pipe( prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, image=init_resized, mask_image=mask_resized, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength, generator=generator, ).images[0] # Resize back output = output.resize(init_image.size, Image.Resampling.LANCZOS) return (init_image, output) # ============================================================ # Feature 3: Background Removal # ============================================================ @spaces.GPU(duration=60) def remove_background(input_image): if input_image is None: raise gr.Error("⚠️ Vui lòng upload ảnh!") model, transform = get_rmbg_model() # Prepare input input_tensor = transform(input_image).unsqueeze(0).to("cuda") # Inference with torch.no_grad(): preds = model(input_tensor)[-1].sigmoid().cpu() # Create mask pred_pil = Image.fromarray((preds[0].squeeze().numpy() * 255).astype(np.uint8)) mask = pred_pil.resize(input_image.size, Image.Resampling.LANCZOS) # Apply alpha channel result = input_image.copy().convert("RGBA") result.putalpha(mask) return result # ============================================================ # Feature 4: Image Upscaling # ============================================================ @spaces.GPU(duration=120) def upscale_image(input_image, scale_factor): if input_image is None: raise gr.Error("⚠️ Vui lòng upload ảnh!") processor, model = get_upscale_model() # Limit input size to prevent OOM max_dim = 512 w, h = input_image.size if max(w, h) > max_dim: ratio = max_dim / max(w, h) new_w = int(w * ratio) new_h = int(h * ratio) input_image_resized = input_image.resize((new_w, new_h), Image.Resampling.LANCZOS) else: input_image_resized = input_image inputs = processor(input_image_resized, return_tensors="pt").to("cuda") with torch.no_grad(): outputs = model(**inputs) output = outputs.reconstruction.data.squeeze().float().cpu().clamp(0, 1).numpy() output = np.moveaxis(output, 0, -1) # CHW -> HWC output = (output * 255.0).round().astype(np.uint8) result = Image.fromarray(output) # If user wants 2x instead of 4x, resize down if scale_factor == "2x": target_w = input_image.size[0] * 2 target_h = input_image.size[1] * 2 result = result.resize((target_w, target_h), Image.Resampling.LANCZOS) orig_size = f"{input_image.size[0]}×{input_image.size[1]}" new_size = f"{result.size[0]}×{result.size[1]}" info = f"📐 Gốc: {orig_size} → Kết quả: {new_size}" return result, info # ============================================================ # Feature 5: Basic adjustments (no GPU needed) # ============================================================ def basic_adjust(input_image, brightness, contrast, saturation, sharpness, blur_radius): if input_image is None: raise gr.Error("⚠️ Vui lòng upload ảnh!") img = input_image.copy() if brightness != 1.0: img = ImageEnhance.Brightness(img).enhance(brightness) if contrast != 1.0: img = ImageEnhance.Contrast(img).enhance(contrast) if saturation != 1.0: img = ImageEnhance.Color(img).enhance(saturation) if sharpness != 1.0: img = ImageEnhance.Sharpness(img).enhance(sharpness) if blur_radius > 0: img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) return img # ============================================================ # Custom CSS # ============================================================ css = """ .gradio-container { max-width: 1200px !important; margin: auto !important; } .main-header { text-align: center; padding: 20px 0; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em; font-weight: bold; } .sub-header { text-align: center; color: #666; margin-bottom: 20px; } .feature-icon { font-size: 1.3em; } footer { text-align: center; padding: 20px; color: #999; } """ # ============================================================ # Gradio UI # ============================================================ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="🎨 AI Image Editor") as demo: gr.HTML("""