Spaces:
Sleeping
Sleeping
| import cv2 | |
| from PIL import Image | |
| import streamlit as st | |
| import tempfile | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.models import resnet50 | |
| from mtcnn import MTCNN | |
| from skimage.feature import hog | |
| import joblib | |
| import numpy as np | |
| class VGGFaceEmbedding(nn.Module): | |
| def __init__(self): | |
| super(VGGFaceEmbedding, self).__init__() | |
| self.base_model = resnet50(pretrained=True) | |
| self.base_model = nn.Sequential(*list(self.base_model.children())[:-2]) | |
| self.pooling = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.flatten = nn.Flatten() | |
| def forward(self, x): | |
| x = self.base_model(x) | |
| x = self.pooling(x) | |
| x = self.flatten(x) | |
| return x | |
| class L1Dist(nn.Module): | |
| def __init__(self): | |
| super(L1Dist, self).__init__() | |
| def forward(self, input_embedding, validation_embedding): | |
| return torch.abs(input_embedding - validation_embedding) | |
| class SiameseNetwork(nn.Module): | |
| def __init__(self): | |
| super(SiameseNetwork, self).__init__() | |
| self.embedding = VGGFaceEmbedding() | |
| self.distance = L1Dist() | |
| self.fc1 = nn.Linear(2048, 512) | |
| self.fc2 = nn.Linear(512, 1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, input_image, validation_image): | |
| input_embedding = self.embedding(input_image) | |
| validation_embedding = self.embedding(validation_image) | |
| distances = self.distance(input_embedding, validation_embedding) | |
| x = self.fc1(distances) | |
| x = self.fc2(x) | |
| x = self.sigmoid(x) | |
| return x | |
| def preprocess_image_siamese(temp_face_path): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| img = Image.open(temp_face_path).convert("RGB") | |
| return transform(img) | |
| def preprocess_image_svm(img): | |
| img = cv2.resize(img, (224, 224)) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| return img | |
| def extract_hog_features(img): | |
| hog_features = hog(img, orientations=9, pixels_per_cell=(16, 16), cells_per_block=(4, 4)) | |
| return hog_features | |
| def get_face(img): | |
| detector = MTCNN() | |
| faces = detector.detect_faces(img) | |
| if faces: | |
| x1, y1, w, h = faces[0]['box'] | |
| x1, y1 = abs(x1), abs(y1) | |
| x2, y2 = x1 + w, y1 + h | |
| return img[y1:y2, x1:x2] | |
| return None | |
| def verify(image, model, person, validation_image=None, threshold=None): | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image: | |
| temp_image.write(image.read()) | |
| temp_image_path = temp_image.name | |
| image = cv2.imread(temp_image_path) | |
| face = get_face(image) | |
| temp_face_path = tempfile.mktemp(suffix=".jpg") | |
| cv2.imwrite(temp_face_path, face) | |
| if face is not None: | |
| if model == "Siamese": | |
| siamese = SiameseNetwork() | |
| siamese.load_state_dict(torch.load(f'siamese_{person.lower()}.pth')) | |
| siamese.eval() | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as validation_temp_image: | |
| validation_temp_image.write(validation_image.read()) | |
| validation_temp_image_path = validation_temp_image.name | |
| validation_image = cv2.imread(validation_temp_image_path) | |
| validation_face = get_face(validation_image) | |
| st.image([face, validation_face], caption=["Face 1", "Face 2"], width=200) | |
| validation_temp_face_path = tempfile.mktemp(suffix=".jpg") | |
| cv2.imwrite(validation_temp_face_path, validation_face) | |
| face = preprocess_image_siamese(temp_face_path) | |
| validation_face = preprocess_image_siamese(validation_temp_image_path) | |
| face = face.unsqueeze(0) | |
| validation_face = validation_face.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = siamese(face, validation_face) | |
| probability = output.item() | |
| pred = 1.0 if probability > threshold else 0.0 | |
| if pred == 1: | |
| st.write("Match") | |
| else: | |
| st.write("Not Match") | |
| elif model == "HOG-SVM": | |
| with open(f'./svm_{person.lower()}.pkl', 'rb') as f: | |
| svm = joblib.load(f) | |
| with open(f'./pca_{person.lower()}.pkl', 'rb') as f: | |
| pca = joblib.load(f) | |
| face = cv2.imread(temp_face_path) | |
| face = preprocess_image_svm(face) | |
| st.image(face, caption="Face 1", width=200) | |
| hog = extract_hog_features(face) | |
| hog_pca = pca.transform([hog]) | |
| pred = svm.predict(hog_pca) | |
| if pred == 1: | |
| st.write("Match") | |
| else: | |
| st.write("Not Match") | |
| def main(): | |
| st.title("Face Verification") | |
| person_dict = { | |
| "Theo": 0.542, | |
| "Deverel": 0.5, | |
| "Justin": 0.5 | |
| } | |
| model = st.selectbox("Select Model", ["Siamese", "HOG-SVM"]) | |
| person = st.selectbox("Select Person", person_dict.keys()) | |
| if model == "Siamese": | |
| uploaded_image = st.file_uploader("Upload Validation Image (Siamese)", type=["jpg", "png"]) | |
| enable = st.checkbox("Enable camera") | |
| captured_image = st.camera_input("Take a picture", disabled=not enable) | |
| if captured_image and model == "Siamese": | |
| verify(captured_image, model, person, uploaded_image, person_dict.get(person)) | |
| elif captured_image and model == "HOG-SVM": | |
| verify(captured_image, model, person) | |
| if __name__ == "__main__": | |
| main() | |