image-embedding / app.py
DEVAN CHAUHAN
[add] anime face detection and crop
2418377
raw
history blame
3.85 kB
import gradio as gr
print("Loading models...")
import cv2
import numpy as np
from PIL import Image
from rembg import remove
from sentence_transformers import SentenceTransformer
import urllib.request
import pathlib
print("Libraries loaded")
# Load CLIP Model
image_model = SentenceTransformer("clip-ViT-B-32")
print("CLIP loaded")
# Load Anime Face Cascade
def load_anime_model():
url = "https://raw.githubusercontent.com/nagadomi/lbpcascade_animeface/master/lbpcascade_animeface.xml"
path = pathlib.Path("lbpcascade_animeface.xml")
if not path.exists():
print("Downloading anime face model...")
urllib.request.urlretrieve(url, path.as_posix())
return cv2.CascadeClassifier(path.as_posix())
# Load Human Face Cascade
def load_human_model():
path = pathlib.Path(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
return cv2.CascadeClassifier(path.as_posix())
anime_detector = load_anime_model()
human_detector = load_human_model()
print("Anime + Human detectors loaded")
# Embedding Function
def get_image_embedding(image):
emb = image_model.encode(image)
return {"embedding": emb.tolist()}
# Face Crop + Background Remove
def process_image(input_image, mode):
img = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Choose detector
if mode == "Anime":
detector = anime_detector
else:
detector = human_detector
faces = detector.detectMultiScale(
gray,
scaleFactor=1.1,
minNeighbors=5,
minSize=(24, 24)
)
if len(faces) == 0:
print("direct to background removal")
pil_image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
output = remove(pil_image)
output = output.resize((224, 224))
return "Success βœ…", output
x, y, w, h = faces[0]
height, width, _ = img.shape
# Expand bounding box
top_expand = 0.5
side_expand = 0.3
bottom_expand = 0.2
x1 = int(max(0, x - w * side_expand))
x2 = int(min(width, x + w + w * side_expand))
y1 = int(max(0, y - h * top_expand))
y2 = int(min(height, y + h + h * bottom_expand))
cropped = img[y1:y2, x1:x2]
pil_image = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
# Background removal
output = remove(pil_image)
# Resize for CLIP
output = output.resize((224, 224))
return "Success βœ…", output
# Gradio UI
with gr.Blocks() as demo:
with gr.Tab("Full Pipeline"):
mode_selector = gr.Dropdown(
choices=["Anime", "Human"],
value="Anime",
label="Detection Mode"
)
img_input = gr.Image(type="pil")
status = gr.Text()
img_output = gr.Image()
embedding_output = gr.JSON()
run_btn = gr.Button("Run Pipeline")
def run_pipeline(img, mode):
status_msg, processed_img = process_image(img, mode)
if status_msg != "Success βœ…":
return status_msg, None, {"embedding": None}
embedding = get_image_embedding(processed_img)
return status_msg, processed_img, embedding
run_btn.click(
run_pipeline,
inputs=[img_input, mode_selector],
outputs=[status, img_output, embedding_output]
)
with gr.Tab("Embedding Only"):
img_input2 = gr.Image(type="pil")
embedding_output2 = gr.JSON()
run_btn2 = gr.Button("Get Embedding")
def get_embedding_only(img):
embedding = get_image_embedding(img)
return embedding
run_btn2.click(
get_embedding_only,
inputs=img_input2,
outputs=embedding_output2
)
print("Launching demo...")
demo.queue(max_size=15).launch()