tbvl22 commited on
Commit
fe81ecf
·
1 Parent(s): ad8bf42

Add application file

Browse files
Files changed (1) hide show
  1. PAR_gradio_app.py +693 -0
PAR_gradio_app.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from PIL import Image
8
+ import matplotlib
9
+ matplotlib.use('Agg')
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+ from torchvision import transforms
13
+ import gradio as gr
14
+ import logging
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # ========================================================
21
+ # MODEL ARCHITECTURE (Same as your training code)
22
+ # ========================================================
23
+
24
+ class EnhancedDifferentiableHistogram(nn.Module):
25
+ """Improved differentiable histogram with KDE-based binning"""
26
+ def __init__(self, bins=16, channels=3, min_val=0.0, max_val=1.0, bandwidth=0.05):
27
+ super().__init__()
28
+ self.bins = bins
29
+ self.channels = channels
30
+ self.min_val = min_val
31
+ self.max_val = max_val
32
+ self.bandwidth = bandwidth
33
+ self.bin_width = (max_val - min_val) / bins
34
+ self.bin_centers = nn.Parameter(
35
+ torch.linspace(min_val + self.bin_width/2, max_val - self.bin_width/2, bins),
36
+ requires_grad=False
37
+ )
38
+
39
+ def forward(self, x):
40
+ batch_size = x.size(0)
41
+ histograms = []
42
+ for c in range(self.channels):
43
+ channel_data = x[:, c].view(batch_size, -1, 1)
44
+ diff = (channel_data - self.bin_centers.view(1, 1, -1)) / self.bandwidth
45
+ kernel = torch.sigmoid(diff + 0.5) - torch.sigmoid(diff - 0.5)
46
+ hist = kernel.sum(dim=1)
47
+ hist = hist / (hist.sum(dim=1, keepdim=True) + 1e-6)
48
+ histograms.append(hist)
49
+ return torch.stack(histograms, dim=1)
50
+
51
+
52
+ class ColorConsistencyModule(nn.Module):
53
+ """Enhanced CSCCM with histogram losses"""
54
+ def __init__(self, feature_size, num_color_classes, hist_bins=16):
55
+ super().__init__()
56
+ self.hist_bins = hist_bins
57
+ self.hist_layer = EnhancedDifferentiableHistogram(bins=hist_bins)
58
+ self.hist_embed = nn.Sequential(
59
+ nn.Linear(3 * hist_bins, 128),
60
+ nn.ReLU(),
61
+ nn.Linear(128, 64)
62
+ )
63
+ self.top_fusion = nn.Linear(feature_size + 64, feature_size)
64
+ self.mid_fusion = nn.Linear(feature_size + 64, feature_size)
65
+ self.bottom_fusion = nn.Linear(feature_size + 64, feature_size)
66
+ self.upper_color_refine = nn.Sequential(
67
+ nn.Linear(feature_size, feature_size//2),
68
+ nn.ReLU(),
69
+ nn.Linear(feature_size//2, num_color_classes)
70
+ )
71
+ self.lower_color_refine = nn.Sequential(
72
+ nn.Linear(feature_size, feature_size//2),
73
+ nn.ReLU(),
74
+ nn.Linear(feature_size//2, num_color_classes)
75
+ )
76
+
77
+ def forward(self, top_feat, mid_feat, bot_feat, full_image):
78
+ hist = self.hist_layer(full_image)
79
+ hist_embed = self.hist_embed(hist.view(hist.size(0), -1))
80
+ top_fused = F.relu(self.top_fusion(torch.cat([top_feat, hist_embed], dim=1)))
81
+ mid_fused = F.relu(self.mid_fusion(torch.cat([mid_feat, hist_embed], dim=1)))
82
+ bot_fused = F.relu(self.bottom_fusion(torch.cat([bot_feat, hist_embed], dim=1)))
83
+ upper_color_refined = self.upper_color_refine(mid_fused)
84
+ lower_color_refined = self.lower_color_refine(bot_fused)
85
+ return top_fused, mid_fused, bot_fused, upper_color_refined, lower_color_refined, hist
86
+
87
+
88
+ class Bottleneck(nn.Module):
89
+ """Bottleneck block for ResNet-50"""
90
+ expansion = 4
91
+
92
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
93
+ super().__init__()
94
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
95
+ self.bn1 = nn.BatchNorm2d(out_channels)
96
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride, 1, bias=False)
97
+ self.bn2 = nn.BatchNorm2d(out_channels)
98
+ self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, 1, bias=False)
99
+ self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)
100
+ self.relu = nn.ReLU(inplace=True)
101
+ self.downsample = downsample
102
+
103
+ def forward(self, x):
104
+ identity = x
105
+ out = self.conv1(x)
106
+ out = self.bn1(out)
107
+ out = self.relu(out)
108
+ out = self.conv2(out)
109
+ out = self.bn2(out)
110
+ out = self.relu(out)
111
+ out = self.conv3(out)
112
+ out = self.bn3(out)
113
+ if self.downsample:
114
+ identity = self.downsample(x)
115
+ out += identity
116
+ return self.relu(out)
117
+
118
+
119
+ class ChannelAttention(nn.Module):
120
+ """Channel Attention Module (CBAM)"""
121
+ def __init__(self, in_channels, reduction=16):
122
+ super().__init__()
123
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
124
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
125
+ self.fc = nn.Sequential(
126
+ nn.Linear(in_channels, in_channels // reduction),
127
+ nn.ReLU(inplace=True),
128
+ nn.Linear(in_channels // reduction, in_channels),
129
+ nn.Sigmoid()
130
+ )
131
+
132
+ def forward(self, x):
133
+ b, c, _, _ = x.size()
134
+ avg_out = self.fc(self.avg_pool(x).view(b, c))
135
+ max_out = self.fc(self.max_pool(x).view(b, c))
136
+ out = avg_out + max_out
137
+ return torch.sigmoid(out).view(b, c, 1, 1) * x
138
+
139
+
140
+ class SpatialAttention(nn.Module):
141
+ """Spatial Attention Module (CBAM)"""
142
+ def __init__(self, kernel_size=7):
143
+ super().__init__()
144
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
145
+ self.sigmoid = nn.Sigmoid()
146
+
147
+ def forward(self, x):
148
+ avg_out = torch.mean(x, dim=1, keepdim=True)
149
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
150
+ combined = torch.cat([avg_out, max_out], dim=1)
151
+ attention = self.conv(combined)
152
+ return self.sigmoid(attention) * x
153
+
154
+
155
+ class CustomResNet(nn.Module):
156
+ """Enhanced ResNet-50"""
157
+ def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], in_channels=3):
158
+ super().__init__()
159
+ self.in_channels = 64
160
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
161
+ self.bn1 = nn.BatchNorm2d(64)
162
+ self.relu = nn.ReLU(inplace=True)
163
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
164
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
165
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
166
+ self.attn2 = ChannelAttention(128 * block.expansion)
167
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
168
+ self.attn3 = SpatialAttention()
169
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
170
+
171
+ def _make_layer(self, block, out_channels, blocks, stride=1):
172
+ downsample = None
173
+ if stride != 1 or self.in_channels != out_channels * block.expansion:
174
+ downsample = nn.Sequential(
175
+ nn.Conv2d(self.in_channels, out_channels * block.expansion,
176
+ kernel_size=1, stride=stride, bias=False),
177
+ nn.BatchNorm2d(out_channels * block.expansion)
178
+ )
179
+ layers = []
180
+ layers.append(block(self.in_channels, out_channels, stride, downsample))
181
+ self.in_channels = out_channels * block.expansion
182
+ for _ in range(1, blocks):
183
+ layers.append(block(self.in_channels, out_channels))
184
+ return nn.Sequential(*layers)
185
+
186
+ def forward(self, x):
187
+ x = self.conv1(x)
188
+ x = self.bn1(x)
189
+ x = self.relu(x)
190
+ x = self.maxpool(x)
191
+ x = self.layer1(x)
192
+ x = self.layer2(x)
193
+ x = self.attn2(x)
194
+ x = self.layer3(x)
195
+ x = self.attn3(x)
196
+ x = self.layer4(x)
197
+ return x
198
+
199
+
200
+ class PARModel(nn.Module):
201
+ """Enhanced Pedestrian Attribute Recognition Model"""
202
+ def __init__(self, num_color_classes=11):
203
+ super().__init__()
204
+ self.top_cnn = CustomResNet(block=Bottleneck, layers=[3, 4, 6, 3])
205
+ self.middle_cnn = CustomResNet(block=Bottleneck, layers=[3, 4, 6, 3])
206
+ self.bottom_cnn = CustomResNet(block=Bottleneck, layers=[3, 4, 6, 3])
207
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
208
+ feature_size = 512 * Bottleneck.expansion
209
+ self.dropout = nn.Dropout(0.5)
210
+ self.gender_weights = nn.Parameter(torch.ones(3))
211
+ self.bag_weights = nn.Parameter(torch.ones(2))
212
+ self.color_consistency = ColorConsistencyModule(feature_size, num_color_classes)
213
+
214
+ # Fast path layers
215
+ self.hat_layer_fast = nn.Linear(feature_size, 1)
216
+ self.gender_top_layer_fast = nn.Linear(feature_size, 1)
217
+ self.upper_color_layer_fast = nn.Sequential(
218
+ nn.Linear(feature_size, 512),
219
+ nn.ReLU(),
220
+ nn.Dropout(0.4),
221
+ nn.Linear(512, num_color_classes)
222
+ )
223
+ self.bag_mid_layer_fast = nn.Linear(feature_size, 1)
224
+ self.gender_mid_layer_fast = nn.Linear(feature_size, 1)
225
+ self.lower_color_layer_fast = nn.Sequential(
226
+ nn.Linear(feature_size, 512),
227
+ nn.ReLU(),
228
+ nn.Dropout(0.4),
229
+ nn.Linear(512, num_color_classes)
230
+ )
231
+ self.bag_bot_layer_fast = nn.Linear(feature_size, 1)
232
+ self.gender_bot_layer_fast = nn.Linear(feature_size, 1)
233
+
234
+ # Shared refinement
235
+ self.shared_binary_refine_base = nn.Sequential(
236
+ nn.Linear(feature_size, 256),
237
+ nn.ReLU()
238
+ )
239
+ self.shared_binary_refine_hat = nn.Linear(256, 1)
240
+ self.shared_binary_refine_bag_mid = nn.Linear(256, 1)
241
+ self.shared_binary_refine_bag_bot = nn.Linear(256, 1)
242
+ self.shared_binary_refine_gender_top = nn.Linear(256, 1)
243
+ self.shared_binary_refine_gender_mid = nn.Linear(256, 1)
244
+ self.shared_binary_refine_gender_bot = nn.Linear(256, 1)
245
+
246
+ def forward(self, top, middle, bottom, full_image):
247
+ top_feat = self.top_cnn(top)
248
+ mid_feat = self.middle_cnn(middle)
249
+ bot_feat = self.bottom_cnn(bottom)
250
+
251
+ top_feat = self.pool(top_feat).view(top.size(0), -1)
252
+ mid_feat = self.pool(mid_feat).view(middle.size(0), -1)
253
+ bot_feat = self.pool(bot_feat).view(bottom.size(0), -1)
254
+
255
+ (top_feat, mid_feat, bot_feat,
256
+ upper_color_refined, lower_color_refined,
257
+ full_hist) = self.color_consistency(
258
+ top_feat, mid_feat, bot_feat, full_image
259
+ )
260
+
261
+ top_feat = self.dropout(top_feat)
262
+ mid_feat = self.dropout(mid_feat)
263
+ bot_feat = self.dropout(bot_feat)
264
+
265
+ outputs = {'full_hist': full_hist}
266
+
267
+ # TOP STREAM
268
+ hat_fast = self.hat_layer_fast(top_feat).squeeze(1)
269
+ gender_top_fast = self.gender_top_layer_fast(top_feat).squeeze(1)
270
+ top_base = self.shared_binary_refine_base(top_feat)
271
+ hat_refine = self.shared_binary_refine_hat(top_base).squeeze(1)
272
+ gender_top_refine = self.shared_binary_refine_gender_top(top_base).squeeze(1)
273
+ hat_pred = hat_fast + hat_refine
274
+ gender_top = gender_top_fast + gender_top_refine
275
+ outputs['hat'] = hat_pred
276
+ outputs['gender_top'] = gender_top
277
+
278
+ # MIDDLE STREAM
279
+ bag_mid_fast = self.bag_mid_layer_fast(mid_feat).squeeze(1)
280
+ upper_color_fast = self.upper_color_layer_fast(mid_feat)
281
+ gender_mid_fast = self.gender_mid_layer_fast(mid_feat).squeeze(1)
282
+ mid_base = self.shared_binary_refine_base(mid_feat)
283
+ bag_mid_refine = self.shared_binary_refine_bag_mid(mid_base).squeeze(1)
284
+ gender_mid_refine = self.shared_binary_refine_gender_mid(mid_base).squeeze(1)
285
+ bag_mid_pred = bag_mid_fast + bag_mid_refine
286
+ upper_color = upper_color_fast + upper_color_refined
287
+ gender_mid = gender_mid_fast + gender_mid_refine
288
+ outputs['bag_mid'] = bag_mid_pred
289
+ outputs['upper_color'] = upper_color
290
+ outputs['gender_mid'] = gender_mid
291
+
292
+ # BOTTOM STREAM
293
+ bag_bot_fast = self.bag_bot_layer_fast(bot_feat).squeeze(1)
294
+ lower_color_fast = self.lower_color_layer_fast(bot_feat)
295
+ gender_bot_fast = self.gender_bot_layer_fast(bot_feat).squeeze(1)
296
+ bot_base = self.shared_binary_refine_base(bot_feat)
297
+ bag_bot_refine = self.shared_binary_refine_bag_bot(bot_base).squeeze(1)
298
+ gender_bot_refine = self.shared_binary_refine_gender_bot(bot_base).squeeze(1)
299
+ bag_bot_pred = bag_bot_fast + bag_bot_refine
300
+ lower_color = lower_color_fast + lower_color_refined
301
+ gender_bot = gender_bot_fast + gender_bot_refine
302
+ outputs['bag_bot'] = bag_bot_pred
303
+ outputs['lower_color'] = lower_color
304
+ outputs['gender_bot'] = gender_bot
305
+
306
+ # Combine predictions
307
+ gender_weights = torch.softmax(self.gender_weights, dim=0)
308
+ gender = (outputs['gender_top'] * gender_weights[0] +
309
+ outputs['gender_mid'] * gender_weights[1] +
310
+ outputs['gender_bot'] * gender_weights[2])
311
+
312
+ bag_weights = torch.softmax(self.bag_weights, dim=0)
313
+ bag = (outputs['bag_mid'] * bag_weights[0] +
314
+ outputs['bag_bot'] * bag_weights[1])
315
+
316
+ return (
317
+ outputs['hat'],
318
+ outputs['upper_color'],
319
+ outputs['lower_color'],
320
+ gender,
321
+ bag,
322
+ outputs['gender_top'],
323
+ outputs['gender_mid'],
324
+ outputs['gender_bot'],
325
+ outputs['bag_mid'],
326
+ outputs['bag_bot'],
327
+ outputs['full_hist']
328
+ )
329
+
330
+
331
+ # ========================================================
332
+ # CONFIGURATION
333
+ # ========================================================
334
+
335
+ CHECKPOINT_PATH = "checkpoint.pth"
336
+ IMG_SIZE = (224, 224)
337
+ ATTRIBUTE_THRESHOLDS = {'hat': 0.5, 'gender': 0.5, 'bag': 0.5}
338
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
339
+
340
+ COLOR_MAP = {
341
+ 1: "Black", 2: "Blue", 3: "Brown", 4: "Gray", 5: "Green",
342
+ 6: "Orange", 7: "Pink", 8: "Purple", 9: "Red", 10: "White", 11: "Yellow"
343
+ }
344
+
345
+ # Define transforms
346
+ val_transform = transforms.Compose([
347
+ transforms.Resize(IMG_SIZE),
348
+ transforms.ToTensor(),
349
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
350
+ ])
351
+
352
+ # Global model variable
353
+ model = None
354
+
355
+ # Create examples directory
356
+ EXAMPLES_DIR = "examples"
357
+ os.makedirs(EXAMPLES_DIR, exist_ok=True)
358
+
359
+
360
+ # ========================================================
361
+ # HELPER FUNCTIONS
362
+ # ========================================================
363
+
364
+ def load_model():
365
+ """Load the trained model"""
366
+ global model
367
+ try:
368
+ model = PARModel().to(DEVICE)
369
+ if os.path.exists(CHECKPOINT_PATH):
370
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
371
+ model_state_dict = model.state_dict()
372
+ pretrained_dict = {
373
+ k: v for k, v in checkpoint['model_state_dict'].items()
374
+ if k in model_state_dict and v.size() == model_state_dict[k].size()
375
+ }
376
+ model_state_dict.update(pretrained_dict)
377
+ model.load_state_dict(model_state_dict)
378
+ model.eval()
379
+ logger.info("Model loaded successfully!")
380
+ return True
381
+ else:
382
+ logger.error(f"Checkpoint file not found: {CHECKPOINT_PATH}")
383
+ return False
384
+ except Exception as e:
385
+ logger.error(f"Error loading model: {str(e)}")
386
+ return False
387
+
388
+
389
+ def create_visualization(orig_img, predictions):
390
+ """Create enhanced visualization with predictions overlaid on image - COMPACT VERSION"""
391
+ try:
392
+ # Get original image dimensions
393
+ width, height = orig_img.size
394
+ aspect_ratio = height / width
395
+
396
+ # Create smaller figure for better fit - REDUCED SIZE
397
+ fig_width = 6 # Reduced from 8
398
+ fig_height = fig_width * aspect_ratio
399
+
400
+ # Limit maximum height to prevent overflow
401
+ if fig_height > 10:
402
+ fig_height = 10
403
+ fig_width = fig_height / aspect_ratio
404
+
405
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=80) # Reduced DPI
406
+ ax.imshow(orig_img)
407
+
408
+ # Add region boundaries with thinner lines
409
+ top_rect = patches.Rectangle(
410
+ (0, 0), width, height*0.2,
411
+ linewidth=1.5, edgecolor='#00f5ff', facecolor='none', alpha=0.8
412
+ )
413
+ mid_rect = patches.Rectangle(
414
+ (0, height*0.2), width, height*0.4,
415
+ linewidth=1.5, edgecolor='#39ff14', facecolor='none', alpha=0.8
416
+ )
417
+ bot_rect = patches.Rectangle(
418
+ (0, height*0.6), width, height*0.4,
419
+ linewidth=1.5, edgecolor='#ff006e', facecolor='none', alpha=0.8
420
+ )
421
+
422
+ ax.add_patch(top_rect)
423
+ ax.add_patch(mid_rect)
424
+ ax.add_patch(bot_rect)
425
+
426
+ # Smaller text for predictions
427
+ text_lines = [
428
+ f"Hat: {predictions['hat']['label']} ({predictions['hat']['confidence']:.1%})",
429
+ f"Gender: {predictions['gender']['label']} ({predictions['gender']['confidence']:.1%})",
430
+ f"Bag: {predictions['bag']['label']} ({predictions['bag']['confidence']:.1%})",
431
+ f"Upper: {predictions['upper_color']['label']}",
432
+ f"Lower: {predictions['lower_color']['label']}"
433
+ ]
434
+
435
+ ax.text(
436
+ 0.02, 0.02,
437
+ "\n".join(text_lines),
438
+ transform=ax.transAxes,
439
+ fontsize=9, # Reduced from 11
440
+ fontweight='bold',
441
+ verticalalignment='bottom',
442
+ bbox=dict(
443
+ boxstyle="round,pad=0.3",
444
+ facecolor='black',
445
+ edgecolor='#ff006e',
446
+ alpha=0.9
447
+ ),
448
+ color='white'
449
+ )
450
+
451
+ # Smaller region labels
452
+ region_labels = [
453
+ (0.98, 0.9, "Top\n(Hat)", '#00f5ff'),
454
+ (0.98, 0.5, "Middle\n(Color/Bag)", '#39ff14'),
455
+ (0.98, 0.2, "Bottom\n(Color)", '#ff006e')
456
+ ]
457
+
458
+ for x, y, label, color in region_labels:
459
+ ax.text(
460
+ x, y,
461
+ label,
462
+ transform=ax.transAxes,
463
+ fontsize=7, # Reduced from 9
464
+ fontweight='bold',
465
+ horizontalalignment='right',
466
+ verticalalignment='center',
467
+ bbox=dict(
468
+ boxstyle="round,pad=0.2",
469
+ facecolor='black',
470
+ alpha=0.8,
471
+ edgecolor=color
472
+ ),
473
+ color=color
474
+ )
475
+
476
+ ax.axis('off')
477
+ plt.tight_layout(pad=0)
478
+
479
+ # Convert to image with lower DPI
480
+ buf = io.BytesIO()
481
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=80, facecolor='black', pad_inches=0.05)
482
+ buf.seek(0)
483
+ result_img = Image.open(buf).copy()
484
+ plt.close(fig)
485
+
486
+ return result_img
487
+ except Exception as e:
488
+ logger.error(f"Error creating visualization: {str(e)}")
489
+ return None
490
+
491
+
492
+ def predict(image):
493
+ """Process image and return predictions with visualization"""
494
+ try:
495
+ if image is None:
496
+ return None, "Please upload an image!"
497
+
498
+ # Convert to PIL Image if needed
499
+ if not isinstance(image, Image.Image):
500
+ orig_img = Image.fromarray(image).convert('RGB')
501
+ else:
502
+ orig_img = image.convert('RGB')
503
+
504
+ # Transform image
505
+ img_tensor = val_transform(orig_img)
506
+
507
+ # Split into parts
508
+ H = img_tensor.shape[1]
509
+ top = img_tensor[:, :int(H*0.2), :]
510
+ middle = img_tensor[:, int(H*0.2):int(H*0.6), :]
511
+ bottom = img_tensor[:, int(H*0.6):, :]
512
+ full_image = img_tensor
513
+
514
+ # Add batch dimension and move to device
515
+ top = top.unsqueeze(0).to(DEVICE)
516
+ middle = middle.unsqueeze(0).to(DEVICE)
517
+ bottom = bottom.unsqueeze(0).to(DEVICE)
518
+ full_image = full_image.unsqueeze(0).to(DEVICE)
519
+
520
+ # Run model
521
+ with torch.no_grad():
522
+ (hat_pred, upper_color_pred, lower_color_pred,
523
+ gender_pred, bag_pred, _, _, _, _, _, _) = model(
524
+ top, middle, bottom, full_image
525
+ )
526
+
527
+ # Process predictions
528
+ hat_prob = torch.sigmoid(hat_pred).item()
529
+ hat_class = int(hat_prob > ATTRIBUTE_THRESHOLDS['hat'])
530
+ hat_label = "Yes" if hat_class == 1 else "No"
531
+
532
+ upper_color_class = upper_color_pred.argmax(1).item() + 1
533
+ upper_color_name = COLOR_MAP.get(upper_color_class, f"Unknown({upper_color_class})")
534
+
535
+ lower_color_class = lower_color_pred.argmax(1).item() + 1
536
+ lower_color_name = COLOR_MAP.get(lower_color_class, f"Unknown({lower_color_class})")
537
+
538
+ gender_prob = torch.sigmoid(gender_pred).item()
539
+ gender_class = int(gender_prob > ATTRIBUTE_THRESHOLDS['gender'])
540
+ gender_label = "Female" if gender_class == 1 else "Male"
541
+
542
+ bag_prob = torch.sigmoid(bag_pred).item()
543
+ bag_class = int(bag_prob > ATTRIBUTE_THRESHOLDS['bag'])
544
+ bag_label = "Yes" if bag_class == 1 else "No"
545
+
546
+ predictions = {
547
+ 'hat': {'label': hat_label, 'confidence': hat_prob},
548
+ 'gender': {'label': gender_label, 'confidence': gender_prob},
549
+ 'bag': {'label': bag_label, 'confidence': bag_prob},
550
+ 'upper_color': {'label': upper_color_name, 'class': upper_color_class},
551
+ 'lower_color': {'label': lower_color_name, 'class': lower_color_class}
552
+ }
553
+
554
+ # Create visualization
555
+ result_img = create_visualization(orig_img, predictions)
556
+
557
+ # Create text output
558
+ output_text = f"""
559
+ ## Pedestrian Attribute Recognition Results
560
+
561
+ ### Binary Attributes
562
+ - **Hat**: {hat_label} (Confidence: {hat_prob:.2%})
563
+ - **Gender**: {gender_label} (Confidence: {gender_prob:.2%})
564
+ - **Bag**: {bag_label} (Confidence: {bag_prob:.2%})
565
+
566
+ ### Color Attributes
567
+ - **Upper Body Color**: {upper_color_name}
568
+ - **Lower Body Color**: {lower_color_name}
569
+
570
+ ### Model Information
571
+ - Device: {DEVICE}
572
+ - Image Size: {IMG_SIZE}
573
+ """
574
+
575
+ return result_img, output_text
576
+
577
+ except Exception as e:
578
+ logger.error(f"Error processing image: {str(e)}")
579
+ return None, f"Error: {str(e)}"
580
+
581
+
582
+ def get_example_images():
583
+ """Get list of example images from examples directory"""
584
+ example_images = []
585
+ if os.path.exists(EXAMPLES_DIR):
586
+ for file in os.listdir(EXAMPLES_DIR):
587
+ if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
588
+ example_images.append(os.path.join(EXAMPLES_DIR, file))
589
+ return example_images if example_images else None
590
+
591
+
592
+ # ========================================================
593
+ # GRADIO INTERFACE
594
+ # ========================================================
595
+
596
+ # Load model on startup
597
+ logger.info("Starting Pedestrian Attribute Recognition App...")
598
+ logger.info(f"Using device: {DEVICE}")
599
+ if not load_model():
600
+ logger.error("Failed to load model. Please check the checkpoint path.")
601
+ raise Exception(f"Model checkpoint not found at: {CHECKPOINT_PATH}")
602
+
603
+ # Get example images
604
+ example_images = get_example_images()
605
+
606
+ # Create Gradio interface
607
+ with gr.Blocks(title="Pedestrian Attribute Recognition", theme=gr.themes.Soft()) as demo:
608
+ gr.Markdown(
609
+ """
610
+ # Pedestrian Attribute Recognition System
611
+
612
+ Upload an image of a pedestrian to analyze their attributes including:
613
+ - **Hat Detection** - Whether the person is wearing a hat
614
+ - **Gender Classification** - Male or Female
615
+ - **Bag Detection** - Whether the person is carrying a bag
616
+ - **Upper Body Color** - Color of upper clothing
617
+ - **Lower Body Color** - Color of lower clothing
618
+
619
+ The model uses a custom ResNet-50 architecture with attention mechanisms and color consistency modules.
620
+ """
621
+ )
622
+
623
+ with gr.Row():
624
+ with gr.Column(scale=1):
625
+ input_image = gr.Image(
626
+ label="Upload Pedestrian Image",
627
+ type="pil"
628
+ )
629
+ predict_btn = gr.Button("Analyze Attributes", variant="primary", size="lg")
630
+
631
+ # Add examples if available
632
+ if example_images:
633
+ gr.Examples(
634
+ examples=[[img] for img in example_images],
635
+ inputs=input_image,
636
+ label="Example Images"
637
+ )
638
+ else:
639
+ gr.Markdown(
640
+ """
641
+ **To add example images:**
642
+ 1. Create a folder named `examples` in the same directory as this script
643
+ 2. Add pedestrian images to the `examples` folder
644
+ 3. Restart the app
645
+ """
646
+ )
647
+
648
+ with gr.Column(scale=1):
649
+ output_image = gr.Image(
650
+ label="Annotated Result",
651
+ type="pil"
652
+ )
653
+ output_text = gr.Markdown(label="Predictions")
654
+
655
+ gr.Markdown(
656
+ """
657
+ ### About the Model
658
+
659
+ This system uses an enhanced Pedestrian Attribute Recognition (PAR) model with:
660
+ - **Three-stream ResNet-50** architecture for different body regions
661
+ - **CBAM Attention** mechanisms for improved feature extraction
662
+ - **Color Consistency Module** with differentiable histograms
663
+ - **Multi-task Learning** for simultaneous attribute prediction
664
+
665
+ **Regions Analyzed:**
666
+ - Top (0-20%): Hat detection
667
+ - Middle (20-60%): Upper color, gender, bag
668
+ - Bottom (60-100%): Lower color
669
+ """
670
+ )
671
+
672
+ # Connect the button
673
+ predict_btn.click(
674
+ fn=predict,
675
+ inputs=input_image,
676
+ outputs=[output_image, output_text]
677
+ )
678
+
679
+ # Also trigger on image upload
680
+ input_image.change(
681
+ fn=predict,
682
+ inputs=input_image,
683
+ outputs=[output_image, output_text]
684
+ )
685
+
686
+ # Launch the app
687
+ if __name__ == "__main__":
688
+ demo.launch(
689
+ server_name="0.0.0.0",
690
+ server_port=7860,
691
+ share=False,
692
+ show_error=True
693
+ )