| import cv2 | |
| import gradio as gr | |
| import torch | |
| from torchvision.transforms import v2 | |
| from model_old import ResNet18 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = ResNet18(1, 7) | |
| model.load_state_dict(torch.load("./model.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| face_cascade = cv2.CascadeClassifier("./haarcascade_frontalface_default.xml") | |
| class_list = ["angry", "disgust", "fear", "happy", "neutral", "sad", "surprise"] | |
| DATASET_MEAN = 0.5077385902404785 | |
| DATASET_STD = 0.255077600479126 | |
| preprocess = v2.Compose( | |
| [ | |
| v2.Grayscale(), | |
| v2.PILToTensor(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize(mean=(DATASET_MEAN,), std=(DATASET_STD,)), | |
| ] | |
| ) | |
| def get_probs(image): | |
| inp = preprocess(torch.tensor(image).permute(2, 0, 1).unsqueeze(0)) | |
| inp = inp.to(device) | |
| pred = model(inp).squeeze() | |
| probs = torch.softmax(pred, 0).cpu() | |
| return probs | |
| def draw_labels(image, cords: tuple, label: str): | |
| x, y, w, h = cords | |
| cv2.rectangle(image, (x, y), (x + w, y + h), (255, 0, 0), 2) | |
| image = cv2.putText( | |
| image, | |
| label, | |
| (x, y), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1, | |
| (0, 255, 0), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| return image | |
| def predict(image): | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4) | |
| for cords in faces: | |
| x, y, w, h = cords | |
| resized = cv2.resize(image[y : y + h, x : x + w], (48, 48), cv2.INTER_AREA) | |
| probs = get_probs(resized) | |
| label = class_list[probs.argmax(0).item()] | |
| image = draw_labels(image, cords, label) | |
| return image | |
| webcam_interface = gr.Interface( | |
| predict, | |
| inputs=gr.Image(sources=['webcam'], streaming=True, label='Input webcam'), | |
| outputs=gr.Image(label='Output video'), | |
| live=True, | |
| title='Webcam mode', | |
| description='Created by Czarna Magia AI Student Club', | |
| theme=gr.themes.Soft(), | |
| ) | |
| img_interface = gr.Interface( | |
| predict, | |
| inputs=gr.Image(sources=['webcam', 'upload', 'clipboard'], label='Input image'), | |
| outputs=gr.Image(label='Output image'), | |
| title='Image upload mode', | |
| description='Created by Czarna Magia AI Student Club', | |
| theme=gr.themes.Soft(), | |
| ) | |
| iface = gr.TabbedInterface( | |
| interface_list=[img_interface, webcam_interface], | |
| tab_names=['Image upload', 'Webcam'], | |
| title='Face Expression Recognizer', | |
| theme=gr.themes.Soft(), | |
| ) | |
| iface.launch() | |