ai-image-editor / app.py
bep40's picture
Upload app.py
017ef62 verified
import os
import math
import torch
import spaces
import gradio as gr
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance
from functools import lru_cache
# ============================================================
# 🎨 AI Image Editor - Powered by HuggingFace Models
# Features:
# 1. ✏️ Instruction-based Editing (InstructPix2Pix)
# 2. 🖌️ Inpainting (SDXL Inpainting)
# 3. ✂️ Background Removal (RMBG-2.0 / BiRefNet)
# 4. 🔍 Image Upscaling (Swin2SR)
# ============================================================
# --- Global model holders (lazy loaded) ---
_edit_pipe = None
_inpaint_pipe = None
_rmbg_model = None
_rmbg_transform = None
_upscale_processor = None
_upscale_model = None
def get_edit_pipe():
"""Lazy load InstructPix2Pix pipeline"""
global _edit_pipe
if _edit_pipe is None:
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
_edit_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix",
torch_dtype=torch.float16,
safety_checker=None,
).to("cuda")
_edit_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
_edit_pipe.scheduler.config
)
return _edit_pipe
def get_inpaint_pipe():
"""Lazy load SDXL Inpainting pipeline"""
global _inpaint_pipe
if _inpaint_pipe is None:
from diffusers import AutoPipelineForInpainting
_inpaint_pipe = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
return _inpaint_pipe
def get_rmbg_model():
"""Lazy load background removal model"""
global _rmbg_model, _rmbg_transform
if _rmbg_model is None:
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
_rmbg_model = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
).to("cuda").eval()
_rmbg_transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
return _rmbg_model, _rmbg_transform
def get_upscale_model():
"""Lazy load Swin2SR upscaling model"""
global _upscale_processor, _upscale_model
if _upscale_model is None:
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
_upscale_processor = AutoImageProcessor.from_pretrained(
"caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
)
_upscale_model = Swin2SRForImageSuperResolution.from_pretrained(
"caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
).to("cuda").eval()
return _upscale_processor, _upscale_model
# ============================================================
# Feature 1: Instruction-based Editing
# ============================================================
@spaces.GPU(duration=120)
def instruct_edit(input_image, instruction, text_cfg, image_cfg, steps, seed):
if input_image is None:
raise gr.Error("⚠️ Vui lòng upload ảnh trước!")
if not instruction or instruction.strip() == "":
raise gr.Error("⚠️ Vui lòng nhập lệnh chỉnh sửa!")
pipe = get_edit_pipe()
# Resize to be compatible with the model (multiples of 64)
width, height = input_image.size
factor = 512 / max(width, height)
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
w = int((width * factor) // 64) * 64
h = int((height * factor) // 64) * 64
input_resized = ImageOps.fit(input_image, (w, h), method=Image.Resampling.LANCZOS)
generator = torch.Generator("cuda").manual_seed(int(seed))
result = pipe(
instruction,
image=input_resized,
guidance_scale=text_cfg,
image_guidance_scale=image_cfg,
num_inference_steps=int(steps),
generator=generator,
).images[0]
# Resize back to original
result = result.resize(input_image.size, Image.Resampling.LANCZOS)
return result
# ============================================================
# Feature 2: Inpainting
# ============================================================
@spaces.GPU(duration=120)
def inpaint(input_dict, prompt, negative_prompt, guidance_scale, steps, strength, seed):
if input_dict is None:
raise gr.Error("⚠️ Vui lòng upload ảnh và vẽ mask!")
# Extract image and mask from ImageEditor
init_image = input_dict["background"].convert("RGB")
# Get mask from the drawn layer
if len(input_dict["layers"]) > 0:
mask_layer = input_dict["layers"][0]
# The alpha channel of the layer IS the mask
if mask_layer.mode == "RGBA":
mask = mask_layer.getchannel("A")
else:
mask = mask_layer.convert("L")
else:
raise gr.Error("⚠️ Vui lòng vẽ vùng cần chỉnh sửa trên ảnh!")
# Check if mask has any painted area
mask_array = np.array(mask)
if mask_array.max() == 0:
raise gr.Error("⚠️ Vui lòng vẽ vùng cần chỉnh sửa (brush) trên ảnh!")
if not prompt or prompt.strip() == "":
raise gr.Error("⚠️ Vui lòng nhập mô tả nội dung muốn tạo!")
pipe = get_inpaint_pipe()
# Resize to 1024x1024 for SDXL
init_resized = init_image.resize((1024, 1024), Image.Resampling.LANCZOS)
mask_resized = mask.resize((1024, 1024), Image.Resampling.LANCZOS)
generator = torch.Generator("cuda").manual_seed(int(seed))
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
image=init_resized,
mask_image=mask_resized,
guidance_scale=guidance_scale,
num_inference_steps=int(steps),
strength=strength,
generator=generator,
).images[0]
# Resize back
output = output.resize(init_image.size, Image.Resampling.LANCZOS)
return (init_image, output)
# ============================================================
# Feature 3: Background Removal
# ============================================================
@spaces.GPU(duration=60)
def remove_background(input_image):
if input_image is None:
raise gr.Error("⚠️ Vui lòng upload ảnh!")
model, transform = get_rmbg_model()
# Prepare input
input_tensor = transform(input_image).unsqueeze(0).to("cuda")
# Inference
with torch.no_grad():
preds = model(input_tensor)[-1].sigmoid().cpu()
# Create mask
pred_pil = Image.fromarray((preds[0].squeeze().numpy() * 255).astype(np.uint8))
mask = pred_pil.resize(input_image.size, Image.Resampling.LANCZOS)
# Apply alpha channel
result = input_image.copy().convert("RGBA")
result.putalpha(mask)
return result
# ============================================================
# Feature 4: Image Upscaling
# ============================================================
@spaces.GPU(duration=120)
def upscale_image(input_image, scale_factor):
if input_image is None:
raise gr.Error("⚠️ Vui lòng upload ảnh!")
processor, model = get_upscale_model()
# Limit input size to prevent OOM
max_dim = 512
w, h = input_image.size
if max(w, h) > max_dim:
ratio = max_dim / max(w, h)
new_w = int(w * ratio)
new_h = int(h * ratio)
input_image_resized = input_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
else:
input_image_resized = input_image
inputs = processor(input_image_resized, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
output = outputs.reconstruction.data.squeeze().float().cpu().clamp(0, 1).numpy()
output = np.moveaxis(output, 0, -1) # CHW -> HWC
output = (output * 255.0).round().astype(np.uint8)
result = Image.fromarray(output)
# If user wants 2x instead of 4x, resize down
if scale_factor == "2x":
target_w = input_image.size[0] * 2
target_h = input_image.size[1] * 2
result = result.resize((target_w, target_h), Image.Resampling.LANCZOS)
orig_size = f"{input_image.size[0]}×{input_image.size[1]}"
new_size = f"{result.size[0]}×{result.size[1]}"
info = f"📐 Gốc: {orig_size} → Kết quả: {new_size}"
return result, info
# ============================================================
# Feature 5: Basic adjustments (no GPU needed)
# ============================================================
def basic_adjust(input_image, brightness, contrast, saturation, sharpness, blur_radius):
if input_image is None:
raise gr.Error("⚠️ Vui lòng upload ảnh!")
img = input_image.copy()
if brightness != 1.0:
img = ImageEnhance.Brightness(img).enhance(brightness)
if contrast != 1.0:
img = ImageEnhance.Contrast(img).enhance(contrast)
if saturation != 1.0:
img = ImageEnhance.Color(img).enhance(saturation)
if sharpness != 1.0:
img = ImageEnhance.Sharpness(img).enhance(sharpness)
if blur_radius > 0:
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
return img
# ============================================================
# Custom CSS
# ============================================================
css = """
.gradio-container {
max-width: 1200px !important;
margin: auto !important;
}
.main-header {
text-align: center;
padding: 20px 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 2.5em;
font-weight: bold;
}
.sub-header {
text-align: center;
color: #666;
margin-bottom: 20px;
}
.feature-icon {
font-size: 1.3em;
}
footer {
text-align: center;
padding: 20px;
color: #999;
}
"""
# ============================================================
# Gradio UI
# ============================================================
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="🎨 AI Image Editor") as demo:
gr.HTML("""
<div class="main-header">🎨 AI Image Editor</div>
<div class="sub-header">
Chỉnh sửa ảnh thông minh với AI | Powered by HuggingFace 🤗
</div>
""")
# === Tab 1: Instruction Edit ===
with gr.Tab("✏️ Chỉnh sửa bằng lệnh"):
gr.Markdown("""
### Mô tả thay đổi bạn muốn bằng tiếng Anh
Ví dụ: *"make it snowy"*, *"turn the sky to sunset"*, *"add sunglasses"*, *"make it a watercolor painting"*
""")
with gr.Row():
with gr.Column(scale=1):
edit_input = gr.Image(type="pil", label="📷 Ảnh gốc", height=400)
edit_instruction = gr.Textbox(
label="✏️ Lệnh chỉnh sửa (tiếng Anh)",
placeholder="e.g. make it look like a painting...",
lines=2,
)
with gr.Accordion("⚙️ Cài đặt nâng cao", open=False):
edit_text_cfg = gr.Slider(
label="Text Guidance Scale",
minimum=1.0, maximum=15.0, value=7.5, step=0.5,
info="Mức độ tuân theo lệnh (cao = thay đổi nhiều hơn)"
)
edit_image_cfg = gr.Slider(
label="Image Guidance Scale",
minimum=0.5, maximum=3.0, value=1.5, step=0.1,
info="Mức độ giữ lại ảnh gốc (cao = giữ nhiều hơn)"
)
edit_steps = gr.Slider(
label="Số bước", minimum=10, maximum=100, value=30, step=5,
)
edit_seed = gr.Number(label="Seed", value=42, precision=0)
edit_btn = gr.Button("🚀 Chỉnh sửa", variant="primary", size="lg")
with gr.Column(scale=1):
edit_output = gr.Image(label="🖼️ Kết quả", height=400)
edit_btn.click(
fn=instruct_edit,
inputs=[edit_input, edit_instruction, edit_text_cfg, edit_image_cfg, edit_steps, edit_seed],
outputs=edit_output,
)
gr.Examples(
examples=[
["https://raw.githubusercontent.com/timbrooks/instruct-pix2pix/main/imgs/example.jpg", "turn him into a cyborg"],
["https://raw.githubusercontent.com/timbrooks/instruct-pix2pix/main/imgs/example.jpg", "make it a watercolor painting"],
],
inputs=[edit_input, edit_instruction],
label="💡 Ví dụ",
)
# === Tab 2: Inpainting ===
with gr.Tab("🖌️ Inpainting (Tô vẽ lại)"):
gr.Markdown("""
### Vẽ lên vùng cần thay đổi, sau đó mô tả nội dung mới
Dùng **brush** để tô lên vùng muốn chỉnh sửa → nhập mô tả → nhấn Inpaint
""")
with gr.Row():
with gr.Column(scale=1):
inpaint_editor = gr.ImageEditor(
type="pil",
label="🖌️ Vẽ mask lên ảnh",
height=450,
brush=gr.Brush(
colors=["#FFFFFF"],
default_color="#FFFFFF",
color_mode="fixed",
default_size=30,
),
eraser=gr.Eraser(default_size=30),
layers=True,
)
inpaint_prompt = gr.Textbox(
label="📝 Mô tả nội dung mới (tiếng Anh)",
placeholder="e.g. a cute cat sitting here...",
lines=2,
)
inpaint_neg = gr.Textbox(
label="🚫 Negative prompt (tùy chọn)",
placeholder="e.g. blurry, low quality, distorted...",
lines=1,
)
with gr.Accordion("⚙️ Cài đặt nâng cao", open=False):
inpaint_cfg = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=20.0, value=8.0, step=0.5,
)
inpaint_steps = gr.Slider(
label="Số bước", minimum=10, maximum=50, value=25, step=5,
)
inpaint_strength = gr.Slider(
label="Strength", minimum=0.5, maximum=1.0, value=0.99, step=0.01,
info="Mức độ thay đổi (1.0 = thay đổi hoàn toàn)",
)
inpaint_seed = gr.Number(label="Seed", value=42, precision=0)
inpaint_btn = gr.Button("🎨 Inpaint", variant="primary", size="lg")
with gr.Column(scale=1):
inpaint_output = gr.ImageSlider(
label="📊 So sánh Trước / Sau",
height=450,
)
inpaint_btn.click(
fn=inpaint,
inputs=[inpaint_editor, inpaint_prompt, inpaint_neg, inpaint_cfg, inpaint_steps, inpaint_strength, inpaint_seed],
outputs=inpaint_output,
)
# === Tab 3: Background Removal ===
with gr.Tab("✂️ Xóa nền"):
gr.Markdown("""
### Xóa nền ảnh tự động bằng AI
Upload ảnh → nhận ảnh nền trong suốt (PNG)
""")
with gr.Row():
with gr.Column(scale=1):
bg_input = gr.Image(type="pil", label="📷 Ảnh gốc", height=400)
bg_btn = gr.Button("✂️ Xóa nền", variant="primary", size="lg")
with gr.Column(scale=1):
bg_output = gr.Image(label="🖼️ Kết quả (nền trong suốt)", height=400)
bg_btn.click(fn=remove_background, inputs=bg_input, outputs=bg_output)
# === Tab 4: Image Upscaling ===
with gr.Tab("🔍 Phóng to ảnh"):
gr.Markdown("""
### Phóng to ảnh chất lượng cao với AI
Tăng độ phân giải ảnh lên 2x hoặc 4x mà không bị mờ
""")
with gr.Row():
with gr.Column(scale=1):
upscale_input = gr.Image(type="pil", label="📷 Ảnh gốc", height=400)
upscale_factor = gr.Radio(
choices=["2x", "4x"], value="4x", label="📏 Mức phóng to",
)
upscale_btn = gr.Button("🔍 Phóng to", variant="primary", size="lg")
with gr.Column(scale=1):
upscale_output = gr.Image(label="🖼️ Kết quả", height=400)
upscale_info = gr.Textbox(label="📐 Thông tin", interactive=False)
upscale_btn.click(
fn=upscale_image,
inputs=[upscale_input, upscale_factor],
outputs=[upscale_output, upscale_info],
)
# === Tab 5: Basic Adjustments ===
with gr.Tab("🎚️ Chỉnh sửa cơ bản"):
gr.Markdown("""
### Điều chỉnh các thông số cơ bản
Độ sáng, tương phản, bão hòa, sắc nét, làm mờ
""")
with gr.Row():
with gr.Column(scale=1):
adj_input = gr.Image(type="pil", label="📷 Ảnh gốc", height=400)
adj_brightness = gr.Slider(label="☀️ Độ sáng", minimum=0.1, maximum=3.0, value=1.0, step=0.05)
adj_contrast = gr.Slider(label="🔲 Tương phản", minimum=0.1, maximum=3.0, value=1.0, step=0.05)
adj_saturation = gr.Slider(label="🎨 Bão hòa", minimum=0.0, maximum=3.0, value=1.0, step=0.05)
adj_sharpness = gr.Slider(label="🔪 Sắc nét", minimum=0.0, maximum=3.0, value=1.0, step=0.05)
adj_blur = gr.Slider(label="💨 Làm mờ", minimum=0, maximum=10, value=0, step=0.5)
adj_btn = gr.Button("✨ Áp dụng", variant="primary", size="lg")
with gr.Column(scale=1):
adj_output = gr.Image(label="🖼️ Kết quả", height=400)
adj_btn.click(
fn=basic_adjust,
inputs=[adj_input, adj_brightness, adj_contrast, adj_saturation, adj_sharpness, adj_blur],
outputs=adj_output,
)
# === Footer ===
gr.HTML("""
<footer>
<hr>
<p>
🤗 <strong>AI Image Editor</strong> |
Models: <a href="https://huggingface.co/timbrooks/instruct-pix2pix" target="_blank">InstructPix2Pix</a> ·
<a href="https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1" target="_blank">SDXL Inpainting</a> ·
<a href="https://huggingface.co/ZhengPeng7/BiRefNet" target="_blank">BiRefNet</a> ·
<a href="https://huggingface.co/caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" target="_blank">Swin2SR</a>
</p>
<p>Powered by 🤗 HuggingFace Diffusers, Transformers & Gradio</p>
</footer>
""")
if __name__ == "__main__":
demo.queue(max_size=20, api_open=False).launch()