import os import math import torch import spaces import gradio as gr import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance from functools import lru_cache # ============================================================ # 🎨 AI Image Editor - Powered by HuggingFace Models # Features: # 1. ✏️ Instruction-based Editing (InstructPix2Pix) # 2. 🖌️ Inpainting (SDXL Inpainting) # 3. ✂️ Background Removal (RMBG-2.0 / BiRefNet) # 4. 🔍 Image Upscaling (Swin2SR) # ============================================================ # --- Global model holders (lazy loaded) --- _edit_pipe = None _inpaint_pipe = None _rmbg_model = None _rmbg_transform = None _upscale_processor = None _upscale_model = None def get_edit_pipe(): """Lazy load InstructPix2Pix pipeline""" global _edit_pipe if _edit_pipe is None: from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler _edit_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None, ).to("cuda") _edit_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( _edit_pipe.scheduler.config ) return _edit_pipe def get_inpaint_pipe(): """Lazy load SDXL Inpainting pipeline""" global _inpaint_pipe if _inpaint_pipe is None: from diffusers import AutoPipelineForInpainting _inpaint_pipe = AutoPipelineForInpainting.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16", ).to("cuda") return _inpaint_pipe def get_rmbg_model(): """Lazy load background removal model""" global _rmbg_model, _rmbg_transform if _rmbg_model is None: from torchvision import transforms from transformers import AutoModelForImageSegmentation _rmbg_model = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ).to("cuda").eval() _rmbg_transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) return _rmbg_model, _rmbg_transform def get_upscale_model(): """Lazy load Swin2SR upscaling model""" global _upscale_processor, _upscale_model if _upscale_model is None: from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution _upscale_processor = AutoImageProcessor.from_pretrained( "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" ) _upscale_model = Swin2SRForImageSuperResolution.from_pretrained( "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" ).to("cuda").eval() return _upscale_processor, _upscale_model # ============================================================ # Feature 1: Instruction-based Editing # ============================================================ @spaces.GPU(duration=120) def instruct_edit(input_image, instruction, text_cfg, image_cfg, steps, seed): if input_image is None: raise gr.Error("⚠️ Vui lòng upload ảnh trước!") if not instruction or instruction.strip() == "": raise gr.Error("⚠️ Vui lòng nhập lệnh chỉnh sửa!") pipe = get_edit_pipe() # Resize to be compatible with the model (multiples of 64) width, height = input_image.size factor = 512 / max(width, height) factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) w = int((width * factor) // 64) * 64 h = int((height * factor) // 64) * 64 input_resized = ImageOps.fit(input_image, (w, h), method=Image.Resampling.LANCZOS) generator = torch.Generator("cuda").manual_seed(int(seed)) result = pipe( instruction, image=input_resized, guidance_scale=text_cfg, image_guidance_scale=image_cfg, num_inference_steps=int(steps), generator=generator, ).images[0] # Resize back to original result = result.resize(input_image.size, Image.Resampling.LANCZOS) return result # ============================================================ # Feature 2: Inpainting # ============================================================ @spaces.GPU(duration=120) def inpaint(input_dict, prompt, negative_prompt, guidance_scale, steps, strength, seed): if input_dict is None: raise gr.Error("⚠️ Vui lòng upload ảnh và vẽ mask!") # Extract image and mask from ImageEditor init_image = input_dict["background"].convert("RGB") # Get mask from the drawn layer if len(input_dict["layers"]) > 0: mask_layer = input_dict["layers"][0] # The alpha channel of the layer IS the mask if mask_layer.mode == "RGBA": mask = mask_layer.getchannel("A") else: mask = mask_layer.convert("L") else: raise gr.Error("⚠️ Vui lòng vẽ vùng cần chỉnh sửa trên ảnh!") # Check if mask has any painted area mask_array = np.array(mask) if mask_array.max() == 0: raise gr.Error("⚠️ Vui lòng vẽ vùng cần chỉnh sửa (brush) trên ảnh!") if not prompt or prompt.strip() == "": raise gr.Error("⚠️ Vui lòng nhập mô tả nội dung muốn tạo!") pipe = get_inpaint_pipe() # Resize to 1024x1024 for SDXL init_resized = init_image.resize((1024, 1024), Image.Resampling.LANCZOS) mask_resized = mask.resize((1024, 1024), Image.Resampling.LANCZOS) generator = torch.Generator("cuda").manual_seed(int(seed)) output = pipe( prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, image=init_resized, mask_image=mask_resized, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength, generator=generator, ).images[0] # Resize back output = output.resize(init_image.size, Image.Resampling.LANCZOS) return (init_image, output) # ============================================================ # Feature 3: Background Removal # ============================================================ @spaces.GPU(duration=60) def remove_background(input_image): if input_image is None: raise gr.Error("⚠️ Vui lòng upload ảnh!") model, transform = get_rmbg_model() # Prepare input input_tensor = transform(input_image).unsqueeze(0).to("cuda") # Inference with torch.no_grad(): preds = model(input_tensor)[-1].sigmoid().cpu() # Create mask pred_pil = Image.fromarray((preds[0].squeeze().numpy() * 255).astype(np.uint8)) mask = pred_pil.resize(input_image.size, Image.Resampling.LANCZOS) # Apply alpha channel result = input_image.copy().convert("RGBA") result.putalpha(mask) return result # ============================================================ # Feature 4: Image Upscaling # ============================================================ @spaces.GPU(duration=120) def upscale_image(input_image, scale_factor): if input_image is None: raise gr.Error("⚠️ Vui lòng upload ảnh!") processor, model = get_upscale_model() # Limit input size to prevent OOM max_dim = 512 w, h = input_image.size if max(w, h) > max_dim: ratio = max_dim / max(w, h) new_w = int(w * ratio) new_h = int(h * ratio) input_image_resized = input_image.resize((new_w, new_h), Image.Resampling.LANCZOS) else: input_image_resized = input_image inputs = processor(input_image_resized, return_tensors="pt").to("cuda") with torch.no_grad(): outputs = model(**inputs) output = outputs.reconstruction.data.squeeze().float().cpu().clamp(0, 1).numpy() output = np.moveaxis(output, 0, -1) # CHW -> HWC output = (output * 255.0).round().astype(np.uint8) result = Image.fromarray(output) # If user wants 2x instead of 4x, resize down if scale_factor == "2x": target_w = input_image.size[0] * 2 target_h = input_image.size[1] * 2 result = result.resize((target_w, target_h), Image.Resampling.LANCZOS) orig_size = f"{input_image.size[0]}×{input_image.size[1]}" new_size = f"{result.size[0]}×{result.size[1]}" info = f"📐 Gốc: {orig_size} → Kết quả: {new_size}" return result, info # ============================================================ # Feature 5: Basic adjustments (no GPU needed) # ============================================================ def basic_adjust(input_image, brightness, contrast, saturation, sharpness, blur_radius): if input_image is None: raise gr.Error("⚠️ Vui lòng upload ảnh!") img = input_image.copy() if brightness != 1.0: img = ImageEnhance.Brightness(img).enhance(brightness) if contrast != 1.0: img = ImageEnhance.Contrast(img).enhance(contrast) if saturation != 1.0: img = ImageEnhance.Color(img).enhance(saturation) if sharpness != 1.0: img = ImageEnhance.Sharpness(img).enhance(sharpness) if blur_radius > 0: img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) return img # ============================================================ # Custom CSS # ============================================================ css = """ .gradio-container { max-width: 1200px !important; margin: auto !important; } .main-header { text-align: center; padding: 20px 0; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em; font-weight: bold; } .sub-header { text-align: center; color: #666; margin-bottom: 20px; } .feature-icon { font-size: 1.3em; } footer { text-align: center; padding: 20px; color: #999; } """ # ============================================================ # Gradio UI # ============================================================ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="🎨 AI Image Editor") as demo: gr.HTML("""
🎨 AI Image Editor
Chỉnh sửa ảnh thông minh với AI | Powered by HuggingFace 🤗
""") # === Tab 1: Instruction Edit === with gr.Tab("✏️ Chỉnh sửa bằng lệnh"): gr.Markdown(""" ### Mô tả thay đổi bạn muốn bằng tiếng Anh Ví dụ: *"make it snowy"*, *"turn the sky to sunset"*, *"add sunglasses"*, *"make it a watercolor painting"* """) with gr.Row(): with gr.Column(scale=1): edit_input = gr.Image(type="pil", label="📷 Ảnh gốc", height=400) edit_instruction = gr.Textbox( label="✏️ Lệnh chỉnh sửa (tiếng Anh)", placeholder="e.g. make it look like a painting...", lines=2, ) with gr.Accordion("⚙️ Cài đặt nâng cao", open=False): edit_text_cfg = gr.Slider( label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.5, info="Mức độ tuân theo lệnh (cao = thay đổi nhiều hơn)" ) edit_image_cfg = gr.Slider( label="Image Guidance Scale", minimum=0.5, maximum=3.0, value=1.5, step=0.1, info="Mức độ giữ lại ảnh gốc (cao = giữ nhiều hơn)" ) edit_steps = gr.Slider( label="Số bước", minimum=10, maximum=100, value=30, step=5, ) edit_seed = gr.Number(label="Seed", value=42, precision=0) edit_btn = gr.Button("🚀 Chỉnh sửa", variant="primary", size="lg") with gr.Column(scale=1): edit_output = gr.Image(label="🖼️ Kết quả", height=400) edit_btn.click( fn=instruct_edit, inputs=[edit_input, edit_instruction, edit_text_cfg, edit_image_cfg, edit_steps, edit_seed], outputs=edit_output, ) gr.Examples( examples=[ ["https://raw.githubusercontent.com/timbrooks/instruct-pix2pix/main/imgs/example.jpg", "turn him into a cyborg"], ["https://raw.githubusercontent.com/timbrooks/instruct-pix2pix/main/imgs/example.jpg", "make it a watercolor painting"], ], inputs=[edit_input, edit_instruction], label="💡 Ví dụ", ) # === Tab 2: Inpainting === with gr.Tab("🖌️ Inpainting (Tô vẽ lại)"): gr.Markdown(""" ### Vẽ lên vùng cần thay đổi, sau đó mô tả nội dung mới Dùng **brush** để tô lên vùng muốn chỉnh sửa → nhập mô tả → nhấn Inpaint """) with gr.Row(): with gr.Column(scale=1): inpaint_editor = gr.ImageEditor( type="pil", label="🖌️ Vẽ mask lên ảnh", height=450, brush=gr.Brush( colors=["#FFFFFF"], default_color="#FFFFFF", color_mode="fixed", default_size=30, ), eraser=gr.Eraser(default_size=30), layers=True, ) inpaint_prompt = gr.Textbox( label="📝 Mô tả nội dung mới (tiếng Anh)", placeholder="e.g. a cute cat sitting here...", lines=2, ) inpaint_neg = gr.Textbox( label="🚫 Negative prompt (tùy chọn)", placeholder="e.g. blurry, low quality, distorted...", lines=1, ) with gr.Accordion("⚙️ Cài đặt nâng cao", open=False): inpaint_cfg = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=20.0, value=8.0, step=0.5, ) inpaint_steps = gr.Slider( label="Số bước", minimum=10, maximum=50, value=25, step=5, ) inpaint_strength = gr.Slider( label="Strength", minimum=0.5, maximum=1.0, value=0.99, step=0.01, info="Mức độ thay đổi (1.0 = thay đổi hoàn toàn)", ) inpaint_seed = gr.Number(label="Seed", value=42, precision=0) inpaint_btn = gr.Button("🎨 Inpaint", variant="primary", size="lg") with gr.Column(scale=1): inpaint_output = gr.ImageSlider( label="📊 So sánh Trước / Sau", height=450, ) inpaint_btn.click( fn=inpaint, inputs=[inpaint_editor, inpaint_prompt, inpaint_neg, inpaint_cfg, inpaint_steps, inpaint_strength, inpaint_seed], outputs=inpaint_output, ) # === Tab 3: Background Removal === with gr.Tab("✂️ Xóa nền"): gr.Markdown(""" ### Xóa nền ảnh tự động bằng AI Upload ảnh → nhận ảnh nền trong suốt (PNG) """) with gr.Row(): with gr.Column(scale=1): bg_input = gr.Image(type="pil", label="📷 Ảnh gốc", height=400) bg_btn = gr.Button("✂️ Xóa nền", variant="primary", size="lg") with gr.Column(scale=1): bg_output = gr.Image(label="🖼️ Kết quả (nền trong suốt)", height=400) bg_btn.click(fn=remove_background, inputs=bg_input, outputs=bg_output) # === Tab 4: Image Upscaling === with gr.Tab("🔍 Phóng to ảnh"): gr.Markdown(""" ### Phóng to ảnh chất lượng cao với AI Tăng độ phân giải ảnh lên 2x hoặc 4x mà không bị mờ """) with gr.Row(): with gr.Column(scale=1): upscale_input = gr.Image(type="pil", label="📷 Ảnh gốc", height=400) upscale_factor = gr.Radio( choices=["2x", "4x"], value="4x", label="📏 Mức phóng to", ) upscale_btn = gr.Button("🔍 Phóng to", variant="primary", size="lg") with gr.Column(scale=1): upscale_output = gr.Image(label="🖼️ Kết quả", height=400) upscale_info = gr.Textbox(label="📐 Thông tin", interactive=False) upscale_btn.click( fn=upscale_image, inputs=[upscale_input, upscale_factor], outputs=[upscale_output, upscale_info], ) # === Tab 5: Basic Adjustments === with gr.Tab("🎚️ Chỉnh sửa cơ bản"): gr.Markdown(""" ### Điều chỉnh các thông số cơ bản Độ sáng, tương phản, bão hòa, sắc nét, làm mờ """) with gr.Row(): with gr.Column(scale=1): adj_input = gr.Image(type="pil", label="📷 Ảnh gốc", height=400) adj_brightness = gr.Slider(label="☀️ Độ sáng", minimum=0.1, maximum=3.0, value=1.0, step=0.05) adj_contrast = gr.Slider(label="🔲 Tương phản", minimum=0.1, maximum=3.0, value=1.0, step=0.05) adj_saturation = gr.Slider(label="🎨 Bão hòa", minimum=0.0, maximum=3.0, value=1.0, step=0.05) adj_sharpness = gr.Slider(label="🔪 Sắc nét", minimum=0.0, maximum=3.0, value=1.0, step=0.05) adj_blur = gr.Slider(label="💨 Làm mờ", minimum=0, maximum=10, value=0, step=0.5) adj_btn = gr.Button("✨ Áp dụng", variant="primary", size="lg") with gr.Column(scale=1): adj_output = gr.Image(label="🖼️ Kết quả", height=400) adj_btn.click( fn=basic_adjust, inputs=[adj_input, adj_brightness, adj_contrast, adj_saturation, adj_sharpness, adj_blur], outputs=adj_output, ) # === Footer === gr.HTML(""" """) if __name__ == "__main__": demo.queue(max_size=20, api_open=False).launch()