DeepTurtle / app.py
Hamzah-ALQadasi's picture
Update app.py
8646d51
import os
import re
import cv2
import torch
import numpy as np
import math
from math import sqrt
import pandas as pd
from PIL import Image
from tqdm import tqdm
import argparse
import matplotlib
import warnings
warnings.filterwarnings("ignore")
from ultralytics import YOLO
import matplotlib.image as mpimg
from matplotlib.patches import Rectangle
from segment_anything import sam_model_registry, SamPredictor
from face_side_prediction import *
import gradio as gr
ROTATION_MARGIN = 0.5 * (sqrt(2) - 1.0)
def check_duplications(model, boxes):
boxes = boxes.cpu()
df = pd.DataFrame(boxes, columns=['xmin', 'ymin', 'xmax', 'ymax', 'prob', 'cls'])
if df.shape[0] >= 2:
df['area'] = (df['xmax'] - df['xmin']) * (df['ymax'] - df['ymin'])
df = df.sort_values(['area'], ascending=[False])
df = df.drop_duplicates(subset='cls', keep='first')
del df['area']
return torch.tensor(df.to_numpy())
return boxes
def check_missing(model, boxes, image_path, conf_threshold=0.8):
empty_flag = len(boxes)
while empty_flag == 0 and conf_threshold > 0:
conf_threshold -= 0.1
result = model.predict(source=image_path, conf=conf_threshold, verbose=False)
boxes = list(result)[0].boxes.data
empty_flag = len(boxes)
return boxes
def extract_masked_image(image, mask):
# Ensure that the mask is boolean
mask = mask.astype(bool)
# Create an array of zeros with the same shape as the image
background = np.zeros_like(image)
# Copy the masked area from the original image onto the background
background[mask] = image[mask]
return background
def apply_mask(image, mask, alpha=0.4):
"""Apply a mask to the image with the given color and alpha transparency."""
# Ensure that the mask is a binary mask
mask = mask.astype(bool)
# Create an overlay with the blue color
overlay = np.zeros_like(image, dtype=np.uint8)
overlay[..., 0] = 255 # Blue channel
overlay[..., 1] = 0 # Green channel (should be 0 for pure blue)
overlay[..., 2] = 0 # Red channel (should be 0 for pure blue)
# Apply the overlay wherever the mask is true
combined_image = image
combined_image[mask] = combined_image[mask] * (1 - alpha) + overlay[mask] * alpha
return combined_image
def draw_box(image, box, score, color=(0, 255, 0), box_thickness=5):
# Draw the bounding box on the image
cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), color, box_thickness)
return image
def initialize_sam_models():
sam_checkpoint_h = "./pretrained_checkpoint/sam_hq_vit_h.pth"
sam_h = sam_model_registry["vit_h"](checkpoint=sam_checkpoint_h)
sam_h.to(device="cpu")
predictor_h = SamPredictor(sam_h)
sam_checkpoint_l = "./pretrained_checkpoint/sam_hq_vit_l.pth"
sam_l = sam_model_registry["vit_l"](checkpoint=sam_checkpoint_l)
sam_l.to(device="cpu")
predictor_l = SamPredictor(sam_l)
return predictor_h, predictor_l
def initialize_yolo_model():
model = YOLO('./pretrained_checkpoint/yolov8m.pt')
return model
def initialize_effecient_model():
model = models.efficientnet_b0(weights=None)
num_features = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(num_features, 1)
model.load_state_dict(torch.load('./pretrained_checkpoint/effecient_b0.pth', map_location=torch.device('cpu')))
return model
def crop_image(image, margin=ROTATION_MARGIN):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
# Find the coordinates of the non-zero pixels
coords = cv2.findNonZero(thresh)
x_min, y_min = coords.min(axis=0)[0]
x_max, y_max = coords.max(axis=0)[0]
# Keep aspect ratio and make square.
width = x_max - x_min
height = y_max - y_min
size_new = int((1.0 + 2 * margin) * max(width, height))
if width > height:
x_offset = int(margin * width)
y_offset = int(0.5 * ((1.0 + 2 * margin) * width - height))
else:
x_offset = int(0.5 * ((1.0 + 2 * margin) * height - width))
y_offset = int(margin * height)
new_image = np.zeros((size_new, size_new, 3), dtype=np.uint8)
new_image[:0:size_new] = (255, 255, 0) # Green in BGR format
# Copy the cropped image to the new image at position x_offset, y_offset.
new_image[y_offset:y_offset + height, x_offset:x_offset + width] = image[y_min:y_max, x_min:x_max]
return new_image
def get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h, predictor_l,
threshold=0.75):
for i, (mask, score) in enumerate(zip(masks, scores)):
if score > threshold:
# Save only the masked area with black background
extracted_image = extract_masked_image(image.copy(), mask)
# Crop to the non-black pixels only and keep aspect ratio.
extracted_image = crop_image(extracted_image)
image_name_1, image_1 = process_single_image(effecient_model, extracted_image, image_name)
# Save image with mask and bounding box
image_with_mask = apply_mask(image.copy(), mask)
image_2 = draw_box(image_with_mask, input_box[i], score)
height_1, width_1 = image_1.shape[:2]
# Resize image_2 to match the dimensions of image_1
image_2 = cv2.resize(image_2, (width_1, height_1))
side = re.search(r"\[(.*?)\]", image_name_1).group(1)
if side == 'R':
side_name = "Right"
elif side == 'L':
side_name = "Left"
return image_2, extracted_image, image_1, side_name
else:
predictor_h.set_image(image)
masks, scores, _ = predictor_h.predict(
point_coords=None,
point_labels=None,
box=input_box,
multimask_output=False,
hq_token_only=False,
return_logits=False
)
# Recursive call with predictor_h to process low confidence masks
get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h,
predictor_l, threshold=0, save_extracted=save_extracted)
def process_image(image, image_name, model, effecient_model, predictor_h, predictor_l, save_dir='./output'):
try:
result = model.predict(source=image, show=False, save=False, conf=0.5, verbose=False)
boxes = list(result)[0].boxes.data
boxes = check_missing(model, boxes, image)
box = check_duplications(model, boxes)[0].to(torch.int32)[:4]
input_box = box.numpy().reshape((1, -1))
predictor_l.set_image(image)
masks, scores, _ = predictor_l.predict(
point_coords=None,
point_labels=None,
box=input_box,
multimask_output=False,
hq_token_only=False,
return_logits=False
)
# Determine if we need to save the output or perform high-quality predictions
img1, img2, img3, side = get_result(masks, scores, input_box, image, image_name, save_dir, effecient_model, predictor_h, predictor_l)
return img1, img2, img3, side
except OSError as e:
print(f"Error processing image {image_file}: {e}")
def process_image_gradio(image):
if image is None:
return None, None, None, ""
image = np.array(image)
image = image[:, :, ::-1].copy() # Convert RGB to BGR
image_name = "image.jpg"
img1, img2, img3, side = process_image(image, image_name, yolo_model, effecient_model, predictor_h, predictor_l)
img1 = Image.fromarray(img1[:, :, ::-1])
img2 = Image.fromarray(img2[:, :, ::-1])
img3 = Image.fromarray(img3[:, :, ::-1])
return img1, img2, img3, side
# Initialize the models
predictor_h, predictor_l = initialize_sam_models()
yolo_model = initialize_yolo_model()
effecient_model = initialize_effecient_model()
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
image_path = evt.value['image']['path']
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image_name = "image.jpg"
img1, img2, img3, side = process_image(image, image_name, yolo_model, effecient_model, predictor_h, predictor_l)
img1 = Image.fromarray(img1[:, :, ::-1])
img2 = Image.fromarray(img2[:, :, ::-1])
img3 = Image.fromarray(img3[:, :, ::-1])
return img1, img2, img3, side
examples = [
"raw_images/1.jpeg",
"raw_images/2.jpeg",
"raw_images/3.jpeg",
"raw_images/4.jpeg",
"raw_images/5.jpeg",
"raw_images/6.jpeg",
"raw_images/7.jpeg",
"raw_images/8.jpeg",
"raw_images/9.jpeg",
"raw_images/10.jpeg"
]
style = """
#header-section {
background-color: #f0f0f0;
padding: 20px;
text-align: center;
font-size: 1.5em;
margin-bottom: 0px;
}
.light iframe {
scrolling: yes !important; /* Enable scroll bars if content overflows */
overflow: auto;
}
"""
with gr.Blocks(css=style) as demo:
gr.HTML("""
<div id='header-section'>
<h1>DeepTurtle</h1>
</div>
""")
gr.Markdown("""
<div style="background-color: #f0f0f0; padding: 10px;">
DeepTurtle is an advanced image processing pipeline that includes the following steps:
<ul>
<li><strong>Face Detection:</strong> Utilizes YOLOv8 for detecting turtle faces.</li>
<li><strong>Segmentation:</strong> Employs SAM-HQ models for segmenting the detected faces.</li>
<li><strong>Realigning:</strong> A regressor realigns segmented faces horizontally.</li>
<li><strong>Direction Classification:</strong> Postprocessing step to classify face direction as 'left' or 'right'.</li>
</ul>
</div>
Additionally, you have two options for image input:
<ul>
<li><strong>Upload a Turtle Image:</strong> You can upload turtle images from your local device.</li>
<li><strong>Select from the Gallery:</strong> At the end of this page, you can select from predefined turtle images in the gallery.</li>
</ul>
<hr> <!-- Horizontal splitter line -->
""")
with gr.Row():
gr.Column()
with gr.Column():
gr.Markdown("Upload an image:")
image_input = gr.Image(show_label=False, sources=["upload", "clipboard"])
gr.Column()
with gr.Row():
with gr.Column():
gr.Markdown("#### Detected & Segmented Face") # Title for first output
output1 = gr.Image(show_label=False, show_download_button=False) # First output image
with gr.Column():
gr.Markdown("#### Cropped Turtle Face") # Title for second output
output2 = gr.Image(show_label=False, show_download_button=False) # Second output image
with gr.Column():
gr.Markdown("#### Horizontally-Aligned Face") # Title for third output
output3 = gr.Image(show_label=False, show_download_button=False) # Third output image
with gr.Row():
gr.Column()
with gr.Column():
gr.Markdown("#### Face Side Orientation")
side_text = gr.Text(show_label=False)
gr.Column()
#gr.Markdown("#### Select one of the examples")
#gallery = gr.Gallery(value=examples, label="Examples", show_label=False,
# elem_id="gallery", columns=[5], rows=[2],preview=False,
# selected_index=0,height=500,show_download_button=False,show_share_button=False)
#gallery.select(on_select, inputs=None, outputs=[output1, output2, output3, side_text])
image_input.change(process_image_gradio, inputs=image_input, outputs=[output1, output2, output3, side_text])
with gr.Row():
with gr.Column():
examples = gr.Examples(examples=examples, label="Select one of the examples below", inputs=image_input, fn=process_image_gradio, outputs=[output1, output2, output3, side_text],cache_examples=True,examples_per_page=10)
# Run the interface
demo.launch(share=True)