| from face_detector import FaceDetector | |
| from model_old import ResNet18 | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from PIL import Image | |
| from util import draw_bboxes, draw_label_on_bbox | |
| import torchvision.transforms as T | |
| class FaceExpressionRecognizer: | |
| _DATASET_MEAN = 0.5077385902404785 | |
| _DATASET_STD = 0.255077600479126 | |
| def __init__(self): | |
| self.face_detector = FaceDetector() | |
| self.fer_classifier = _make_fer_classifier() | |
| self.post_process = T.Compose([ | |
| T.Grayscale(), | |
| T.ConvertImageDtype(torch.float32), | |
| T.Normalize(FaceExpressionRecognizer._DATASET_MEAN, FaceExpressionRecognizer._DATASET_STD) | |
| ]) | |
| self.idx_to_label = { | |
| 0: 'angry', | |
| 1: 'disgust', | |
| 2: 'fear', | |
| 3: 'happy', | |
| 4: 'neutral', | |
| 5: 'sad', | |
| 6: 'surprise', | |
| } | |
| def handle_frame(self, image: Image.Image) -> Image.Image: | |
| bboxes = self.face_detector.detect_bboxes(image) | |
| if bboxes is None: | |
| return image | |
| extracted_faces = self.face_detector.extract_faces(image, bboxes) | |
| extracted_faces = self.post_process(extracted_faces) | |
| preds = self.fer_classifier(extracted_faces).argmax(dim=1) | |
| print(f'Preds: {preds}') | |
| preds = preds.tolist() | |
| img_w_boxes = draw_bboxes(image.copy(), bboxes, (255, 0, 0)) | |
| image_w_boxes_arr = np.array(img_w_boxes) | |
| for bbox, pred in zip(bboxes, preds): | |
| image_w_boxes_arr = draw_label_on_bbox(image_w_boxes_arr, bbox, self.idx_to_label[pred]) | |
| return Image.fromarray(image_w_boxes_arr) | |
| def _make_fer_classifier() -> nn.Module: | |
| model = ResNet18(1, 7) | |
| fer_fc = nn.Linear(256, 7) | |
| model = nn.Sequential(*list(model.children())[:-1]) | |
| model = nn.Sequential(*model, fer_fc) | |
| model.load_state_dict(torch.load('./saved_models/fer_model.pth', map_location=torch.device('cpu'))) | |
| return model | |