AKMESSI commited on
Commit
43ae76e
·
verified ·
1 Parent(s): 8b3ba8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -303
app.py CHANGED
@@ -1,304 +1,304 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import models, transforms
5
- from PIL import Image
6
- import numpy as np
7
- import cv2
8
-
9
- # --- 1. CONFIGURATION & STYLING ---
10
- st.set_page_config(
11
- page_title="Aesthetix AI",
12
- page_icon="✨",
13
- layout="centered",
14
- initial_sidebar_state="collapsed"
15
- )
16
-
17
- # Custom CSS for Premium White/Clean Theme
18
- st.markdown("""
19
- <style>
20
- /* App Background */
21
- .stApp {
22
- background-color: #F8F9FB;
23
- font-family: 'Helvetica Neue', sans-serif;
24
- }
25
-
26
- /* Hide Streamlit Branding */
27
- #MainMenu {visibility: hidden;}
28
- header {visibility: hidden;}
29
- footer {visibility: hidden;}
30
-
31
- /* Main Content Card Style */
32
- .block-container {
33
- padding-top: 2rem;
34
- padding-bottom: 2rem;
35
- }
36
-
37
- /* Custom Headers */
38
- h1 {
39
- color: #1A1A1A;
40
- font-weight: 700;
41
- letter-spacing: -1px;
42
- text-align: center;
43
- padding-bottom: 10px;
44
- }
45
-
46
- p {
47
- color: #666666;
48
- }
49
-
50
- /* Styled Image Containers */
51
- div[data-testid="stImage"] {
52
- border-radius: 12px;
53
- overflow: hidden;
54
- box-shadow: 0 10px 20px rgba(0,0,0,0.05);
55
- transition: transform 0.3s ease;
56
- }
57
-
58
- /* Score Card */
59
- .score-card {
60
- background-color: #FFFFFF;
61
- padding: 30px;
62
- border-radius: 20px;
63
- box-shadow: 0 4px 15px rgba(0,0,0,0.05);
64
- text-align: center;
65
- border: 1px solid #EEEEEE;
66
- margin-top: 20px;
67
- }
68
-
69
- .score-value {
70
- font-size: 5rem;
71
- font-weight: 800;
72
- margin: 0;
73
- line-height: 1;
74
- }
75
-
76
- .score-label {
77
- font-size: 1.1rem;
78
- color: #888;
79
- font-weight: 500;
80
- text-transform: uppercase;
81
- letter-spacing: 2px;
82
- }
83
-
84
- /* Button Styling */
85
- .stButton > button {
86
- background: linear-gradient(90deg, #1A1A1A 0%, #333333 100%);
87
- color: white;
88
- border: none;
89
- padding: 12px 28px;
90
- border-radius: 50px;
91
- font-weight: 600;
92
- letter-spacing: 0.5px;
93
- width: 100%;
94
- transition: all 0.3s;
95
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
96
- }
97
-
98
- .stButton > button:hover {
99
- transform: translateY(-2px);
100
- box-shadow: 0 6px 12px rgba(0,0,0,0.15);
101
- background: #000000;
102
- }
103
-
104
- /* File Uploader */
105
- .stFileUploader {
106
- padding: 20px;
107
- background-color: #FFFFFF;
108
- border-radius: 15px;
109
- border: 1px dashed #DDDDDD;
110
- }
111
- </style>
112
- """, unsafe_allow_html=True)
113
-
114
- # Header
115
- st.markdown("<h1>✨ Aesthetix AI</h1>", unsafe_allow_html=True)
116
- st.markdown("<p style='text-align: center; margin-top: -15px; margin-bottom: 30px;'>Facial Symmetry & Feature Analysis Engine</p>", unsafe_allow_html=True)
117
-
118
- # --- 2. MODEL LOADING ---
119
- @st.cache_resource
120
- def load_models():
121
- device = torch.device("cpu")
122
-
123
- # Rating Model (ResNet18)
124
- rater = models.resnet18(weights=None)
125
- num_ftrs = rater.fc.in_features
126
- rater.fc = nn.Linear(num_ftrs, 1)
127
- try:
128
- rater.load_state_dict(torch.load("best_face_rater_colab.pth", map_location=device))
129
- except FileNotFoundError:
130
- st.error("⚠️ Model file missing. Upload 'best_face_rater_colab.pth'.")
131
- return None, None
132
- rater.eval()
133
-
134
- # Segmentation Model (DeepLabV3)
135
- seg_model = models.segmentation.deeplabv3_resnet50(weights='DEFAULT')
136
- seg_model.eval()
137
-
138
- return rater, seg_model
139
-
140
- rater_model, seg_model = load_models()
141
-
142
- # --- 3. PROCESSING LOGIC ---
143
- def isolate_face_pixels(image):
144
- # Prepare for DeepLabV3
145
- seg_transform = transforms.Compose([
146
- transforms.Resize(256),
147
- transforms.CenterCrop(224),
148
- transforms.ToTensor(),
149
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
150
- ])
151
- input_tensor = seg_transform(image).unsqueeze(0)
152
-
153
- with torch.no_grad():
154
- output = seg_model(input_tensor)['out'][0]
155
-
156
- output_predictions = output.argmax(0)
157
- # Class 15 is Person
158
- mask = (output_predictions == 15).byte().numpy()
159
-
160
- image_resized = image.resize((224, 224))
161
- img_np = np.array(image_resized)
162
-
163
- # Apply Mask (Black Background)
164
- mask_3d = np.stack([mask, mask, mask], axis=2)
165
- foreground = img_np * mask_3d
166
-
167
- return Image.fromarray(foreground)
168
-
169
- def crop_to_face_strict(image_pil):
170
- img_np = np.array(image_pil)
171
- if len(img_np.shape) == 2: img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
172
-
173
- # Haar Cascade
174
- face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
175
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
176
- faces = face_cascade.detectMultiScale(gray, 1.1, 4)
177
-
178
- if len(faces) == 0: return image_pil, False
179
-
180
- # Largest Face
181
- x, y, w, h = max(faces, key=lambda f: f[2] * f[3])
182
-
183
- # Margin logic
184
- margin = int(h * 0.20)
185
- x = max(0, x - margin)
186
- y = max(0, y - margin)
187
- w = min(img_np.shape[1] - x, w + 2*margin)
188
- h = min(img_np.shape[0] - y, h + 2*margin)
189
-
190
- return image_pil.crop((x, y, x+w, y+h)), True
191
-
192
- # Grad-CAM Setup
193
- gradients = None
194
- activations = None
195
- def backward_hook(module, grad_input, grad_output):
196
- global gradients
197
- gradients = grad_output[0]
198
- def forward_hook(module, input, output):
199
- global activations
200
- activations = output
201
-
202
- def generate_heatmap(model, input_tensor):
203
- target_layer = model.layer4[-1]
204
- handle_f = target_layer.register_forward_hook(forward_hook)
205
- handle_b = target_layer.register_full_backward_hook(backward_hook)
206
-
207
- output = model(input_tensor)
208
- model.zero_grad()
209
- output.backward()
210
-
211
- pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
212
- for i in range(512): activations[:, i, :, :] *= pooled_gradients[i]
213
-
214
- heatmap = torch.mean(activations, dim=1).squeeze()
215
- heatmap = np.maximum(heatmap.detach().numpy(), 0)
216
- if np.max(heatmap) > 0: heatmap /= np.max(heatmap)
217
-
218
- handle_f.remove(); handle_b.remove()
219
- return heatmap
220
-
221
- def overlay_heatmap(heatmap, original_image):
222
- heatmap = cv2.resize(heatmap, (original_image.width, original_image.height))
223
- heatmap = np.uint8(255 * heatmap)
224
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
225
- img_np = np.array(original_image)
226
- img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
227
- superimposed_img = heatmap * 0.4 + img_np
228
- return Image.fromarray(cv2.cvtColor(np.uint8(superimposed_img), cv2.COLOR_BGR2RGB))
229
-
230
- # --- 4. MAIN INTERFACE ---
231
-
232
- uploaded_file = st.file_uploader("Upload a clear portrait", type=["jpg", "jpeg", "png"])
233
-
234
- if uploaded_file is not None and rater_model:
235
- image = Image.open(uploaded_file).convert('RGB')
236
-
237
- # Processing Flow
238
- with st.spinner("Isolating facial geometry..."):
239
- cropped_img, found = crop_to_face_strict(image)
240
- final_input = isolate_face_pixels(cropped_img)
241
-
242
- # UI Columns
243
- col1, col2 = st.columns(2)
244
- with col1:
245
- st.image(image, caption='Original', use_container_width=True)
246
- with col2:
247
- st.image(final_input, caption='AI Analysis View', use_container_width=True)
248
-
249
- st.write("")
250
-
251
- if st.button('Calculate Score'):
252
- progress_bar = st.progress(0)
253
-
254
- # 1. Transform
255
- transform = transforms.Compose([
256
- transforms.Resize((224, 224)),
257
- transforms.ToTensor(),
258
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
259
- ])
260
- input_tensor = transform(final_input).unsqueeze(0)
261
- input_tensor.requires_grad = True
262
-
263
- progress_bar.progress(60)
264
-
265
- # 2. Score
266
- with torch.no_grad():
267
- output = rater_model(input_tensor)
268
- score = output.item()
269
-
270
- score = max(1.0, min(5.0, score))
271
-
272
- # 3. Heatmap (Visual Reasoning)
273
- heatmap = generate_heatmap(rater_model, input_tensor)
274
- overlay = overlay_heatmap(heatmap, final_input)
275
-
276
- progress_bar.progress(100)
277
-
278
- # --- RESULTS DISPLAY ---
279
- st.markdown("<br>", unsafe_allow_html=True)
280
-
281
- # Determine Color Code
282
- if score >= 4.0: score_color = "#4CAF50" # Green
283
- elif score >= 3.0: score_color = "#FF9800" # Orange
284
- else: score_color = "#F44336" # Red
285
-
286
- # Metric Card HTML
287
- st.markdown(f"""
288
- <div class="score-card">
289
- <p class="score-label">Aesthetic Rating</p>
290
- <h1 class="score-value" style="color: {score_color};">{score:.2f}</h1>
291
- <p style="margin-top: 10px; color: #666;">out of 5.0</p>
292
- </div>
293
- """, unsafe_allow_html=True)
294
-
295
- st.write("")
296
- st.image(overlay, caption='Feature Activation Map (Visual Reasoning)', use_container_width=True)
297
-
298
- if score >= 4.0:
299
- st.success("Exceptional features detected. High symmetry and proportion.")
300
- st.balloons()
301
- elif score >= 3.0:
302
- st.info("Strong features detected. Above average structure.")
303
- else:
304
  st.warning("Average structure detected. Lighting or angle may affect result.")
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
+
9
+ # --- 1. CONFIGURATION & STYLING ---
10
+ st.set_page_config(
11
+ page_title="Aesthetix AI",
12
+ page_icon="✨",
13
+ layout="centered",
14
+ initial_sidebar_state="collapsed"
15
+ )
16
+
17
+ # Custom CSS for Premium White/Clean Theme
18
+ st.markdown("""
19
+ <style>
20
+ /* App Background */
21
+ .stApp {
22
+ background-color: #F8F9FB;
23
+ font-family: 'Helvetica Neue', sans-serif;
24
+ }
25
+
26
+ /* Hide Streamlit Branding */
27
+ #MainMenu {visibility: hidden;}
28
+ header {visibility: hidden;}
29
+ footer {visibility: hidden;}
30
+
31
+ /* Main Content Card Style */
32
+ .block-container {
33
+ padding-top: 2rem;
34
+ padding-bottom: 2rem;
35
+ }
36
+
37
+ /* Custom Headers */
38
+ h1 {
39
+ color: #1A1A1A;
40
+ font-weight: 700;
41
+ letter-spacing: -1px;
42
+ text-align: center;
43
+ padding-bottom: 10px;
44
+ }
45
+
46
+ p {
47
+ color: #666666;
48
+ }
49
+
50
+ /* Styled Image Containers */
51
+ div[data-testid="stImage"] {
52
+ border-radius: 12px;
53
+ overflow: hidden;
54
+ box-shadow: 0 10px 20px rgba(0,0,0,0.05);
55
+ transition: transform 0.3s ease;
56
+ }
57
+
58
+ /* Score Card */
59
+ .score-card {
60
+ background-color: #FFFFFF;
61
+ padding: 30px;
62
+ border-radius: 20px;
63
+ box-shadow: 0 4px 15px rgba(0,0,0,0.05);
64
+ text-align: center;
65
+ border: 1px solid #EEEEEE;
66
+ margin-top: 20px;
67
+ }
68
+
69
+ .score-value {
70
+ font-size: 5rem;
71
+ font-weight: 800;
72
+ margin: 0;
73
+ line-height: 1;
74
+ }
75
+
76
+ .score-label {
77
+ font-size: 1.1rem;
78
+ color: #888;
79
+ font-weight: 500;
80
+ text-transform: uppercase;
81
+ letter-spacing: 2px;
82
+ }
83
+
84
+ /* Button Styling */
85
+ .stButton > button {
86
+ background: linear-gradient(90deg, #1A1A1A 0%, #333333 100%);
87
+ color: white;
88
+ border: none;
89
+ padding: 12px 28px;
90
+ border-radius: 50px;
91
+ font-weight: 600;
92
+ letter-spacing: 0.5px;
93
+ width: 100%;
94
+ transition: all 0.3s;
95
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
96
+ }
97
+
98
+ .stButton > button:hover {
99
+ transform: translateY(-2px);
100
+ box-shadow: 0 6px 12px rgba(0,0,0,0.15);
101
+ background: #000000;
102
+ }
103
+
104
+ /* File Uploader */
105
+ .stFileUploader {
106
+ padding: 20px;
107
+ background-color: #FFFFFF;
108
+ border-radius: 15px;
109
+ border: 1px dashed #DDDDDD;
110
+ }
111
+ </style>
112
+ """, unsafe_allow_html=True)
113
+
114
+ # Header
115
+ st.markdown("<h1>✨ Aesthetix AI</h1>", unsafe_allow_html=True)
116
+ st.markdown("<p style='text-align: center; margin-top: -15px; margin-bottom: 30px;'>Facial Symmetry & Feature Analysis Engine</p>", unsafe_allow_html=True)
117
+
118
+ # --- 2. MODEL LOADING ---
119
+ @st.cache_resource
120
+ def load_models():
121
+ device = torch.device("cpu")
122
+
123
+ # Rating Model (ResNet18)
124
+ rater = models.resnet18(weights=None)
125
+ num_ftrs = rater.fc.in_features
126
+ rater.fc = nn.Linear(num_ftrs, 1)
127
+ try:
128
+ rater.load_state_dict(torch.load("best_face_rater_colab.pth", map_location=device))
129
+ except FileNotFoundError:
130
+ st.error("⚠️ Model file missing. Upload 'best_face_rater_colab.pth'.")
131
+ return None, None
132
+ rater.eval()
133
+
134
+ # Segmentation Model (DeepLabV3)
135
+ seg_model = models.segmentation.deeplabv3_resnet50(weights='DEFAULT')
136
+ seg_model.eval()
137
+
138
+ return rater, seg_model
139
+
140
+ rater_model, seg_model = load_models()
141
+
142
+ # --- 3. PROCESSING LOGIC ---
143
+ def isolate_face_pixels(image):
144
+ # Prepare for DeepLabV3
145
+ seg_transform = transforms.Compose([
146
+ transforms.Resize(256),
147
+ transforms.CenterCrop(224),
148
+ transforms.ToTensor(),
149
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
150
+ ])
151
+ input_tensor = seg_transform(image).unsqueeze(0)
152
+
153
+ with torch.no_grad():
154
+ output = seg_model(input_tensor)['out'][0]
155
+
156
+ output_predictions = output.argmax(0)
157
+ # Class 15 is Person
158
+ mask = (output_predictions == 15).byte().numpy()
159
+
160
+ image_resized = image.resize((224, 224))
161
+ img_np = np.array(image_resized)
162
+
163
+ # Apply Mask (Black Background)
164
+ mask_3d = np.stack([mask, mask, mask], axis=2)
165
+ foreground = img_np * mask_3d
166
+
167
+ return Image.fromarray(foreground)
168
+
169
+ def crop_to_face_strict(image_pil):
170
+ img_np = np.array(image_pil)
171
+ if len(img_np.shape) == 2: img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
172
+
173
+ # Haar Cascade
174
+ face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
175
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
176
+ faces = face_cascade.detectMultiScale(gray, 1.1, 4)
177
+
178
+ if len(faces) == 0: return image_pil, False
179
+
180
+ # Largest Face
181
+ x, y, w, h = max(faces, key=lambda f: f[2] * f[3])
182
+
183
+ # Margin logic
184
+ margin = int(h * 0.20)
185
+ x = max(0, x - margin)
186
+ y = max(0, y - margin)
187
+ w = min(img_np.shape[1] - x, w + 2*margin)
188
+ h = min(img_np.shape[0] - y, h + 2*margin)
189
+
190
+ return image_pil.crop((x, y, x+w, y+h)), True
191
+
192
+ # Grad-CAM Setup
193
+ gradients = None
194
+ activations = None
195
+ def backward_hook(module, grad_input, grad_output):
196
+ global gradients
197
+ gradients = grad_output[0]
198
+ def forward_hook(module, input, output):
199
+ global activations
200
+ activations = output
201
+
202
+ def generate_heatmap(model, input_tensor):
203
+ target_layer = model.layer4[-1]
204
+ handle_f = target_layer.register_forward_hook(forward_hook)
205
+ handle_b = target_layer.register_full_backward_hook(backward_hook)
206
+
207
+ output = model(input_tensor)
208
+ model.zero_grad()
209
+ output.backward()
210
+
211
+ pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
212
+ for i in range(512): activations[:, i, :, :] *= pooled_gradients[i]
213
+
214
+ heatmap = torch.mean(activations, dim=1).squeeze()
215
+ heatmap = np.maximum(heatmap.detach().numpy(), 0)
216
+ if np.max(heatmap) > 0: heatmap /= np.max(heatmap)
217
+
218
+ handle_f.remove(); handle_b.remove()
219
+ return heatmap
220
+
221
+ def overlay_heatmap(heatmap, original_image):
222
+ heatmap = cv2.resize(heatmap, (original_image.width, original_image.height))
223
+ heatmap = np.uint8(255 * heatmap)
224
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
225
+ img_np = np.array(original_image)
226
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
227
+ superimposed_img = heatmap * 0.4 + img_np
228
+ return Image.fromarray(cv2.cvtColor(np.uint8(superimposed_img), cv2.COLOR_BGR2RGB))
229
+
230
+ # --- 4. MAIN INTERFACE ---
231
+
232
+ uploaded_file = st.file_uploader("Upload a clear portrait", type=["jpg", "jpeg", "png"])
233
+
234
+ if uploaded_file is not None and rater_model:
235
+ image = Image.open(uploaded_file).convert('RGB')
236
+
237
+ # Processing Flow
238
+ with st.spinner("Isolating facial geometry..."):
239
+ cropped_img, found = crop_to_face_strict(image)
240
+ final_input = isolate_face_pixels(cropped_img)
241
+
242
+ # UI Columns
243
+ col1, col2 = st.columns(2)
244
+ with col1:
245
+ st.image(image, caption='Original', use_column_width=True)
246
+ with col2:
247
+ st.image(final_input, caption='AI Analysis View', use_column_width=True)
248
+
249
+ st.write("")
250
+
251
+ if st.button('Calculate Score'):
252
+ progress_bar = st.progress(0)
253
+
254
+ # 1. Transform
255
+ transform = transforms.Compose([
256
+ transforms.Resize((224, 224)),
257
+ transforms.ToTensor(),
258
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
259
+ ])
260
+ input_tensor = transform(final_input).unsqueeze(0)
261
+ input_tensor.requires_grad = True
262
+
263
+ progress_bar.progress(60)
264
+
265
+ # 2. Score
266
+ with torch.no_grad():
267
+ output = rater_model(input_tensor)
268
+ score = output.item()
269
+
270
+ score = max(1.0, min(5.0, score))
271
+
272
+ # 3. Heatmap (Visual Reasoning)
273
+ heatmap = generate_heatmap(rater_model, input_tensor)
274
+ overlay = overlay_heatmap(heatmap, final_input)
275
+
276
+ progress_bar.progress(100)
277
+
278
+ # --- RESULTS DISPLAY ---
279
+ st.markdown("<br>", unsafe_allow_html=True)
280
+
281
+ # Determine Color Code
282
+ if score >= 4.0: score_color = "#4CAF50" # Green
283
+ elif score >= 3.0: score_color = "#FF9800" # Orange
284
+ else: score_color = "#F44336" # Red
285
+
286
+ # Metric Card HTML
287
+ st.markdown(f"""
288
+ <div class="score-card">
289
+ <p class="score-label">Aesthetic Rating</p>
290
+ <h1 class="score-value" style="color: {score_color};">{score:.2f}</h1>
291
+ <p style="margin-top: 10px; color: #666;">out of 5.0</p>
292
+ </div>
293
+ """, unsafe_allow_html=True)
294
+
295
+ st.write("")
296
+ st.image(overlay, caption='Feature Activation Map (Visual Reasoning)', use_container_width=True)
297
+
298
+ if score >= 4.0:
299
+ st.success("Exceptional features detected. High symmetry and proportion.")
300
+ st.balloons()
301
+ elif score >= 3.0:
302
+ st.info("Strong features detected. Above average structure.")
303
+ else:
304
  st.warning("Average structure detected. Lighting or angle may affect result.")