Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from transformers import SamModel, SamProcessor | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Check if CUDA is available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load SAM model and processor | |
| model_id = "facebook/sam-vit-base" | |
| processor = SamProcessor.from_pretrained(model_id) | |
| model = SamModel.from_pretrained(model_id).to(device) | |
| def get_sam_mask(image, points=None): | |
| """ | |
| Generate mask from SAM model based on the entire image | |
| """ | |
| # Convert to RGB if needed | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Process image with SAM | |
| if points is None: | |
| # Generate automatic masks for the whole image | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Get the best mask (highest IoU) | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu() | |
| )[0][0] | |
| # Convert to binary mask and return the largest mask | |
| masks = masks.numpy() | |
| if masks.shape[0] > 0: | |
| # Calculate area of each mask and get the largest one | |
| areas = [np.sum(mask) for mask in masks] | |
| largest_mask_idx = np.argmax(areas) | |
| return masks[largest_mask_idx].astype(np.uint8) * 255 | |
| else: | |
| # If no masks found, return full image mask | |
| return np.ones((image.height, image.width), dtype=np.uint8) * 255 | |
| else: | |
| # Use the provided points to generate a mask | |
| inputs = processor(images=image, input_points=[points], return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Get the mask | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu() | |
| )[0][0] | |
| return masks[0].numpy().astype(np.uint8) * 255 | |
| def find_optimal_crop(image, mask, target_aspect_ratio): | |
| """ | |
| Find the optimal crop that preserves important content based on the mask | |
| """ | |
| # Convert PIL image to numpy array | |
| image_np = np.array(image) | |
| h, w = mask.shape | |
| # Find the bounding box of the important content | |
| # First, find where the mask is non-zero (important content) | |
| y_indices, x_indices = np.where(mask > 0) | |
| if len(y_indices) == 0 or len(x_indices) == 0: | |
| # Fallback if no mask is found | |
| content_box = (0, 0, w, h) | |
| else: | |
| # Get the bounding box of important content | |
| min_x, max_x = np.min(x_indices), np.max(x_indices) | |
| min_y, max_y = np.min(y_indices), np.max(y_indices) | |
| content_width = max_x - min_x + 1 | |
| content_height = max_y - min_y + 1 | |
| content_box = (min_x, min_y, content_width, content_height) | |
| # Calculate target dimensions based on the original image | |
| if target_aspect_ratio > w / h: | |
| # Target is wider than original | |
| target_h = int(w / target_aspect_ratio) | |
| target_w = w | |
| else: | |
| # Target is taller than original | |
| target_h = h | |
| target_w = int(h * target_aspect_ratio) | |
| # Calculate the center of the important content | |
| content_center_x = content_box[0] + content_box[2] // 2 | |
| content_center_y = content_box[1] + content_box[3] // 2 | |
| # Try to center the crop on the important content | |
| x = max(0, min(content_center_x - target_w // 2, w - target_w)) | |
| y = max(0, min(content_center_y - target_h // 2, h - target_h)) | |
| # Check if the important content fits within this crop | |
| min_x, min_y, content_width, content_height = content_box | |
| max_x = min_x + content_width | |
| max_y = min_y + content_height | |
| # If the content doesn't fit in the crop, adjust the crop | |
| if target_w >= content_width and target_h >= content_height: | |
| # If the crop is large enough to include all content, center it | |
| x = max(0, min(content_center_x - target_w // 2, w - target_w)) | |
| y = max(0, min(content_center_y - target_h // 2, h - target_h)) | |
| else: | |
| # If crop isn't large enough for all content, maximize visible content | |
| # and prioritize centering the crop on the content | |
| x = max(0, min(min_x, w - target_w)) | |
| y = max(0, min(min_y, h - target_h)) | |
| # If we still can't fit width, center the crop horizontally | |
| if content_width > target_w: | |
| x = max(0, min(content_center_x - target_w // 2, w - target_w)) | |
| # If we still can't fit height, center the crop vertically | |
| if content_height > target_h: | |
| y = max(0, min(content_center_y - target_h // 2, h - target_h)) | |
| return (x, y, x + target_w, y + target_h) | |
| def smart_crop(input_image, target_aspect_ratio, point_x=None, point_y=None): | |
| """ | |
| Main function to perform smart cropping | |
| """ | |
| if input_image is None: | |
| return None | |
| # Open image and convert to RGB | |
| pil_image = Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image | |
| if pil_image.mode != "RGB": | |
| pil_image = pil_image.convert("RGB") | |
| # Generate mask using SAM | |
| points = None | |
| if point_x is not None and point_y is not None and point_x > 0 and point_y > 0: | |
| points = [[point_x, point_y]] | |
| mask = get_sam_mask(pil_image, points) | |
| # Calculate the best crop | |
| crop_box = find_optimal_crop(pil_image, mask, target_aspect_ratio) | |
| # Crop the image | |
| cropped_img = pil_image.crop(crop_box) | |
| # Visualize the process | |
| fig, ax = plt.subplots(1, 3, figsize=(15, 5)) | |
| ax[0].imshow(pil_image) | |
| ax[0].set_title("Original Image") | |
| ax[0].axis("off") | |
| ax[1].imshow(mask, cmap='gray') | |
| ax[1].set_title("SAM Segmentation Mask") | |
| ax[1].axis("off") | |
| ax[2].imshow(cropped_img) | |
| ax[2].set_title(f"Smart Cropped ({target_aspect_ratio:.2f})") | |
| ax[2].axis("off") | |
| plt.tight_layout() | |
| # Create a temporary file for visualization | |
| vis_path = "visualization.png" | |
| plt.savefig(vis_path) | |
| plt.close() | |
| return cropped_img, vis_path | |
| def aspect_ratio_options(choice): | |
| """Map aspect ratio choices to actual values""" | |
| options = { | |
| "16:9 (Landscape)": 16/9, | |
| "9:16 (Portrait)": 9/16, | |
| "4:3 (Standard)": 4/3, | |
| "3:4 (Portrait)": 3/4, | |
| "1:1 (Square)": 1/1, | |
| "21:9 (Ultrawide)": 21/9, | |
| "2:3 (Portrait)": 2/3, | |
| "3:2 (Landscape)": 3/2, | |
| } | |
| return options.get(choice, 16/9) | |
| def process_image(input_image, aspect_ratio_choice, point_x=None, point_y=None): | |
| if input_image is None: | |
| return None, None | |
| # Get the actual aspect ratio value | |
| target_aspect_ratio = aspect_ratio_options(aspect_ratio_choice) | |
| # Process the image | |
| result_img, vis_path = smart_crop(input_image, target_aspect_ratio, point_x, point_y) | |
| return result_img, vis_path | |
| def create_app(): | |
| with gr.Blocks(title="Smart Image Cropper using SAM") as app: | |
| gr.Markdown("# Smart Image Cropper using Segment Anything Model (SAM)") | |
| gr.Markdown(""" | |
| Upload an image and choose your target aspect ratio. The app will use the Segment Anything Model (SAM) | |
| to identify important content and crop intelligently to preserve it. | |
| Optionally, you can click on the uploaded image to specify a point of interest. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| aspect_ratio = gr.Dropdown( | |
| choices=[ | |
| "16:9 (Landscape)", | |
| "9:16 (Portrait)", | |
| "4:3 (Standard)", | |
| "3:4 (Portrait)", | |
| "1:1 (Square)", | |
| "21:9 (Ultrawide)", | |
| "2:3 (Portrait)", | |
| "3:2 (Landscape)" | |
| ], | |
| value="16:9 (Landscape)", | |
| label="Target Aspect Ratio" | |
| ) | |
| point_coords = gr.State(value=[None, None]) | |
| def update_coords(img, evt: gr.SelectData): | |
| return [evt.index[0], evt.index[1]] | |
| input_image.select(update_coords, inputs=[input_image], outputs=[point_coords]) | |
| process_btn = gr.Button("Process Image") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image(type="pil", label="Cropped Result") | |
| visualization = gr.Image(type="filepath", label="Process Visualization") | |
| process_btn.click( | |
| fn=lambda img, ratio, coords: process_image(img, ratio, coords[0], coords[1]), | |
| inputs=[input_image, aspect_ratio, point_coords], | |
| outputs=[output_image, visualization] | |
| ) | |
| gr.Markdown(""" | |
| ## How It Works | |
| 1. The Segment Anything Model (SAM) analyzes your image to identify the important content | |
| 2. The app finds the optimal crop window that maximizes the preservation of that content | |
| 3. The image is cropped to your desired aspect ratio while keeping the important parts | |
| ## Tips | |
| - For better results with specific subjects, click on the important object in the image | |
| - Try different aspect ratios to see how the model adapts the cropping | |
| """) | |
| return app | |
| # Create and launch the app | |
| demo = create_app() | |
| # For local testing | |
| if __name__ == "__main__": | |
| demo.launch() | |
| else: | |
| # For Hugging Face Spaces | |
| demo.launch() |