import os import re import cv2 import torch import numpy as np import math from math import sqrt import pandas as pd from PIL import Image from tqdm import tqdm import argparse import matplotlib import warnings warnings.filterwarnings("ignore") from ultralytics import YOLO import matplotlib.image as mpimg from matplotlib.patches import Rectangle from segment_anything import sam_model_registry, SamPredictor from face_side_prediction import * import gradio as gr ROTATION_MARGIN = 0.5 * (sqrt(2) - 1.0) def check_duplications(model, boxes): boxes = boxes.cpu() df = pd.DataFrame(boxes, columns=['xmin', 'ymin', 'xmax', 'ymax', 'prob', 'cls']) if df.shape[0] >= 2: df['area'] = (df['xmax'] - df['xmin']) * (df['ymax'] - df['ymin']) df = df.sort_values(['area'], ascending=[False]) df = df.drop_duplicates(subset='cls', keep='first') del df['area'] return torch.tensor(df.to_numpy()) return boxes def check_missing(model, boxes, image_path, conf_threshold=0.8): empty_flag = len(boxes) while empty_flag == 0 and conf_threshold > 0: conf_threshold -= 0.1 result = model.predict(source=image_path, conf=conf_threshold, verbose=False) boxes = list(result)[0].boxes.data empty_flag = len(boxes) return boxes def extract_masked_image(image, mask): # Ensure that the mask is boolean mask = mask.astype(bool) # Create an array of zeros with the same shape as the image background = np.zeros_like(image) # Copy the masked area from the original image onto the background background[mask] = image[mask] return background def apply_mask(image, mask, alpha=0.4): """Apply a mask to the image with the given color and alpha transparency.""" # Ensure that the mask is a binary mask mask = mask.astype(bool) # Create an overlay with the blue color overlay = np.zeros_like(image, dtype=np.uint8) overlay[..., 0] = 255 # Blue channel overlay[..., 1] = 0 # Green channel (should be 0 for pure blue) overlay[..., 2] = 0 # Red channel (should be 0 for pure blue) # Apply the overlay wherever the mask is true combined_image = image combined_image[mask] = combined_image[mask] * (1 - alpha) + overlay[mask] * alpha return combined_image def draw_box(image, box, score, color=(0, 255, 0), box_thickness=5): # Draw the bounding box on the image cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), color, box_thickness) return image def initialize_sam_models(): sam_checkpoint_h = "./pretrained_checkpoint/sam_hq_vit_h.pth" sam_h = sam_model_registry["vit_h"](checkpoint=sam_checkpoint_h) sam_h.to(device="cpu") predictor_h = SamPredictor(sam_h) sam_checkpoint_l = "./pretrained_checkpoint/sam_hq_vit_l.pth" sam_l = sam_model_registry["vit_l"](checkpoint=sam_checkpoint_l) sam_l.to(device="cpu") predictor_l = SamPredictor(sam_l) return predictor_h, predictor_l def initialize_yolo_model(): model = YOLO('./pretrained_checkpoint/yolov8m.pt') return model def initialize_effecient_model(): model = models.efficientnet_b0(weights=None) num_features = model.classifier[1].in_features model.classifier[1] = torch.nn.Linear(num_features, 1) model.load_state_dict(torch.load('./pretrained_checkpoint/effecient_b0.pth', map_location=torch.device('cpu'))) return model def crop_image(image, margin=ROTATION_MARGIN): gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) # Find the coordinates of the non-zero pixels coords = cv2.findNonZero(thresh) x_min, y_min = coords.min(axis=0)[0] x_max, y_max = coords.max(axis=0)[0] # Keep aspect ratio and make square. width = x_max - x_min height = y_max - y_min size_new = int((1.0 + 2 * margin) * max(width, height)) if width > height: x_offset = int(margin * width) y_offset = int(0.5 * ((1.0 + 2 * margin) * width - height)) else: x_offset = int(0.5 * ((1.0 + 2 * margin) * height - width)) y_offset = int(margin * height) new_image = np.zeros((size_new, size_new, 3), dtype=np.uint8) new_image[:0:size_new] = (255, 255, 0) # Green in BGR format # Copy the cropped image to the new image at position x_offset, y_offset. new_image[y_offset:y_offset + height, x_offset:x_offset + width] = image[y_min:y_max, x_min:x_max] return new_image def get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h, predictor_l, threshold=0.75): for i, (mask, score) in enumerate(zip(masks, scores)): if score > threshold: # Save only the masked area with black background extracted_image = extract_masked_image(image.copy(), mask) # Crop to the non-black pixels only and keep aspect ratio. extracted_image = crop_image(extracted_image) image_name_1, image_1 = process_single_image(effecient_model, extracted_image, image_name) # Save image with mask and bounding box image_with_mask = apply_mask(image.copy(), mask) image_2 = draw_box(image_with_mask, input_box[i], score) height_1, width_1 = image_1.shape[:2] # Resize image_2 to match the dimensions of image_1 image_2 = cv2.resize(image_2, (width_1, height_1)) side = re.search(r"\[(.*?)\]", image_name_1).group(1) if side == 'R': side_name = "Right" elif side == 'L': side_name = "Left" return image_2, extracted_image, image_1, side_name else: predictor_h.set_image(image) masks, scores, _ = predictor_h.predict( point_coords=None, point_labels=None, box=input_box, multimask_output=False, hq_token_only=False, return_logits=False ) # Recursive call with predictor_h to process low confidence masks get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h, predictor_l, threshold=0, save_extracted=save_extracted) def process_image(image, image_name, model, effecient_model, predictor_h, predictor_l, save_dir='./output'): try: result = model.predict(source=image, show=False, save=False, conf=0.5, verbose=False) boxes = list(result)[0].boxes.data boxes = check_missing(model, boxes, image) box = check_duplications(model, boxes)[0].to(torch.int32)[:4] input_box = box.numpy().reshape((1, -1)) predictor_l.set_image(image) masks, scores, _ = predictor_l.predict( point_coords=None, point_labels=None, box=input_box, multimask_output=False, hq_token_only=False, return_logits=False ) # Determine if we need to save the output or perform high-quality predictions img1, img2, img3, side = get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h, predictor_l) return img1, img2, img3, side except OSError as e: print(f"Error processing image {image_file}: {e}") def process_image_gradio(image): if image is None: return None, None, None, "" image = np.array(image) image = image[:, :, ::-1].copy() # Convert RGB to BGR image_name = "image.jpg" img1, img2, img3, side = process_image(image, image_name, yolo_model, effecient_model, predictor_h, predictor_l) img1 = Image.fromarray(img1[:, :, ::-1]) img2 = Image.fromarray(img2[:, :, ::-1]) img3 = Image.fromarray(img3[:, :, ::-1]) return img1, img2, img3, side # Initialize the models predictor_h, predictor_l = initialize_sam_models() yolo_model = initialize_yolo_model() effecient_model = initialize_effecient_model() def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData image_path = evt.value['image']['path'] image = cv2.imread(image_path, cv2.IMREAD_COLOR) image_name = "image.jpg" img1, img2, img3, side = process_image(image, image_name, yolo_model, effecient_model, predictor_h, predictor_l) img1 = Image.fromarray(img1[:, :, ::-1]) img2 = Image.fromarray(img2[:, :, ::-1]) img3 = Image.fromarray(img3[:, :, ::-1]) return img1, img2, img3, side examples = [ "raw_images/1.jpeg", "raw_images/2.jpeg", "raw_images/3.jpeg", "raw_images/4.jpeg", "raw_images/5.jpeg", "raw_images/6.jpeg", "raw_images/7.jpeg", "raw_images/8.jpeg", "raw_images/9.jpeg", "raw_images/10.jpeg" ] style = """ #header-section { background-color: #f0f0f0; padding: 20px; text-align: center; font-size: 1.5em; margin-bottom: 0px; } .light iframe { scrolling: yes !important; /* Enable scroll bars if content overflows */ overflow: auto; } """ with gr.Blocks(css=style) as demo: gr.HTML("""