bytchew's picture
Update app.py
2c03b99 verified
import PIL.Image
from PIL import ImageDraw, ImageFont
import gradio as gr
import numpy as np
from celeb_indicies import *
import PIL
import os
import platform
import time
from huggingface_hub import hf_hub_download
import huggingface_hub
from diffusers import DiffusionPipeline
import torch
from torch import nn
import torchvision
# By default, dlib will be compiled locally when installed via pip, which takes so long it
# Causes huggingface to time out during the build process.
# To avoid this, check if we are running on a Linux system, and if so load a binary of dlib compiled for x86_64 Linux
if(platform.system() == 'Linux'):
os.system("pip install ./dlib-19.24.99-cp310-cp310-linux_x86_64.whl")
os.system("pip install face-recognition")
start_time = time.time()
import face_detection
print(f"took {(time.time() - start_time) / 60} minutes to load face_recognition")
# Change these values to switch the model you are using and the name of the weights file in this model
model_repo_id = "CSSE416-final-project/faceRecogModel"
weight_file_id = "modelWeights101.bin"
# 1. Load the model from Hugging Face Hub
def load_model(repo_id):
# Download the model weights from the repo
weights_path = hf_hub_download(repo_id=model_repo_id, filename=weight_file_id)
# Initialize the ResNet-18 architecture
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 100) # Adjust for your number of classes
# Load the model weights
state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval() # Set the model to evaluation mode
return model
# 2. Load model
model = load_model(model_repo_id)
# 3. Define how to transform image
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor()
])
# 4. Preprocess and display the image
def process_image_str(groupImageFilePath: str):
groupImage = PIL.Image.open(groupImageFilePath)
locations, images = face_detection.getCroppedImages(groupImage)
groupImage_d = ImageDraw.ImageDraw(groupImage)
font = ImageFont.truetype("Arial Bold.ttf", 30)
labels = "| | Name | Certainty | \n | -------- | ------- | ------- |\n"
n = 1
for image, location in zip(images, locations):
# Process the image
intputTensor = transforms(image).unsqueeze(0)
# Do AI stuff here and format output
with torch.no_grad():
outputs_t = model(intputTensor)
cert, pred_t = torch.max(torch.softmax(outputs_t, dim=1), dim=1)
groupImage_d.rectangle(location, outline=(0, 255, 0), width=2)
groupImage_d.text((location[0] + 4, location[1] + 2), f"{n}", fill=(0, 255, 0), font=font)
labels += f"| {n} | {celeb_list[pred_t.item()]} | {int(cert.item() * 100)}% | \n"
n += 1
return [gr.Image(groupImage), gr.Markdown(labels)]
# return gr.Image(images[0])
# 5. Create the Gradio interface
interface = gr.Interface(
fn=process_image_str, # Function to process the image
inputs=gr.Image(type='filepath', label="Input Image"), # Upload input
outputs=[gr.Image(label="Output"), gr.Markdown(label="Output Legend")], # Display output
allow_flagging='never',
title="Celebrity Face Detector",
description="Upload a picture of a celebrity or group of celebrities to identify them (ex. Jeff Bezos)"
)
# 6. Launch the app
if __name__ == "__main__":
interface.launch()