import torch from PIL import Image, ImageDraw import numpy as np import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamPredictor import cv2 import os from collections import defaultdict def concat_image_variations_with_base( base_folder: str, variation_folder: str, output_folder: str, image_size: int = 512, stroke_width: int = 6 ): """ Includes the base image followed by variations in a row. Outlines: - base image: black stroke - _0 -> green, _1 -> blue, _2 -> red """ os.makedirs(output_folder, exist_ok=True) suffix_to_color = { '0': 'green', '1': 'blue', '2': 'red' } # Group variation images by ID groups = defaultdict(list) for fname in sorted(os.listdir(variation_folder)): if fname.endswith('.png'): match = re.match(r"(\d+)_\d+_(\d)\.png", fname) if match: base_id = match.group(1) groups[base_id].append(fname) for base_id, variations in groups.items(): images = [] # Load base image base_candidates = [f for f in os.listdir(base_folder) if f.startswith(base_id)] if base_candidates: base_img_path = os.path.join(base_folder, base_candidates[0]) base_img = Image.open(base_img_path).convert("RGBA").resize((image_size, image_size)) draw = ImageDraw.Draw(base_img) draw.rectangle([0, 0, image_size - 1, image_size - 1], outline="black", width=stroke_width) images.append(base_img) else: print(f"Base image not found for ID {base_id}") continue # Add variation images for var in sorted(variations, key=lambda x: int(x.split('_')[1])): path = os.path.join(variation_folder, var) img = Image.open(path).convert("RGBA").resize((image_size, image_size)) draw = ImageDraw.Draw(img) suffix = var.split('_')[-1].split('.')[0] color = suffix_to_color.get(suffix, "black") draw.rectangle([0, 0, image_size - 1, image_size - 1], outline=color, width=stroke_width) images.append(img) # Concatenate all total_width = image_size * len(images) concat_img = Image.new("RGBA", (total_width, image_size)) for i, img in enumerate(images): concat_img.paste(img, (i * image_size, 0)) output_path = os.path.join(output_folder, f"{base_id}_concat.png") concat_img.save(output_path) print(f"Saved: {output_path}") # Load the SAM model def load_sam_model(model_type="vit_h", checkpoint_path="sam_vit_h_4b8939.pth"): sam = sam_model_registry[model_type](checkpoint=checkpoint_path) sam.to("cuda" if torch.cuda.is_available() else "cpu") predictor = SamPredictor(sam) return predictor # Draw bounding box and label def draw_box(img, box, label=None, color="green", output_path=None): draw = ImageDraw.Draw(img) draw.rectangle(box, outline=color, width=3) if label: draw.text((box[0] + 5, box[1] + 5), label, fill=color) if output_path: img.save(output_path) return img def yolo_to_xyxy(boxes, image_width, image_height): """ Convert YOLO format boxes (label cx cy w h) to absolute xyxy format. Parameters: boxes (list of list): Each item is [label, cx, cy, w, h] in relative coords. image_width (int): Width of the image in pixels. image_height (int): Height of the image in pixels. Returns: List of [label, x1, y1, x2, y2] in pixel coords. """ xyxy_boxes = [] for box in boxes: label, cx, cy, w, h = box cx *= image_width cy *= image_height w *= image_width h *= image_height x1 = int(cx - w / 2) y1 = int(cy - h / 2) x2 = int(cx + w / 2) y2 = int(cy + h / 2) xyxy_boxes.append([int(label), x1, y1, x2, y2]) return xyxy_boxes # Main logic def segment(image_np, box_coords, predictor): # SAM expects box as numpy array in [x1, y1, x2, y2] format input_box = np.array([box_coords]) # Get mask masks, scores, logits = predictor.predict(box=input_box, multimask_output=False) mask = masks[0] # Apply mask to image masked_image = image_np.copy() masked_image[~mask] = [255, 255, 255] # white background where mask is off # Convert back to PIL for saving result_img = Image.fromarray(masked_image) return result_img # result_img = draw_box(result_img, box_coords, label="object", color="green", output_path="annotated_sam.jpg") # print("✅ Image saved as 'annotated_sam.jpg'") # ============================ single image ============================ # image_path = "A_images_resized/0010.png" # Replace with your image # checkpoint_path = "sam_vit_h_4b8939.pth" # Replace with your model checkpoint # box_coords = (100, 150, 300, 350) # Replace with your target box (x1, y1, x2, y2) # # Load model # predictor = load_sam_model(checkpoint_path=checkpoint_path) # # load image # image_pil = Image.open(image_path).convert("RGB") # image_np = np.array(image_pil) # predictor.set_image(image_np) # result_img = segment(image_np, box_coords, predictor) # ============================ multiple image ============================ image_folder_path = "A_images_resized" checkpoint_path = "sam_vit_h_4b8939.pth" # Replace with your model checkpoint predictor = load_sam_model(checkpoint_path=checkpoint_path) print("okkkkk") for img_path in os.listdir(image_folder_path): # load image image_pil = Image.open(os.path.join(image_folder_path, img_path)).convert("RGB") image_np = np.array(image_pil) predictor.set_image(image_np) print("12345") # load txt with open(f"A_labels_resized/{img_path.removesuffix('.png')}.txt", "r") as f: lines = f.readlines() boxes = [list(map(float, line.strip().split())) for line in lines] box_coords = yolo_to_xyxy(boxes, 1024, 1024) for idx, box_coord in enumerate(box_coords): label, x, y, x1, y1 = box_coord[0], box_coord[1], box_coord[2], box_coord[3], box_coord[4] box_coord = (x, y, x1, y1) result_img = segment(image_np, box_coord, predictor) result_img.save(f"layer_image/{img_path.removesuffix('.png')}_{idx}_{label}.png") # === view both original and layered data === # concat_image_variations_with_base( # base_folder="A_images_resized", # variation_folder="layer_image", # output_folder="view_image", # image_size= 512, # stroke_width= 6 # )