deedax commited on
Commit
6bab5d5
·
1 Parent(s): ddf2dbf

Upload 3 files

Browse files
Files changed (3) hide show
  1. demo.JPG +0 -0
  2. face_facts.py +155 -0
  3. utils.py +180 -0
demo.JPG ADDED
face_facts.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import tempfile
4
+ import time
5
+ import streamlit as st
6
+ from PIL import Image
7
+ from io import BytesIO
8
+ import torch
9
+ from ultralytics import YOLO
10
+ from utils import Ultimate_Lightning, age_lightning, gender_lightning, race_lightning
11
+ import pandas as pd
12
+ import mediapipe as mp
13
+ from torch.cuda import is_available as gpu_ready
14
+ from mediapipe.tasks import python
15
+ from mediapipe.tasks.python import vision
16
+
17
+ torch.manual_seed(42)
18
+
19
+ DEMO_IMAGE = 'demo.JPG'
20
+ GENDER_DICT = {
21
+ 1: 'Female',
22
+ 0: 'Male'
23
+ }
24
+ RACE_DICT = {
25
+ 0: 'White',
26
+ 1: 'Black',
27
+ 2: 'Asian',
28
+ 3: 'Indian',
29
+ 4: 'Others'
30
+ }
31
+ device = 'cuda' if gpu_ready() else 'cpu'
32
+
33
+ def load_face_detector():
34
+ base_options = python.BaseOptions(model_asset_path='models/detector.tflite')
35
+ options = vision.FaceDetectorOptions(base_options=base_options)
36
+ detector = vision.FaceDetector.create_from_options(options)
37
+ return detector
38
+
39
+ @st.cache_data
40
+ def load_model():
41
+ joint_model = Ultimate_Lightning()
42
+ age_model = age_lightning()
43
+ age_model.load_state_dict(torch.load('models/age_model.pth'))
44
+ race_model = race_lightning()
45
+ race_model.load_state_dict(torch.load('models/race_model.pth'))
46
+ gender_model = gender_lightning()
47
+ gender_model.load_state_dict(torch.load('models/gender_model.pth'))
48
+ joint_model.load_state_dict(torch.load('models/joint_model.pth'))
49
+
50
+ return age_model, gender_model, race_model, joint_model
51
+
52
+ st.set_page_config(page_title="Face-Facts")
53
+ st.title('Face-Facts')
54
+
55
+ app_mode = st.sidebar.selectbox('Choose Page', ['About the App', 'Run Face Facts'])
56
+
57
+ st.markdown(
58
+ """
59
+ <style>
60
+ [data-testid = 'stSidebar'][aria-expanded = 'true'] > div:first-child{
61
+ width: 350px
62
+ }
63
+ [data-testid = 'stSidebar'][aria-expanded = 'false'] > div:first-child{
64
+ width: 350px
65
+ margin-left: -350px
66
+ }
67
+ </style>
68
+ """, unsafe_allow_html = True
69
+ )
70
+
71
+ if app_mode == 'About the App':
72
+ st.markdown('')
73
+
74
+ elif app_mode == 'Run Face Facts':
75
+
76
+ age_model, gender_model, race_model, joint_model = load_model()
77
+ detector = load_face_detector()
78
+
79
+ st.sidebar.markdown('---')
80
+ use_single_model = st.sidebar.checkbox('Use single model', value = False)
81
+
82
+ kpi1, age_col, gender_col, race_col, kpi5 = st.columns(5)
83
+ with age_col:
84
+ st.markdown('**Age**')
85
+ age_text = st.markdown('0')
86
+ with gender_col:
87
+ st.markdown('**Gender**')
88
+ gender_text = st.markdown('0')
89
+ with race_col:
90
+ st.markdown('**Race**')
91
+ race_text = st.markdown('0')
92
+
93
+ img_file_buffer = st.sidebar.file_uploader('Upload an Image', type = ['jpg', 'png', 'jpeg'])
94
+ if img_file_buffer:
95
+ buffer = BytesIO(img_file_buffer.read())
96
+ data = np.frombuffer(buffer.getvalue(), dtype=np.uint8)
97
+ image_orig = cv2.imdecode(data, cv2.IMREAD_COLOR)
98
+ else:
99
+ demo_image = DEMO_IMAGE
100
+ image_orig = cv2.imread(demo_image, cv2.IMREAD_COLOR)
101
+
102
+ st.sidebar.text('Original Image')
103
+ st.sidebar.image(image_orig, channels = 'BGR')
104
+
105
+ image = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
106
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
107
+ detection_result = detector.detect(image)
108
+ if detection_result.detections != []:
109
+ res = detection_result.detections[0].bounding_box
110
+ x, y, w, h = res.origin_x, res.origin_y, res.width, res.height
111
+ image = image.numpy_view()[y:y+h, x:x+w]
112
+ image = cv2.resize(image, (200, 200))
113
+ image = torch.from_numpy(image).permute(2, 0, 1) / 255.
114
+ age_model.eval()
115
+ gender_model.eval()
116
+ race_model.eval()
117
+ joint_model.eval()
118
+ with torch.no_grad():
119
+ if use_single_model:
120
+ age_pred, gender_pred, race_pred = joint_model(image.unsqueeze(0))
121
+ else:
122
+ age_pred, gender_pred, race_pred = age_model(image.unsqueeze(0)), gender_model(image.unsqueeze(0)), race_model(image.unsqueeze(0))
123
+ age = int(age_pred.item())
124
+ gender = GENDER_DICT[int(gender_pred.item() > 0.5)]
125
+ race = RACE_DICT[race_pred.argmax(dim = 1).item()]
126
+ gender_emoji = '♂️' if gender == 'Male' else '♀️'
127
+ gender_color = 'blue' if gender == 'Male' else 'pink'
128
+ else:
129
+ age = '-'
130
+ gender = '-'
131
+ gender_emoji = '-'
132
+ gender_color = ''
133
+ race = '-'
134
+ st.error('No face detected in the image')
135
+ age_text.write(f'<h1> {age} </h1>', unsafe_allow_html = True)
136
+ gender_text.write(f"<h1 style='color: {gender_color};'>{gender_emoji}</h1>", unsafe_allow_html=True)
137
+ race_text.write(f"<h1> {race} </h1>", unsafe_allow_html = True)
138
+
139
+ st.markdown('---')
140
+
141
+ if not age == '-':
142
+ with st.expander('🔻 More Details 🔻'):
143
+ gender_decimal = gender_pred.item() if gender == 'Female' else abs(1 - gender_pred.item())
144
+ gender_precentage = f'{100 * gender_decimal:.2f}%'
145
+ st.write('---')
146
+ cols = st.columns(2)
147
+ with cols[0]:
148
+ st.markdown("<h3 style='color: gray;'>Face of Interest</h3>", unsafe_allow_html=True)
149
+ st.image(image.permute(1, 2, 0).numpy(), channels = 'RGB', use_column_width = True)
150
+ with cols[1]:
151
+ st.write(f"<h3 style = 'color: {gender_color};'> {gender}; {gender_precentage} Probability </h3>", unsafe_allow_html = True)
152
+ st.progress(gender_decimal)
153
+ st.write('---')
154
+ st.bar_chart(pd.DataFrame({'Probability': race_pred.squeeze().numpy(), 'Race': list(RACE_DICT.values())}), x = 'Race', y = 'Probability')
155
+ st.write('---')
utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ import torch.nn.functional as F
5
+ class CustomModelMain(nn.Module):
6
+ def __init__(self, problem_type, n_classes):
7
+ super().__init__()
8
+ if problem_type == 'Classification' and n_classes == 1:
9
+ output = nn.Sigmoid()
10
+ elif problem_type == 'Regression' and n_classes == 1:
11
+ output = nn.ReLU()
12
+ elif problem_type == 'Classification' and n_classes > 1:
13
+ output = nn.Softmax(dim = 1)
14
+ self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1)
15
+ self.pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
16
+ self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, padding = 1)
17
+ self.pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
18
+ self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1)
19
+ self.pool3 = nn.MaxPool2d(kernel_size = 2, stride = 2)
20
+ self.flatten = nn.Flatten()
21
+ self.fc1 = nn.Linear(64 * 25 * 25 , 128)
22
+ self.relu = nn.ReLU()
23
+ self.dropout = nn.Dropout(p = 0.5)
24
+ self.fc2 = nn.Linear(128, n_classes)
25
+ self.output = output
26
+ def forward(self, x):
27
+ x = self.conv1(x)
28
+ x = self.pool1(x)
29
+ x = self.relu(x)
30
+ x = self.conv2(x)
31
+ x = self.pool2(x)
32
+ x = self.relu(x)
33
+ x = self.conv3(x)
34
+ x = self.pool3(x)
35
+ x = self.relu(x)
36
+ x = self.flatten(x)
37
+ x = self.fc1(x)
38
+ x = self.relu(x)
39
+ x = self.dropout(x)
40
+ x = self.fc2(x)
41
+ x = self.output(x)
42
+ return x
43
+ class age_lightning(pl.LightningModule):
44
+ def __init__(self):
45
+ super().__init__()
46
+ self.model = CustomModelMain('Regression', 1)
47
+ def forward(self, x):
48
+ return self.model(x)
49
+ def training_step(self, batch, batch_idx):
50
+ x, y = batch
51
+ y = y[:, 0]
52
+ y_hat = self(x)
53
+ loss = F.mse_loss(y_hat, y.unsqueeze(-1).float())
54
+ acc = torch.eq((y_hat > 0.5).int().to(torch.int64), y.unsqueeze(-1).int()).all(dim=1).sum() / len(y)
55
+ self.log('train loss', loss, prog_bar = True)
56
+ return loss
57
+ def validation_step(self, batch, batch_idx):
58
+ x, y = batch
59
+ y_val = y[:, 0]
60
+ y_hat = self(x)
61
+ loss = F.mse_loss(y_hat, y_val.unsqueeze(-1).float())
62
+ self.log('valid loss', loss, prog_bar = True)
63
+ def configure_optimizers(self):
64
+ return torch.optim.Adam(self.parameters(), lr=1e-4)
65
+ class gender_lightning(pl.LightningModule):
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.model = CustomModelMain('Classification', 1)
69
+ def forward(self, x):
70
+ return self.model(x)
71
+ def training_step(self, batch, batch_idx):
72
+ x, y = batch
73
+ y = y[:, 1]
74
+ y_hat = self(x)
75
+ loss = F.binary_cross_entropy(y_hat, y.unsqueeze(-1).float())
76
+ acc = torch.eq((y_hat > 0.5).int().to(torch.int64), y.unsqueeze(-1).int()).all(dim=1).sum() / len(y)
77
+ self.log('train loss', loss, prog_bar = True)
78
+ self.log('accuracy', acc, prog_bar = True)
79
+ return loss
80
+ def validation_step(self, batch, batch_idx):
81
+ x, y = batch
82
+ y_val = y[:, 1]
83
+ y_hat = self(x)
84
+ loss = F.binary_cross_entropy_with_logits(y_hat, y_val.unsqueeze(-1).float())
85
+ acc = torch.eq((y_hat > 0.5).int().to(torch.int64), y_val.unsqueeze(-1).int()).all(dim=1).sum() / len(y_val)
86
+ self.log('valid loss', loss, prog_bar = True)
87
+ self.log('val accuracy', acc, prog_bar = True)
88
+ def configure_optimizers(self):
89
+ return torch.optim.Adam(self.parameters(), lr=1e-4)
90
+ class race_lightning(pl.LightningModule):
91
+ def __init__(self):
92
+ super().__init__()
93
+ self.model = CustomModelMain('Classification', 5)
94
+ def forward(self, x):
95
+ return self.model(x)
96
+ def training_step(self, batch, batch_idx):
97
+ x, y = batch
98
+ y = y[:, 2]
99
+ y_hat = self(x)
100
+ y_oh = F.one_hot(y, num_classes = 5)
101
+ loss = F.cross_entropy(y_hat.log(), y_oh.float())
102
+ preds = y_hat.argmax(dim = 1)
103
+ acc = torch.eq(y, preds).float().mean()
104
+ self.log('train loss', loss, prog_bar = True)
105
+ self.log('accuracy', acc, prog_bar = True)
106
+ return loss
107
+ def validation_step(self, batch, batch_idx):
108
+ x, y = batch
109
+ y_val = y[:, 2]
110
+ y_hat = self(x)
111
+ y_oh = F.one_hot(y_val, num_classes = 5)
112
+ loss = F.cross_entropy(y_hat, y_oh.float())
113
+ preds = y_hat.argmax(dim = 1)
114
+ acc = torch.eq(y_val, preds).float().mean()
115
+ self.log('valid loss', loss, prog_bar = True)
116
+ self.log('val accuracy', acc, prog_bar = True)
117
+ def configure_optimizers(self):
118
+ return torch.optim.Adam(self.parameters(), lr=1e-4)
119
+ class Ultimate_Lightning(pl.LightningModule):
120
+ def __init__(self):
121
+ super().__init__()
122
+ self.age_model = CustomModelMain('Regression', 1)
123
+ self.gender_model = CustomModelMain('Classification', 1)
124
+ self.race_model = CustomModelMain('Classification', 5)
125
+ def forward(self, x):
126
+ return self.age_model(x), self.gender_model(x), self.race_model(x)
127
+ def training_step(self, batch, batch_idx):
128
+ x, y = batch
129
+ y_age, y_gender, y_race = y[:, 0], y[:, 1], y[:, 2]
130
+ y_hat_age, y_hat_gender, y_hat_race = self(x)
131
+
132
+ age_loss = F.mse_loss(y_hat_age, y_age.unsqueeze(-1).float())
133
+ age_acc = torch.eq((y_hat_age > 0.5).int().to(torch.int64), y_age.unsqueeze(-1).int()).all(dim=1).sum() / len(y_age)
134
+
135
+ gender_loss = F.binary_cross_entropy(y_hat_gender, y_gender.unsqueeze(-1).float())
136
+ gender_acc = torch.eq((y_hat_gender > 0.5).int().to(torch.int64), y_gender.unsqueeze(-1).int()).all(dim=1).sum() / len(y_gender)
137
+
138
+ y_race_oh = F.one_hot(y_race, num_classes = 5)
139
+ race_loss = F.cross_entropy(y_hat_race.log(), y_race_oh.float())
140
+ race_preds = y_hat_race.argmax(dim = 1)
141
+ race_acc = torch.eq(y_race, race_preds).float().mean()
142
+
143
+ total_loss = (0.001 * age_loss) + gender_loss + race_loss
144
+
145
+ self.log('age loss', age_loss, prog_bar = True)
146
+ self.log('gender loss', gender_loss, prog_bar = True)
147
+ self.log('race loss', race_loss, prog_bar = True)
148
+ self.log('gender acc', gender_acc, prog_bar = True)
149
+ self.log('race acc', race_acc, prog_bar = True)
150
+ self.log('total loss', total_loss, prog_bar = True)
151
+
152
+ return total_loss
153
+
154
+ def validation_step(self, batch, batch_idx):
155
+ x, y = batch
156
+ y_age, y_gender, y_race = y[:, 0], y[:, 1], y[:, 2]
157
+ y_hat_age, y_hat_gender, y_hat_race = self(x)
158
+
159
+ age_loss = F.mse_loss(y_hat_age, y_age.unsqueeze(-1).float())
160
+ age_acc = torch.eq((y_hat_age > 0.5).int().to(torch.int64), y_age.unsqueeze(-1).int()).all(dim=1).sum() / len(y_age)
161
+
162
+ gender_loss = F.binary_cross_entropy(y_hat_gender, y_gender.unsqueeze(-1).float())
163
+ gender_acc = torch.eq((y_hat_gender > 0.5).int().to(torch.int64), y_gender.unsqueeze(-1).int()).all(dim=1).sum() / len(y_gender)
164
+
165
+ y_race_oh = F.one_hot(y_race, num_classes = 5)
166
+ race_loss = F.cross_entropy(y_hat_race.log(), y_race_oh.float())
167
+ race_preds = y_hat_race.argmax(dim = 1)
168
+ race_acc = torch.eq(y_race, race_preds).float().mean()
169
+
170
+ total_loss = (0.001 * age_loss) + gender_loss + race_loss
171
+
172
+ self.log('val age loss', age_loss, prog_bar = True)
173
+
174
+ self.log('val gender acc', gender_acc, prog_bar = True)
175
+
176
+ self.log('val race acc', race_acc, prog_bar = True)
177
+
178
+ def configure_optimizers(self):
179
+ return torch.optim.Adam(self.parameters(), lr=1e-4)
180
+