Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,53 +3,71 @@ import numpy as np
|
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
import torch
|
| 6 |
-
from transformers import
|
|
|
|
| 7 |
|
| 8 |
# Set up device
|
| 9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
print(f"Using device: {device}")
|
| 11 |
|
| 12 |
-
# Load
|
| 13 |
-
print("Loading
|
| 14 |
-
|
| 15 |
-
|
| 16 |
|
| 17 |
-
def
|
| 18 |
-
"""
|
| 19 |
-
# Convert to
|
| 20 |
-
if
|
| 21 |
-
image_pil = Image.fromarray(image)
|
| 22 |
-
else:
|
| 23 |
image_pil = image
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
with torch.no_grad():
|
| 32 |
-
outputs =
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
detected_boxes.append({
|
| 45 |
-
'box': box,
|
| 46 |
-
'score': score.item(),
|
| 47 |
-
'label': model.config.id2label[label.item()]
|
| 48 |
-
})
|
| 49 |
|
| 50 |
-
return
|
| 51 |
|
| 52 |
-
def find_optimal_crop(image, target_ratio,
|
| 53 |
"""Find the optimal crop area that preserves important content while matching target ratio"""
|
| 54 |
# Get image dimensions
|
| 55 |
if not isinstance(image, np.ndarray):
|
|
@@ -59,34 +77,6 @@ def find_optimal_crop(image, target_ratio, objects):
|
|
| 59 |
current_ratio = w / h
|
| 60 |
target_ratio_value = eval(target_ratio.replace(':', '/'))
|
| 61 |
|
| 62 |
-
# If no objects detected, use center crop
|
| 63 |
-
if not objects:
|
| 64 |
-
if current_ratio > target_ratio_value:
|
| 65 |
-
# Need to crop width
|
| 66 |
-
new_width = int(h * target_ratio_value)
|
| 67 |
-
left = (w - new_width) // 2
|
| 68 |
-
right = left + new_width
|
| 69 |
-
return (left, 0, right, h)
|
| 70 |
-
else:
|
| 71 |
-
# Need to crop height
|
| 72 |
-
new_height = int(w / target_ratio_value)
|
| 73 |
-
top = (h - new_height) // 2
|
| 74 |
-
bottom = top + new_height
|
| 75 |
-
return (0, top, w, bottom)
|
| 76 |
-
|
| 77 |
-
# Create a combined importance map from all detected objects
|
| 78 |
-
importance_map = np.zeros((h, w), dtype=np.float32)
|
| 79 |
-
|
| 80 |
-
# Add all objects to the importance map
|
| 81 |
-
for obj in objects:
|
| 82 |
-
x1, y1, x2, y2 = obj['box']
|
| 83 |
-
# Ensure box is within image boundaries
|
| 84 |
-
x1, y1 = max(0, x1), max(0, y1)
|
| 85 |
-
x2, y2 = min(w-1, x2), min(h-1, y2)
|
| 86 |
-
|
| 87 |
-
# Add object to importance map with its confidence score
|
| 88 |
-
importance_map[y1:y2, x1:x2] = max(importance_map[y1:y2, x1:x2], obj['score'])
|
| 89 |
-
|
| 90 |
# If current ratio is wider than target, we need to crop width
|
| 91 |
if current_ratio > target_ratio_value:
|
| 92 |
new_width = int(h * target_ratio_value)
|
|
@@ -144,16 +134,16 @@ def apply_crop(image, crop_box):
|
|
| 144 |
|
| 145 |
def adjust_aspect_ratio(image, target_ratio):
|
| 146 |
"""Main function to adjust aspect ratio through intelligent cropping"""
|
| 147 |
-
#
|
| 148 |
-
|
| 149 |
|
| 150 |
# Find optimal crop box
|
| 151 |
-
crop_box = find_optimal_crop(image, target_ratio,
|
| 152 |
|
| 153 |
# Apply the crop
|
| 154 |
result = apply_crop(image, crop_box)
|
| 155 |
|
| 156 |
-
return result
|
| 157 |
|
| 158 |
def process_image(input_image, target_ratio="16:9"):
|
| 159 |
"""Process function for Gradio interface"""
|
|
@@ -165,7 +155,7 @@ def process_image(input_image, target_ratio="16:9"):
|
|
| 165 |
image = input_image
|
| 166 |
|
| 167 |
# Adjust aspect ratio
|
| 168 |
-
result = adjust_aspect_ratio(image, target_ratio)
|
| 169 |
|
| 170 |
# Convert result to appropriate format
|
| 171 |
if isinstance(result, np.ndarray):
|
|
@@ -173,15 +163,26 @@ def process_image(input_image, target_ratio="16:9"):
|
|
| 173 |
else:
|
| 174 |
result_pil = result
|
| 175 |
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
print(f"Error processing image: {e}")
|
| 180 |
-
return None
|
| 181 |
|
| 182 |
# Create the Gradio interface
|
| 183 |
-
with gr.Blocks(title="Smart Crop Aspect Ratio Adjuster") as demo:
|
| 184 |
-
gr.Markdown("# Smart Crop Aspect Ratio Adjuster")
|
| 185 |
gr.Markdown("Upload an image, choose your target aspect ratio, and the AI will intelligently crop it to preserve important content.")
|
| 186 |
|
| 187 |
with gr.Row():
|
|
@@ -199,23 +200,24 @@ with gr.Blocks(title="Smart Crop Aspect Ratio Adjuster") as demo:
|
|
| 199 |
|
| 200 |
with gr.Column():
|
| 201 |
output_image = gr.Image(label="Processed Image")
|
|
|
|
| 202 |
|
| 203 |
submit_btn.click(
|
| 204 |
process_image,
|
| 205 |
inputs=[input_image, aspect_ratio],
|
| 206 |
-
outputs=output_image
|
| 207 |
)
|
| 208 |
|
| 209 |
gr.Markdown("""
|
| 210 |
## How it works
|
| 211 |
-
1. **
|
| 212 |
-
2. **Importance Mapping**:
|
| 213 |
-
3. **Smart Cropping**:
|
| 214 |
|
| 215 |
## Tips
|
| 216 |
-
- For best results, ensure important subjects are visible
|
|
|
|
| 217 |
- Try different aspect ratios to see what works best with your image
|
| 218 |
-
- The model works best with clear, well-lit images with distinct objects
|
| 219 |
""")
|
| 220 |
|
| 221 |
# Launch the app
|
|
|
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
import torch
|
| 6 |
+
from transformers import SamModel, SamProcessor
|
| 7 |
+
import os
|
| 8 |
|
| 9 |
# Set up device
|
| 10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
print(f"Using device: {device}")
|
| 12 |
|
| 13 |
+
# Load SAM model for segmentation
|
| 14 |
+
print("Loading SAM model...")
|
| 15 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 16 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
| 17 |
|
| 18 |
+
def get_sam_masks(image):
|
| 19 |
+
"""Get segmentation masks using SAM model"""
|
| 20 |
+
# Convert to numpy if needed
|
| 21 |
+
if isinstance(image, Image.Image):
|
|
|
|
|
|
|
| 22 |
image_pil = image
|
| 23 |
+
image_np = np.array(image)
|
| 24 |
+
else:
|
| 25 |
+
image_np = image
|
| 26 |
+
image_pil = Image.fromarray(image_np)
|
| 27 |
+
|
| 28 |
+
h, w = image_np.shape[:2]
|
| 29 |
+
|
| 30 |
+
# Create a grid of points to sample the image
|
| 31 |
+
x_points = np.linspace(w//4, 3*w//4, 5, dtype=int)
|
| 32 |
+
y_points = np.linspace(h//4, 3*h//4, 5, dtype=int)
|
| 33 |
+
grid_points = []
|
| 34 |
+
for y in y_points:
|
| 35 |
+
for x in x_points:
|
| 36 |
+
grid_points.append([x, y])
|
| 37 |
+
points = [grid_points]
|
| 38 |
+
|
| 39 |
+
# Process image through SAM
|
| 40 |
+
inputs = sam_processor(
|
| 41 |
+
images=image_pil,
|
| 42 |
+
input_points=points,
|
| 43 |
+
return_tensors="pt"
|
| 44 |
+
).to(device)
|
| 45 |
+
|
| 46 |
+
# Generate masks
|
| 47 |
with torch.no_grad():
|
| 48 |
+
outputs = sam_model(**inputs)
|
| 49 |
+
masks = sam_processor.image_processor.post_process_masks(
|
| 50 |
+
outputs.pred_masks.cpu(),
|
| 51 |
+
inputs["original_sizes"].cpu(),
|
| 52 |
+
inputs["reshaped_input_sizes"].cpu()
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Combine all masks to create importance map
|
| 56 |
+
importance_map = np.zeros((h, w), dtype=np.float32)
|
| 57 |
+
individual_masks = []
|
| 58 |
|
| 59 |
+
for i in range(len(masks[0])):
|
| 60 |
+
mask = masks[0][i].numpy().astype(np.float32)
|
| 61 |
+
individual_masks.append(mask)
|
| 62 |
+
importance_map += mask
|
| 63 |
|
| 64 |
+
# Normalize to 0-1
|
| 65 |
+
if importance_map.max() > 0:
|
| 66 |
+
importance_map = importance_map / importance_map.max()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
return importance_map, individual_masks
|
| 69 |
|
| 70 |
+
def find_optimal_crop(image, target_ratio, importance_map):
|
| 71 |
"""Find the optimal crop area that preserves important content while matching target ratio"""
|
| 72 |
# Get image dimensions
|
| 73 |
if not isinstance(image, np.ndarray):
|
|
|
|
| 77 |
current_ratio = w / h
|
| 78 |
target_ratio_value = eval(target_ratio.replace(':', '/'))
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# If current ratio is wider than target, we need to crop width
|
| 81 |
if current_ratio > target_ratio_value:
|
| 82 |
new_width = int(h * target_ratio_value)
|
|
|
|
| 134 |
|
| 135 |
def adjust_aspect_ratio(image, target_ratio):
|
| 136 |
"""Main function to adjust aspect ratio through intelligent cropping"""
|
| 137 |
+
# Get segmentation masks and importance map
|
| 138 |
+
importance_map, _ = get_sam_masks(image)
|
| 139 |
|
| 140 |
# Find optimal crop box
|
| 141 |
+
crop_box = find_optimal_crop(image, target_ratio, importance_map)
|
| 142 |
|
| 143 |
# Apply the crop
|
| 144 |
result = apply_crop(image, crop_box)
|
| 145 |
|
| 146 |
+
return result, importance_map
|
| 147 |
|
| 148 |
def process_image(input_image, target_ratio="16:9"):
|
| 149 |
"""Process function for Gradio interface"""
|
|
|
|
| 155 |
image = input_image
|
| 156 |
|
| 157 |
# Adjust aspect ratio
|
| 158 |
+
result, importance_map = adjust_aspect_ratio(image, target_ratio)
|
| 159 |
|
| 160 |
# Convert result to appropriate format
|
| 161 |
if isinstance(result, np.ndarray):
|
|
|
|
| 163 |
else:
|
| 164 |
result_pil = result
|
| 165 |
|
| 166 |
+
# Visualize importance map for debugging
|
| 167 |
+
if isinstance(importance_map, np.ndarray):
|
| 168 |
+
# Convert to heatmap
|
| 169 |
+
heatmap = (importance_map * 255).astype(np.uint8)
|
| 170 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 171 |
+
|
| 172 |
+
# Convert to PIL
|
| 173 |
+
heatmap_pil = Image.fromarray(heatmap)
|
| 174 |
+
|
| 175 |
+
return [result_pil, heatmap_pil]
|
| 176 |
+
|
| 177 |
+
return [result_pil, None]
|
| 178 |
|
| 179 |
except Exception as e:
|
| 180 |
print(f"Error processing image: {e}")
|
| 181 |
+
return [None, None]
|
| 182 |
|
| 183 |
# Create the Gradio interface
|
| 184 |
+
with gr.Blocks(title="SAM-Based Smart Crop Aspect Ratio Adjuster") as demo:
|
| 185 |
+
gr.Markdown("# SAM-Based Smart Crop Aspect Ratio Adjuster")
|
| 186 |
gr.Markdown("Upload an image, choose your target aspect ratio, and the AI will intelligently crop it to preserve important content.")
|
| 187 |
|
| 188 |
with gr.Row():
|
|
|
|
| 200 |
|
| 201 |
with gr.Column():
|
| 202 |
output_image = gr.Image(label="Processed Image")
|
| 203 |
+
importance_map_vis = gr.Image(label="Importance Map (Debug View)")
|
| 204 |
|
| 205 |
submit_btn.click(
|
| 206 |
process_image,
|
| 207 |
inputs=[input_image, aspect_ratio],
|
| 208 |
+
outputs=[output_image, importance_map_vis]
|
| 209 |
)
|
| 210 |
|
| 211 |
gr.Markdown("""
|
| 212 |
## How it works
|
| 213 |
+
1. **Segmentation**: Uses Meta's Segment Anything Model (SAM) to identify important regions in your image
|
| 214 |
+
2. **Importance Mapping**: Creates a heatmap of important areas based on segmentation masks
|
| 215 |
+
3. **Smart Cropping**: Finds the optimal crop window that preserves the most important content
|
| 216 |
|
| 217 |
## Tips
|
| 218 |
+
- For best results, ensure important subjects are clearly visible in the image
|
| 219 |
+
- The importance map shows what the AI considers important (red/yellow = important, blue = less important)
|
| 220 |
- Try different aspect ratios to see what works best with your image
|
|
|
|
| 221 |
""")
|
| 222 |
|
| 223 |
# Launch the app
|