import torch import numpy as np import gradio as gr from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from tqdm import tqdm from diffusers import StableDiffusionInpaintPipeline, DDIMScheduler from PIL import Image from huggingface_hub import hf_hub_download, login import os import time device = "cuda" if torch.cuda.is_available() else "cpu" hf_token = os.getenv("HF_TOKEN") login(token=hf_token) model_path = hf_hub_download( repo_id="Vuvo11/segment_anything_model", filename="sam_vit_h_4b8939.pth", use_auth_token=hf_token ) sam = sam_model_registry["vit_h"](checkpoint=model_path) mask_generator = SamAutomaticMaskGenerator(sam) scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-inpainting", subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", scheduler=scheduler, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # 🔥 FP16 cho GPU cache_dir="./models", low_cpu_mem_usage=True ).to(device) if torch.cuda.is_available(): pipe.unet = torch.compile(pipe.unet) # 🔥 Tối ưu tốc độ nếu chạy trên GPU pipe.enable_attention_slicing() def generate_mask(image): masks = mask_generator.generate(image) if len(masks) == 0: return np.zeros_like(image[:, :, 0]) largest_mask = max(masks, key=lambda x: np.sum(x["segmentation"])) return (largest_mask["segmentation"] * 255).astype(np.uint8) def inpaint(image, prompt, progress=gr.Progress()): progress(0, "Processing image...") mask = generate_mask(image) progress(30, "Generating inpainting...") original_image = Image.fromarray(image).convert("RGB") mask_image = Image.fromarray(mask).convert("L") original_image = original_image.resize((384, 384)) # 🔥 Resize nhỏ hơn để xử lý nhanh hơn mask_image = mask_image.resize((384, 384)) output = pipe(prompt=prompt, image=original_image, mask_image=mask_image, num_inference_steps=15).images[0] # 🔥 Giảm số bước suy luận progress(100, "Completed!") return np.array(output) with gr.Blocks() as interface: gr.Markdown("## 🎨 AI Furniture Inpainting (Optimized)") with gr.Row(): image_input = gr.Image(type="numpy", label="Upload Image") prompt_input = gr.Textbox(label="Prompt (Describe what to add)") submit = gr.Button("Submit") output_image = gr.Image(label="Generated Image") submit.click(fn=inpaint, inputs=[image_input, prompt_input], outputs=output_image) if __name__ == "__main__": interface.launch()