infogr / data_real_world /segment.py
x444's picture
yiyang 722
6483239
import torch
from PIL import Image, ImageDraw
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
import cv2
import os
from collections import defaultdict
def concat_image_variations_with_base(
base_folder: str,
variation_folder: str,
output_folder: str,
image_size: int = 512,
stroke_width: int = 6
):
"""
Includes the base image followed by variations in a row.
Outlines:
- base image: black stroke
- _0 -> green, _1 -> blue, _2 -> red
"""
os.makedirs(output_folder, exist_ok=True)
suffix_to_color = {
'0': 'green',
'1': 'blue',
'2': 'red'
}
# Group variation images by ID
groups = defaultdict(list)
for fname in sorted(os.listdir(variation_folder)):
if fname.endswith('.png'):
match = re.match(r"(\d+)_\d+_(\d)\.png", fname)
if match:
base_id = match.group(1)
groups[base_id].append(fname)
for base_id, variations in groups.items():
images = []
# Load base image
base_candidates = [f for f in os.listdir(base_folder) if f.startswith(base_id)]
if base_candidates:
base_img_path = os.path.join(base_folder, base_candidates[0])
base_img = Image.open(base_img_path).convert("RGBA").resize((image_size, image_size))
draw = ImageDraw.Draw(base_img)
draw.rectangle([0, 0, image_size - 1, image_size - 1], outline="black", width=stroke_width)
images.append(base_img)
else:
print(f"Base image not found for ID {base_id}")
continue
# Add variation images
for var in sorted(variations, key=lambda x: int(x.split('_')[1])):
path = os.path.join(variation_folder, var)
img = Image.open(path).convert("RGBA").resize((image_size, image_size))
draw = ImageDraw.Draw(img)
suffix = var.split('_')[-1].split('.')[0]
color = suffix_to_color.get(suffix, "black")
draw.rectangle([0, 0, image_size - 1, image_size - 1], outline=color, width=stroke_width)
images.append(img)
# Concatenate all
total_width = image_size * len(images)
concat_img = Image.new("RGBA", (total_width, image_size))
for i, img in enumerate(images):
concat_img.paste(img, (i * image_size, 0))
output_path = os.path.join(output_folder, f"{base_id}_concat.png")
concat_img.save(output_path)
print(f"Saved: {output_path}")
# Load the SAM model
def load_sam_model(model_type="vit_h", checkpoint_path="sam_vit_h_4b8939.pth"):
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to("cuda" if torch.cuda.is_available() else "cpu")
predictor = SamPredictor(sam)
return predictor
# Draw bounding box and label
def draw_box(img, box, label=None, color="green", output_path=None):
draw = ImageDraw.Draw(img)
draw.rectangle(box, outline=color, width=3)
if label:
draw.text((box[0] + 5, box[1] + 5), label, fill=color)
if output_path:
img.save(output_path)
return img
def yolo_to_xyxy(boxes, image_width, image_height):
"""
Convert YOLO format boxes (label cx cy w h) to absolute xyxy format.
Parameters:
boxes (list of list): Each item is [label, cx, cy, w, h] in relative coords.
image_width (int): Width of the image in pixels.
image_height (int): Height of the image in pixels.
Returns:
List of [label, x1, y1, x2, y2] in pixel coords.
"""
xyxy_boxes = []
for box in boxes:
label, cx, cy, w, h = box
cx *= image_width
cy *= image_height
w *= image_width
h *= image_height
x1 = int(cx - w / 2)
y1 = int(cy - h / 2)
x2 = int(cx + w / 2)
y2 = int(cy + h / 2)
xyxy_boxes.append([int(label), x1, y1, x2, y2])
return xyxy_boxes
# Main logic
def segment(image_np, box_coords, predictor):
# SAM expects box as numpy array in [x1, y1, x2, y2] format
input_box = np.array([box_coords])
# Get mask
masks, scores, logits = predictor.predict(box=input_box, multimask_output=False)
mask = masks[0]
# Apply mask to image
masked_image = image_np.copy()
masked_image[~mask] = [255, 255, 255] # white background where mask is off
# Convert back to PIL for saving
result_img = Image.fromarray(masked_image)
return result_img
# result_img = draw_box(result_img, box_coords, label="object", color="green", output_path="annotated_sam.jpg")
# print("✅ Image saved as 'annotated_sam.jpg'")
# ============================ single image ============================
# image_path = "A_images_resized/0010.png" # Replace with your image
# checkpoint_path = "sam_vit_h_4b8939.pth" # Replace with your model checkpoint
# box_coords = (100, 150, 300, 350) # Replace with your target box (x1, y1, x2, y2)
# # Load model
# predictor = load_sam_model(checkpoint_path=checkpoint_path)
# # load image
# image_pil = Image.open(image_path).convert("RGB")
# image_np = np.array(image_pil)
# predictor.set_image(image_np)
# result_img = segment(image_np, box_coords, predictor)
# ============================ multiple image ============================
image_folder_path = "A_images_resized"
checkpoint_path = "sam_vit_h_4b8939.pth" # Replace with your model checkpoint
predictor = load_sam_model(checkpoint_path=checkpoint_path)
print("okkkkk")
for img_path in os.listdir(image_folder_path):
# load image
image_pil = Image.open(os.path.join(image_folder_path, img_path)).convert("RGB")
image_np = np.array(image_pil)
predictor.set_image(image_np)
print("12345")
# load txt
with open(f"A_labels_resized/{img_path.removesuffix('.png')}.txt", "r") as f:
lines = f.readlines()
boxes = [list(map(float, line.strip().split())) for line in lines]
box_coords = yolo_to_xyxy(boxes, 1024, 1024)
for idx, box_coord in enumerate(box_coords):
label, x, y, x1, y1 = box_coord[0], box_coord[1], box_coord[2], box_coord[3], box_coord[4]
box_coord = (x, y, x1, y1)
result_img = segment(image_np, box_coord, predictor)
result_img.save(f"layer_image/{img_path.removesuffix('.png')}_{idx}_{label}.png")
# === view both original and layered data ===
# concat_image_variations_with_base(
# base_folder="A_images_resized",
# variation_folder="layer_image",
# output_folder="view_image",
# image_size= 512,
# stroke_width= 6
# )