Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import math | |
| from math import sqrt | |
| import pandas as pd | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import argparse | |
| import matplotlib | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| from ultralytics import YOLO | |
| import matplotlib.image as mpimg | |
| from matplotlib.patches import Rectangle | |
| from segment_anything import sam_model_registry, SamPredictor | |
| from face_side_prediction import * | |
| import gradio as gr | |
| ROTATION_MARGIN = 0.5 * (sqrt(2) - 1.0) | |
| def check_duplications(model, boxes): | |
| boxes = boxes.cpu() | |
| df = pd.DataFrame(boxes, columns=['xmin', 'ymin', 'xmax', 'ymax', 'prob', 'cls']) | |
| if df.shape[0] >= 2: | |
| df['area'] = (df['xmax'] - df['xmin']) * (df['ymax'] - df['ymin']) | |
| df = df.sort_values(['area'], ascending=[False]) | |
| df = df.drop_duplicates(subset='cls', keep='first') | |
| del df['area'] | |
| return torch.tensor(df.to_numpy()) | |
| return boxes | |
| def check_missing(model, boxes, image_path, conf_threshold=0.8): | |
| empty_flag = len(boxes) | |
| while empty_flag == 0 and conf_threshold > 0: | |
| conf_threshold -= 0.1 | |
| result = model.predict(source=image_path, conf=conf_threshold, verbose=False) | |
| boxes = list(result)[0].boxes.data | |
| empty_flag = len(boxes) | |
| return boxes | |
| def extract_masked_image(image, mask): | |
| # Ensure that the mask is boolean | |
| mask = mask.astype(bool) | |
| # Create an array of zeros with the same shape as the image | |
| background = np.zeros_like(image) | |
| # Copy the masked area from the original image onto the background | |
| background[mask] = image[mask] | |
| return background | |
| def apply_mask(image, mask, alpha=0.4): | |
| """Apply a mask to the image with the given color and alpha transparency.""" | |
| # Ensure that the mask is a binary mask | |
| mask = mask.astype(bool) | |
| # Create an overlay with the blue color | |
| overlay = np.zeros_like(image, dtype=np.uint8) | |
| overlay[..., 0] = 255 # Blue channel | |
| overlay[..., 1] = 0 # Green channel (should be 0 for pure blue) | |
| overlay[..., 2] = 0 # Red channel (should be 0 for pure blue) | |
| # Apply the overlay wherever the mask is true | |
| combined_image = image | |
| combined_image[mask] = combined_image[mask] * (1 - alpha) + overlay[mask] * alpha | |
| return combined_image | |
| def draw_box(image, box, score, color=(0, 255, 0), box_thickness=5): | |
| # Draw the bounding box on the image | |
| cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), color, box_thickness) | |
| return image | |
| def initialize_sam_models(): | |
| sam_checkpoint_h = "./pretrained_checkpoint/sam_hq_vit_h.pth" | |
| sam_h = sam_model_registry["vit_h"](checkpoint=sam_checkpoint_h) | |
| sam_h.to(device="cpu") | |
| predictor_h = SamPredictor(sam_h) | |
| sam_checkpoint_l = "./pretrained_checkpoint/sam_hq_vit_l.pth" | |
| sam_l = sam_model_registry["vit_l"](checkpoint=sam_checkpoint_l) | |
| sam_l.to(device="cpu") | |
| predictor_l = SamPredictor(sam_l) | |
| return predictor_h, predictor_l | |
| def initialize_yolo_model(): | |
| model = YOLO('./pretrained_checkpoint/yolov8m.pt') | |
| return model | |
| def initialize_effecient_model(): | |
| model = models.efficientnet_b0(weights=None) | |
| num_features = model.classifier[1].in_features | |
| model.classifier[1] = torch.nn.Linear(num_features, 1) | |
| model.load_state_dict(torch.load('./pretrained_checkpoint/effecient_b0.pth', map_location=torch.device('cpu'))) | |
| return model | |
| def crop_image(image, margin=ROTATION_MARGIN): | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| _, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) | |
| # Find the coordinates of the non-zero pixels | |
| coords = cv2.findNonZero(thresh) | |
| x_min, y_min = coords.min(axis=0)[0] | |
| x_max, y_max = coords.max(axis=0)[0] | |
| # Keep aspect ratio and make square. | |
| width = x_max - x_min | |
| height = y_max - y_min | |
| size_new = int((1.0 + 2 * margin) * max(width, height)) | |
| if width > height: | |
| x_offset = int(margin * width) | |
| y_offset = int(0.5 * ((1.0 + 2 * margin) * width - height)) | |
| else: | |
| x_offset = int(0.5 * ((1.0 + 2 * margin) * height - width)) | |
| y_offset = int(margin * height) | |
| new_image = np.zeros((size_new, size_new, 3), dtype=np.uint8) | |
| new_image[:0:size_new] = (255, 255, 0) # Green in BGR format | |
| # Copy the cropped image to the new image at position x_offset, y_offset. | |
| new_image[y_offset:y_offset + height, x_offset:x_offset + width] = image[y_min:y_max, x_min:x_max] | |
| return new_image | |
| def get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h, predictor_l, | |
| threshold=0.75): | |
| for i, (mask, score) in enumerate(zip(masks, scores)): | |
| if score > threshold: | |
| # Save only the masked area with black background | |
| extracted_image = extract_masked_image(image.copy(), mask) | |
| # Crop to the non-black pixels only and keep aspect ratio. | |
| extracted_image = crop_image(extracted_image) | |
| image_name_1, image_1 = process_single_image(effecient_model, extracted_image, image_name) | |
| # Save image with mask and bounding box | |
| image_with_mask = apply_mask(image.copy(), mask) | |
| image_2 = draw_box(image_with_mask, input_box[i], score) | |
| height_1, width_1 = image_1.shape[:2] | |
| # Resize image_2 to match the dimensions of image_1 | |
| image_2 = cv2.resize(image_2, (width_1, height_1)) | |
| side = re.search(r"\[(.*?)\]", image_name_1).group(1) | |
| if side == 'R': | |
| side_name = "Right" | |
| elif side == 'L': | |
| side_name = "Left" | |
| return image_2, extracted_image, image_1, side_name | |
| else: | |
| predictor_h.set_image(image) | |
| masks, scores, _ = predictor_h.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=input_box, | |
| multimask_output=False, | |
| hq_token_only=False, | |
| return_logits=False | |
| ) | |
| # Recursive call with predictor_h to process low confidence masks | |
| get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h, | |
| predictor_l, threshold=0, save_extracted=save_extracted) | |
| def process_image(image, image_name, model, effecient_model, predictor_h, predictor_l, save_dir='./output'): | |
| try: | |
| result = model.predict(source=image, show=False, save=False, conf=0.5, verbose=False) | |
| boxes = list(result)[0].boxes.data | |
| boxes = check_missing(model, boxes, image) | |
| box = check_duplications(model, boxes)[0].to(torch.int32)[:4] | |
| input_box = box.numpy().reshape((1, -1)) | |
| predictor_l.set_image(image) | |
| masks, scores, _ = predictor_l.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=input_box, | |
| multimask_output=False, | |
| hq_token_only=False, | |
| return_logits=False | |
| ) | |
| # Determine if we need to save the output or perform high-quality predictions | |
| img1, img2, img3, side = get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h, predictor_l) | |
| return img1, img2, img3, side | |
| except OSError as e: | |
| print(f"Error processing image {image_file}: {e}") | |
| def process_image_gradio(image): | |
| if image is None: | |
| return None, None, None, "" | |
| image = np.array(image) | |
| image = image[:, :, ::-1].copy() # Convert RGB to BGR | |
| image_name = "image.jpg" | |
| img1, img2, img3, side = process_image(image, image_name, yolo_model, effecient_model, predictor_h, predictor_l) | |
| img1 = Image.fromarray(img1[:, :, ::-1]) | |
| img2 = Image.fromarray(img2[:, :, ::-1]) | |
| img3 = Image.fromarray(img3[:, :, ::-1]) | |
| return img1, img2, img3, side | |
| # Initialize the models | |
| predictor_h, predictor_l = initialize_sam_models() | |
| yolo_model = initialize_yolo_model() | |
| effecient_model = initialize_effecient_model() | |
| def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData | |
| image_path = evt.value['image']['path'] | |
| image = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
| image_name = "image.jpg" | |
| img1, img2, img3, side = process_image(image, image_name, yolo_model, effecient_model, predictor_h, predictor_l) | |
| img1 = Image.fromarray(img1[:, :, ::-1]) | |
| img2 = Image.fromarray(img2[:, :, ::-1]) | |
| img3 = Image.fromarray(img3[:, :, ::-1]) | |
| return img1, img2, img3, side | |
| examples = [ | |
| "raw_images/1.jpeg", | |
| "raw_images/2.jpeg", | |
| "raw_images/3.jpeg", | |
| "raw_images/4.jpeg", | |
| "raw_images/5.jpeg", | |
| "raw_images/6.jpeg", | |
| "raw_images/7.jpeg", | |
| "raw_images/8.jpeg", | |
| "raw_images/9.jpeg", | |
| "raw_images/10.jpeg" | |
| ] | |
| style = """ | |
| #header-section { | |
| background-color: #f0f0f0; | |
| padding: 20px; | |
| text-align: center; | |
| font-size: 1.5em; | |
| margin-bottom: 0px; | |
| } | |
| .light iframe { | |
| scrolling: yes !important; /* Enable scroll bars if content overflows */ | |
| overflow: auto; | |
| } | |
| """ | |
| with gr.Blocks(css=style) as demo: | |
| gr.HTML(""" | |
| <div id='header-section'> | |
| <h1>DeepTurtle</h1> | |
| </div> | |
| """) | |
| gr.Markdown(""" | |
| <div style="background-color: #f0f0f0; padding: 10px;"> | |
| DeepTurtle is an advanced image processing pipeline that includes the following steps: | |
| <ul> | |
| <li><strong>Face Detection:</strong> Utilizes YOLOv8 for detecting turtle faces.</li> | |
| <li><strong>Segmentation:</strong> Employs SAM-HQ models for segmenting the detected faces.</li> | |
| <li><strong>Realigning:</strong> A regressor realigns segmented faces horizontally.</li> | |
| <li><strong>Direction Classification:</strong> Postprocessing step to classify face direction as 'left' or 'right'.</li> | |
| </ul> | |
| </div> | |
| Additionally, you have two options for image input: | |
| <ul> | |
| <li><strong>Upload a Turtle Image:</strong> You can upload turtle images from your local device.</li> | |
| <li><strong>Select from the Gallery:</strong> At the end of this page, you can select from predefined turtle images in the gallery.</li> | |
| </ul> | |
| <hr> <!-- Horizontal splitter line --> | |
| """) | |
| with gr.Row(): | |
| gr.Column() | |
| with gr.Column(): | |
| gr.Markdown("Upload an image:") | |
| image_input = gr.Image(show_label=False, sources=["upload", "clipboard"]) | |
| gr.Column() | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Detected & Segmented Face") # Title for first output | |
| output1 = gr.Image(show_label=False, show_download_button=False) # First output image | |
| with gr.Column(): | |
| gr.Markdown("#### Cropped Turtle Face") # Title for second output | |
| output2 = gr.Image(show_label=False, show_download_button=False) # Second output image | |
| with gr.Column(): | |
| gr.Markdown("#### Horizontally-Aligned Face") # Title for third output | |
| output3 = gr.Image(show_label=False, show_download_button=False) # Third output image | |
| with gr.Row(): | |
| gr.Column() | |
| with gr.Column(): | |
| gr.Markdown("#### Face Side Orientation") | |
| side_text = gr.Text(show_label=False) | |
| gr.Column() | |
| #gr.Markdown("#### Select one of the examples") | |
| #gallery = gr.Gallery(value=examples, label="Examples", show_label=False, | |
| # elem_id="gallery", columns=[5], rows=[2],preview=False, | |
| # selected_index=0,height=500,show_download_button=False,show_share_button=False) | |
| #gallery.select(on_select, inputs=None, outputs=[output1, output2, output3, side_text]) | |
| image_input.change(process_image_gradio, inputs=image_input, outputs=[output1, output2, output3, side_text]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| examples = gr.Examples(examples=examples, label="Select one of the examples below", inputs=image_input, fn=process_image_gradio, outputs=[output1, output2, output3, side_text],cache_examples=True,examples_per_page=10) | |
| # Run the interface | |
| demo.launch(share=True) | |