Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| from tqdm import tqdm | |
| from diffusers import StableDiffusionInpaintPipeline, DDIMScheduler | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download, login | |
| import os | |
| import time | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| hf_token = os.getenv("HF_TOKEN") | |
| login(token=hf_token) | |
| model_path = hf_hub_download( | |
| repo_id="Vuvo11/segment_anything_model", | |
| filename="sam_vit_h_4b8939.pth", | |
| use_auth_token=hf_token | |
| ) | |
| sam = sam_model_registry["vit_h"](checkpoint=model_path) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-inpainting", subfolder="scheduler") | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| scheduler=scheduler, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # 🔥 FP16 cho GPU | |
| cache_dir="./models", | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| if torch.cuda.is_available(): | |
| pipe.unet = torch.compile(pipe.unet) # 🔥 Tối ưu tốc độ nếu chạy trên GPU | |
| pipe.enable_attention_slicing() | |
| def generate_mask(image): | |
| masks = mask_generator.generate(image) | |
| if len(masks) == 0: | |
| return np.zeros_like(image[:, :, 0]) | |
| largest_mask = max(masks, key=lambda x: np.sum(x["segmentation"])) | |
| return (largest_mask["segmentation"] * 255).astype(np.uint8) | |
| def inpaint(image, prompt, progress=gr.Progress()): | |
| progress(0, "Processing image...") | |
| mask = generate_mask(image) | |
| progress(30, "Generating inpainting...") | |
| original_image = Image.fromarray(image).convert("RGB") | |
| mask_image = Image.fromarray(mask).convert("L") | |
| original_image = original_image.resize((384, 384)) # 🔥 Resize nhỏ hơn để xử lý nhanh hơn | |
| mask_image = mask_image.resize((384, 384)) | |
| output = pipe(prompt=prompt, image=original_image, mask_image=mask_image, num_inference_steps=15).images[0] # 🔥 Giảm số bước suy luận | |
| progress(100, "Completed!") | |
| return np.array(output) | |
| with gr.Blocks() as interface: | |
| gr.Markdown("## 🎨 AI Furniture Inpainting (Optimized)") | |
| with gr.Row(): | |
| image_input = gr.Image(type="numpy", label="Upload Image") | |
| prompt_input = gr.Textbox(label="Prompt (Describe what to add)") | |
| submit = gr.Button("Submit") | |
| output_image = gr.Image(label="Generated Image") | |
| submit.click(fn=inpaint, inputs=[image_input, prompt_input], outputs=output_image) | |
| if __name__ == "__main__": | |
| interface.launch() | |