File size: 3,614 Bytes
2ed75e5
c217d80
e9c6137
 
c8d4b10
0da216f
2ed75e5
4d35d6e
2e92fc4
9df5892
0da216f
 
 
 
 
 
badde22
617b1f9
 
badde22
344d72a
 
 
 
 
ec30251
344d72a
e9c6137
 
617b1f9
0da216f
a84117a
0da216f
 
 
 
 
 
 
 
a84117a
0da216f
e1373ca
0da216f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb3348b
0da216f
 
 
 
 
 
c217d80
0da216f
c217d80
e42d0f5
c217d80
fdfeac0
c217d80
 
 
0da216f
f4357ff
c217d80
f4357ff
0da216f
 
fdfeac0
c217d80
e42d0f5
fdfeac0
c217d80
 
 
fdfeac0
c217d80
0da216f
e9c6137
0da216f
2ed75e5
0da216f
fdfeac0
 
c4649fe
2ed75e5
2c03b99
2ed75e5
e9c6137
0da216f
e9c6137
0da216f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import PIL.Image
from PIL import ImageDraw, ImageFont
import gradio as gr
import numpy as np
from celeb_indicies import *

import PIL
import os
import platform
import time
from huggingface_hub import hf_hub_download
import huggingface_hub
from diffusers import DiffusionPipeline
import torch
from torch import nn
import torchvision

# By default, dlib will be compiled locally when installed via pip, which takes so long it 
# Causes huggingface to time out during the build process.
# To avoid this, check if we are running on a Linux system, and if so load a binary of dlib compiled for x86_64 Linux
if(platform.system() == 'Linux'):    
    os.system("pip install ./dlib-19.24.99-cp310-cp310-linux_x86_64.whl")
    os.system("pip install face-recognition")

start_time = time.time()
import face_detection
print(f"took {(time.time() - start_time) / 60} minutes to load face_recognition")


# Change these values to switch the model you are using and the name of the weights file in this model
model_repo_id = "CSSE416-final-project/faceRecogModel"
weight_file_id = "modelWeights101.bin"


# 1. Load the model from Hugging Face Hub
def load_model(repo_id):
    # Download the model weights from the repo
    weights_path = hf_hub_download(repo_id=model_repo_id, filename=weight_file_id)

    # Initialize the ResNet-18 architecture
    model = torchvision.models.resnet18(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 100)  # Adjust for your number of classes

    # Load the model weights
    state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict)
    model.eval()  # Set the model to evaluation mode
    return model


# 2. Load model
model = load_model(model_repo_id)


# 3. Define how to transform image
transforms = torchvision.transforms.Compose(
[
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor()
])


# 4. Preprocess and display the image
def process_image_str(groupImageFilePath: str):
    groupImage = PIL.Image.open(groupImageFilePath)    
    locations, images = face_detection.getCroppedImages(groupImage)
    groupImage_d = ImageDraw.ImageDraw(groupImage)
    font = ImageFont.truetype("Arial Bold.ttf", 30)
    
    labels = "| | Name | Certainty | \n | -------- | ------- | ------- |\n"
    n = 1
    
    for image, location in zip(images, locations):
        # Process the image
        intputTensor = transforms(image).unsqueeze(0)
        
        # Do AI stuff here and format output
        with torch.no_grad():
            outputs_t = model(intputTensor)
            cert, pred_t = torch.max(torch.softmax(outputs_t, dim=1), dim=1)
            groupImage_d.rectangle(location, outline=(0, 255, 0), width=2)
            groupImage_d.text((location[0] + 4, location[1] + 2), f"{n}", fill=(0, 255, 0), font=font)
            labels += f"| {n} | {celeb_list[pred_t.item()]} | {int(cert.item() * 100)}% | \n"
        
        n += 1
    
    return [gr.Image(groupImage), gr.Markdown(labels)]
    # return gr.Image(images[0])


# 5. Create the Gradio interface
interface = gr.Interface(
    fn=process_image_str,  # Function to process the image
    inputs=gr.Image(type='filepath', label="Input Image"),  # Upload input
    outputs=[gr.Image(label="Output"), gr.Markdown(label="Output Legend")],  # Display output
    allow_flagging='never',
    title="Celebrity Face Detector",
    description="Upload a picture of a celebrity or group of celebrities to identify them (ex. Jeff Bezos)"
)

# 6. Launch the app
if __name__ == "__main__":
    interface.launch()