|
|
import torch |
|
|
from PIL import Image, ImageDraw |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
import cv2 |
|
|
import os |
|
|
|
|
|
from collections import defaultdict |
|
|
|
|
|
def concat_image_variations_with_base( |
|
|
base_folder: str, |
|
|
variation_folder: str, |
|
|
output_folder: str, |
|
|
image_size: int = 512, |
|
|
stroke_width: int = 6 |
|
|
): |
|
|
""" |
|
|
Includes the base image followed by variations in a row. |
|
|
Outlines: |
|
|
- base image: black stroke |
|
|
- _0 -> green, _1 -> blue, _2 -> red |
|
|
""" |
|
|
|
|
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
|
|
suffix_to_color = { |
|
|
'0': 'green', |
|
|
'1': 'blue', |
|
|
'2': 'red' |
|
|
} |
|
|
|
|
|
|
|
|
groups = defaultdict(list) |
|
|
for fname in sorted(os.listdir(variation_folder)): |
|
|
if fname.endswith('.png'): |
|
|
match = re.match(r"(\d+)_\d+_(\d)\.png", fname) |
|
|
if match: |
|
|
base_id = match.group(1) |
|
|
groups[base_id].append(fname) |
|
|
|
|
|
for base_id, variations in groups.items(): |
|
|
images = [] |
|
|
|
|
|
|
|
|
base_candidates = [f for f in os.listdir(base_folder) if f.startswith(base_id)] |
|
|
if base_candidates: |
|
|
base_img_path = os.path.join(base_folder, base_candidates[0]) |
|
|
base_img = Image.open(base_img_path).convert("RGBA").resize((image_size, image_size)) |
|
|
draw = ImageDraw.Draw(base_img) |
|
|
draw.rectangle([0, 0, image_size - 1, image_size - 1], outline="black", width=stroke_width) |
|
|
images.append(base_img) |
|
|
else: |
|
|
print(f"Base image not found for ID {base_id}") |
|
|
continue |
|
|
|
|
|
|
|
|
for var in sorted(variations, key=lambda x: int(x.split('_')[1])): |
|
|
path = os.path.join(variation_folder, var) |
|
|
img = Image.open(path).convert("RGBA").resize((image_size, image_size)) |
|
|
draw = ImageDraw.Draw(img) |
|
|
suffix = var.split('_')[-1].split('.')[0] |
|
|
color = suffix_to_color.get(suffix, "black") |
|
|
draw.rectangle([0, 0, image_size - 1, image_size - 1], outline=color, width=stroke_width) |
|
|
images.append(img) |
|
|
|
|
|
|
|
|
total_width = image_size * len(images) |
|
|
concat_img = Image.new("RGBA", (total_width, image_size)) |
|
|
for i, img in enumerate(images): |
|
|
concat_img.paste(img, (i * image_size, 0)) |
|
|
|
|
|
output_path = os.path.join(output_folder, f"{base_id}_concat.png") |
|
|
concat_img.save(output_path) |
|
|
print(f"Saved: {output_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_sam_model(model_type="vit_h", checkpoint_path="sam_vit_h_4b8939.pth"): |
|
|
sam = sam_model_registry[model_type](checkpoint=checkpoint_path) |
|
|
sam.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
predictor = SamPredictor(sam) |
|
|
return predictor |
|
|
|
|
|
|
|
|
def draw_box(img, box, label=None, color="green", output_path=None): |
|
|
draw = ImageDraw.Draw(img) |
|
|
draw.rectangle(box, outline=color, width=3) |
|
|
if label: |
|
|
draw.text((box[0] + 5, box[1] + 5), label, fill=color) |
|
|
if output_path: |
|
|
img.save(output_path) |
|
|
return img |
|
|
|
|
|
def yolo_to_xyxy(boxes, image_width, image_height): |
|
|
""" |
|
|
Convert YOLO format boxes (label cx cy w h) to absolute xyxy format. |
|
|
|
|
|
Parameters: |
|
|
boxes (list of list): Each item is [label, cx, cy, w, h] in relative coords. |
|
|
image_width (int): Width of the image in pixels. |
|
|
image_height (int): Height of the image in pixels. |
|
|
|
|
|
Returns: |
|
|
List of [label, x1, y1, x2, y2] in pixel coords. |
|
|
""" |
|
|
xyxy_boxes = [] |
|
|
for box in boxes: |
|
|
label, cx, cy, w, h = box |
|
|
cx *= image_width |
|
|
cy *= image_height |
|
|
w *= image_width |
|
|
h *= image_height |
|
|
x1 = int(cx - w / 2) |
|
|
y1 = int(cy - h / 2) |
|
|
x2 = int(cx + w / 2) |
|
|
y2 = int(cy + h / 2) |
|
|
xyxy_boxes.append([int(label), x1, y1, x2, y2]) |
|
|
return xyxy_boxes |
|
|
|
|
|
|
|
|
|
|
|
def segment(image_np, box_coords, predictor): |
|
|
|
|
|
|
|
|
input_box = np.array([box_coords]) |
|
|
|
|
|
|
|
|
masks, scores, logits = predictor.predict(box=input_box, multimask_output=False) |
|
|
mask = masks[0] |
|
|
|
|
|
|
|
|
masked_image = image_np.copy() |
|
|
masked_image[~mask] = [255, 255, 255] |
|
|
|
|
|
|
|
|
result_img = Image.fromarray(masked_image) |
|
|
return result_img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_folder_path = "A_images_resized" |
|
|
checkpoint_path = "sam_vit_h_4b8939.pth" |
|
|
predictor = load_sam_model(checkpoint_path=checkpoint_path) |
|
|
print("okkkkk") |
|
|
|
|
|
for img_path in os.listdir(image_folder_path): |
|
|
|
|
|
image_pil = Image.open(os.path.join(image_folder_path, img_path)).convert("RGB") |
|
|
image_np = np.array(image_pil) |
|
|
predictor.set_image(image_np) |
|
|
print("12345") |
|
|
|
|
|
|
|
|
with open(f"A_labels_resized/{img_path.removesuffix('.png')}.txt", "r") as f: |
|
|
lines = f.readlines() |
|
|
boxes = [list(map(float, line.strip().split())) for line in lines] |
|
|
box_coords = yolo_to_xyxy(boxes, 1024, 1024) |
|
|
for idx, box_coord in enumerate(box_coords): |
|
|
label, x, y, x1, y1 = box_coord[0], box_coord[1], box_coord[2], box_coord[3], box_coord[4] |
|
|
box_coord = (x, y, x1, y1) |
|
|
result_img = segment(image_np, box_coord, predictor) |
|
|
result_img.save(f"layer_image/{img_path.removesuffix('.png')}_{idx}_{label}.png") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|