PAR_demo / app.py
tbvl22
Add application file
4f2e2c7
import os
import io
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision import transforms
import gradio as gr
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ========================================================
# MODEL ARCHITECTURE (Same as your training code)
# ========================================================
class EnhancedDifferentiableHistogram(nn.Module):
"""Improved differentiable histogram with KDE-based binning"""
def __init__(self, bins=16, channels=3, min_val=0.0, max_val=1.0, bandwidth=0.05):
super().__init__()
self.bins = bins
self.channels = channels
self.min_val = min_val
self.max_val = max_val
self.bandwidth = bandwidth
self.bin_width = (max_val - min_val) / bins
self.bin_centers = nn.Parameter(
torch.linspace(min_val + self.bin_width/2, max_val - self.bin_width/2, bins),
requires_grad=False
)
def forward(self, x):
batch_size = x.size(0)
histograms = []
for c in range(self.channels):
channel_data = x[:, c].view(batch_size, -1, 1)
diff = (channel_data - self.bin_centers.view(1, 1, -1)) / self.bandwidth
kernel = torch.sigmoid(diff + 0.5) - torch.sigmoid(diff - 0.5)
hist = kernel.sum(dim=1)
hist = hist / (hist.sum(dim=1, keepdim=True) + 1e-6)
histograms.append(hist)
return torch.stack(histograms, dim=1)
class ColorConsistencyModule(nn.Module):
"""Enhanced CSCCM with histogram losses"""
def __init__(self, feature_size, num_color_classes, hist_bins=16):
super().__init__()
self.hist_bins = hist_bins
self.hist_layer = EnhancedDifferentiableHistogram(bins=hist_bins)
self.hist_embed = nn.Sequential(
nn.Linear(3 * hist_bins, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
self.top_fusion = nn.Linear(feature_size + 64, feature_size)
self.mid_fusion = nn.Linear(feature_size + 64, feature_size)
self.bottom_fusion = nn.Linear(feature_size + 64, feature_size)
self.upper_color_refine = nn.Sequential(
nn.Linear(feature_size, feature_size//2),
nn.ReLU(),
nn.Linear(feature_size//2, num_color_classes)
)
self.lower_color_refine = nn.Sequential(
nn.Linear(feature_size, feature_size//2),
nn.ReLU(),
nn.Linear(feature_size//2, num_color_classes)
)
def forward(self, top_feat, mid_feat, bot_feat, full_image):
hist = self.hist_layer(full_image)
hist_embed = self.hist_embed(hist.view(hist.size(0), -1))
top_fused = F.relu(self.top_fusion(torch.cat([top_feat, hist_embed], dim=1)))
mid_fused = F.relu(self.mid_fusion(torch.cat([mid_feat, hist_embed], dim=1)))
bot_fused = F.relu(self.bottom_fusion(torch.cat([bot_feat, hist_embed], dim=1)))
upper_color_refined = self.upper_color_refine(mid_fused)
lower_color_refined = self.lower_color_refine(bot_fused)
return top_fused, mid_fused, bot_fused, upper_color_refined, lower_color_refined, hist
class Bottleneck(nn.Module):
"""Bottleneck block for ResNet-50"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample:
identity = self.downsample(x)
out += identity
return self.relu(out)
class ChannelAttention(nn.Module):
"""Channel Attention Module (CBAM)"""
def __init__(self, in_channels, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
avg_out = self.fc(self.avg_pool(x).view(b, c))
max_out = self.fc(self.max_pool(x).view(b, c))
out = avg_out + max_out
return torch.sigmoid(out).view(b, c, 1, 1) * x
class SpatialAttention(nn.Module):
"""Spatial Attention Module (CBAM)"""
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
combined = torch.cat([avg_out, max_out], dim=1)
attention = self.conv(combined)
return self.sigmoid(attention) * x
class CustomResNet(nn.Module):
"""Enhanced ResNet-50"""
def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], in_channels=3):
super().__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.attn2 = ChannelAttention(128 * block.expansion)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.attn3 = SpatialAttention()
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion)
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.attn2(x)
x = self.layer3(x)
x = self.attn3(x)
x = self.layer4(x)
return x
class PARModel(nn.Module):
"""Enhanced Pedestrian Attribute Recognition Model"""
def __init__(self, num_color_classes=11):
super().__init__()
self.top_cnn = CustomResNet(block=Bottleneck, layers=[3, 4, 6, 3])
self.middle_cnn = CustomResNet(block=Bottleneck, layers=[3, 4, 6, 3])
self.bottom_cnn = CustomResNet(block=Bottleneck, layers=[3, 4, 6, 3])
self.pool = nn.AdaptiveAvgPool2d((1, 1))
feature_size = 512 * Bottleneck.expansion
self.dropout = nn.Dropout(0.5)
self.gender_weights = nn.Parameter(torch.ones(3))
self.bag_weights = nn.Parameter(torch.ones(2))
self.color_consistency = ColorConsistencyModule(feature_size, num_color_classes)
# Fast path layers
self.hat_layer_fast = nn.Linear(feature_size, 1)
self.gender_top_layer_fast = nn.Linear(feature_size, 1)
self.upper_color_layer_fast = nn.Sequential(
nn.Linear(feature_size, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, num_color_classes)
)
self.bag_mid_layer_fast = nn.Linear(feature_size, 1)
self.gender_mid_layer_fast = nn.Linear(feature_size, 1)
self.lower_color_layer_fast = nn.Sequential(
nn.Linear(feature_size, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, num_color_classes)
)
self.bag_bot_layer_fast = nn.Linear(feature_size, 1)
self.gender_bot_layer_fast = nn.Linear(feature_size, 1)
# Shared refinement
self.shared_binary_refine_base = nn.Sequential(
nn.Linear(feature_size, 256),
nn.ReLU()
)
self.shared_binary_refine_hat = nn.Linear(256, 1)
self.shared_binary_refine_bag_mid = nn.Linear(256, 1)
self.shared_binary_refine_bag_bot = nn.Linear(256, 1)
self.shared_binary_refine_gender_top = nn.Linear(256, 1)
self.shared_binary_refine_gender_mid = nn.Linear(256, 1)
self.shared_binary_refine_gender_bot = nn.Linear(256, 1)
def forward(self, top, middle, bottom, full_image):
top_feat = self.top_cnn(top)
mid_feat = self.middle_cnn(middle)
bot_feat = self.bottom_cnn(bottom)
top_feat = self.pool(top_feat).view(top.size(0), -1)
mid_feat = self.pool(mid_feat).view(middle.size(0), -1)
bot_feat = self.pool(bot_feat).view(bottom.size(0), -1)
(top_feat, mid_feat, bot_feat,
upper_color_refined, lower_color_refined,
full_hist) = self.color_consistency(
top_feat, mid_feat, bot_feat, full_image
)
top_feat = self.dropout(top_feat)
mid_feat = self.dropout(mid_feat)
bot_feat = self.dropout(bot_feat)
outputs = {'full_hist': full_hist}
# TOP STREAM
hat_fast = self.hat_layer_fast(top_feat).squeeze(1)
gender_top_fast = self.gender_top_layer_fast(top_feat).squeeze(1)
top_base = self.shared_binary_refine_base(top_feat)
hat_refine = self.shared_binary_refine_hat(top_base).squeeze(1)
gender_top_refine = self.shared_binary_refine_gender_top(top_base).squeeze(1)
hat_pred = hat_fast + hat_refine
gender_top = gender_top_fast + gender_top_refine
outputs['hat'] = hat_pred
outputs['gender_top'] = gender_top
# MIDDLE STREAM
bag_mid_fast = self.bag_mid_layer_fast(mid_feat).squeeze(1)
upper_color_fast = self.upper_color_layer_fast(mid_feat)
gender_mid_fast = self.gender_mid_layer_fast(mid_feat).squeeze(1)
mid_base = self.shared_binary_refine_base(mid_feat)
bag_mid_refine = self.shared_binary_refine_bag_mid(mid_base).squeeze(1)
gender_mid_refine = self.shared_binary_refine_gender_mid(mid_base).squeeze(1)
bag_mid_pred = bag_mid_fast + bag_mid_refine
upper_color = upper_color_fast + upper_color_refined
gender_mid = gender_mid_fast + gender_mid_refine
outputs['bag_mid'] = bag_mid_pred
outputs['upper_color'] = upper_color
outputs['gender_mid'] = gender_mid
# BOTTOM STREAM
bag_bot_fast = self.bag_bot_layer_fast(bot_feat).squeeze(1)
lower_color_fast = self.lower_color_layer_fast(bot_feat)
gender_bot_fast = self.gender_bot_layer_fast(bot_feat).squeeze(1)
bot_base = self.shared_binary_refine_base(bot_feat)
bag_bot_refine = self.shared_binary_refine_bag_bot(bot_base).squeeze(1)
gender_bot_refine = self.shared_binary_refine_gender_bot(bot_base).squeeze(1)
bag_bot_pred = bag_bot_fast + bag_bot_refine
lower_color = lower_color_fast + lower_color_refined
gender_bot = gender_bot_fast + gender_bot_refine
outputs['bag_bot'] = bag_bot_pred
outputs['lower_color'] = lower_color
outputs['gender_bot'] = gender_bot
# Combine predictions
gender_weights = torch.softmax(self.gender_weights, dim=0)
gender = (outputs['gender_top'] * gender_weights[0] +
outputs['gender_mid'] * gender_weights[1] +
outputs['gender_bot'] * gender_weights[2])
bag_weights = torch.softmax(self.bag_weights, dim=0)
bag = (outputs['bag_mid'] * bag_weights[0] +
outputs['bag_bot'] * bag_weights[1])
return (
outputs['hat'],
outputs['upper_color'],
outputs['lower_color'],
gender,
bag,
outputs['gender_top'],
outputs['gender_mid'],
outputs['gender_bot'],
outputs['bag_mid'],
outputs['bag_bot'],
outputs['full_hist']
)
# ========================================================
# CONFIGURATION
# ========================================================
CHECKPOINT_PATH = "checkpoint.pth"
IMG_SIZE = (224, 224)
ATTRIBUTE_THRESHOLDS = {'hat': 0.5, 'gender': 0.5, 'bag': 0.5}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
COLOR_MAP = {
1: "Black", 2: "Blue", 3: "Brown", 4: "Gray", 5: "Green",
6: "Orange", 7: "Pink", 8: "Purple", 9: "Red", 10: "White", 11: "Yellow"
}
# Define transforms
val_transform = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Global model variable
model = None
# Create examples directory
EXAMPLES_DIR = "examples"
os.makedirs(EXAMPLES_DIR, exist_ok=True)
# ========================================================
# HELPER FUNCTIONS
# ========================================================
def load_model():
"""Load the trained model"""
global model
try:
model = PARModel().to(DEVICE)
if os.path.exists(CHECKPOINT_PATH):
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
model_state_dict = model.state_dict()
pretrained_dict = {
k: v for k, v in checkpoint['model_state_dict'].items()
if k in model_state_dict and v.size() == model_state_dict[k].size()
}
model_state_dict.update(pretrained_dict)
model.load_state_dict(model_state_dict)
model.eval()
logger.info("Model loaded successfully!")
return True
else:
logger.error(f"Checkpoint file not found: {CHECKPOINT_PATH}")
return False
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
def create_visualization(orig_img, predictions):
"""Create enhanced visualization with predictions overlaid on image - COMPACT VERSION"""
try:
# Get original image dimensions
width, height = orig_img.size
aspect_ratio = height / width
# Create smaller figure for better fit - REDUCED SIZE
fig_width = 6 # Reduced from 8
fig_height = fig_width * aspect_ratio
# Limit maximum height to prevent overflow
if fig_height > 10:
fig_height = 10
fig_width = fig_height / aspect_ratio
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=80) # Reduced DPI
ax.imshow(orig_img)
# Add region boundaries with thinner lines
top_rect = patches.Rectangle(
(0, 0), width, height*0.2,
linewidth=1.5, edgecolor='#00f5ff', facecolor='none', alpha=0.8
)
mid_rect = patches.Rectangle(
(0, height*0.2), width, height*0.4,
linewidth=1.5, edgecolor='#39ff14', facecolor='none', alpha=0.8
)
bot_rect = patches.Rectangle(
(0, height*0.6), width, height*0.4,
linewidth=1.5, edgecolor='#ff006e', facecolor='none', alpha=0.8
)
ax.add_patch(top_rect)
ax.add_patch(mid_rect)
ax.add_patch(bot_rect)
# Smaller text for predictions
text_lines = [
f"Hat: {predictions['hat']['label']} ({predictions['hat']['confidence']:.1%})",
f"Gender: {predictions['gender']['label']} ({predictions['gender']['confidence']:.1%})",
f"Bag: {predictions['bag']['label']} ({predictions['bag']['confidence']:.1%})",
f"Upper: {predictions['upper_color']['label']}",
f"Lower: {predictions['lower_color']['label']}"
]
ax.text(
0.02, 0.02,
"\n".join(text_lines),
transform=ax.transAxes,
fontsize=9, # Reduced from 11
fontweight='bold',
verticalalignment='bottom',
bbox=dict(
boxstyle="round,pad=0.3",
facecolor='black',
edgecolor='#ff006e',
alpha=0.9
),
color='white'
)
# Smaller region labels
region_labels = [
(0.98, 0.9, "Top\n(Hat)", '#00f5ff'),
(0.98, 0.5, "Middle\n(Color/Bag)", '#39ff14'),
(0.98, 0.2, "Bottom\n(Color)", '#ff006e')
]
for x, y, label, color in region_labels:
ax.text(
x, y,
label,
transform=ax.transAxes,
fontsize=7, # Reduced from 9
fontweight='bold',
horizontalalignment='right',
verticalalignment='center',
bbox=dict(
boxstyle="round,pad=0.2",
facecolor='black',
alpha=0.8,
edgecolor=color
),
color=color
)
ax.axis('off')
plt.tight_layout(pad=0)
# Convert to image with lower DPI
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=80, facecolor='black', pad_inches=0.05)
buf.seek(0)
result_img = Image.open(buf).copy()
plt.close(fig)
return result_img
except Exception as e:
logger.error(f"Error creating visualization: {str(e)}")
return None
def predict(image):
"""Process image and return predictions with visualization"""
try:
if image is None:
return None, "Please upload an image!"
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
orig_img = Image.fromarray(image).convert('RGB')
else:
orig_img = image.convert('RGB')
# Transform image
img_tensor = val_transform(orig_img)
# Split into parts
H = img_tensor.shape[1]
top = img_tensor[:, :int(H*0.2), :]
middle = img_tensor[:, int(H*0.2):int(H*0.6), :]
bottom = img_tensor[:, int(H*0.6):, :]
full_image = img_tensor
# Add batch dimension and move to device
top = top.unsqueeze(0).to(DEVICE)
middle = middle.unsqueeze(0).to(DEVICE)
bottom = bottom.unsqueeze(0).to(DEVICE)
full_image = full_image.unsqueeze(0).to(DEVICE)
# Run model
with torch.no_grad():
(hat_pred, upper_color_pred, lower_color_pred,
gender_pred, bag_pred, _, _, _, _, _, _) = model(
top, middle, bottom, full_image
)
# Process predictions
hat_prob = torch.sigmoid(hat_pred).item()
hat_class = int(hat_prob > ATTRIBUTE_THRESHOLDS['hat'])
hat_label = "Yes" if hat_class == 1 else "No"
upper_color_class = upper_color_pred.argmax(1).item() + 1
upper_color_name = COLOR_MAP.get(upper_color_class, f"Unknown({upper_color_class})")
lower_color_class = lower_color_pred.argmax(1).item() + 1
lower_color_name = COLOR_MAP.get(lower_color_class, f"Unknown({lower_color_class})")
gender_prob = torch.sigmoid(gender_pred).item()
gender_class = int(gender_prob > ATTRIBUTE_THRESHOLDS['gender'])
gender_label = "Female" if gender_class == 1 else "Male"
bag_prob = torch.sigmoid(bag_pred).item()
bag_class = int(bag_prob > ATTRIBUTE_THRESHOLDS['bag'])
bag_label = "Yes" if bag_class == 1 else "No"
predictions = {
'hat': {'label': hat_label, 'confidence': hat_prob},
'gender': {'label': gender_label, 'confidence': gender_prob},
'bag': {'label': bag_label, 'confidence': bag_prob},
'upper_color': {'label': upper_color_name, 'class': upper_color_class},
'lower_color': {'label': lower_color_name, 'class': lower_color_class}
}
# Create visualization
result_img = create_visualization(orig_img, predictions)
# Create text output
output_text = f"""
## Pedestrian Attribute Recognition Results
### Binary Attributes
- **Hat**: {hat_label} (Confidence: {hat_prob:.2%})
- **Gender**: {gender_label} (Confidence: {gender_prob:.2%})
- **Bag**: {bag_label} (Confidence: {bag_prob:.2%})
### Color Attributes
- **Upper Body Color**: {upper_color_name}
- **Lower Body Color**: {lower_color_name}
### Model Information
- Device: {DEVICE}
- Image Size: {IMG_SIZE}
"""
return result_img, output_text
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
return None, f"Error: {str(e)}"
def get_example_images():
"""Get list of example images from examples directory"""
example_images = []
if os.path.exists(EXAMPLES_DIR):
for file in os.listdir(EXAMPLES_DIR):
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
example_images.append(os.path.join(EXAMPLES_DIR, file))
return example_images if example_images else None
# ========================================================
# GRADIO INTERFACE
# ========================================================
# Load model on startup
logger.info("Starting Pedestrian Attribute Recognition App...")
logger.info(f"Using device: {DEVICE}")
if not load_model():
logger.error("Failed to load model. Please check the checkpoint path.")
raise Exception(f"Model checkpoint not found at: {CHECKPOINT_PATH}")
# Get example images
example_images = get_example_images()
# Create Gradio interface
with gr.Blocks(title="Pedestrian Attribute Recognition", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Pedestrian Attribute Recognition System
Upload an image of a pedestrian to analyze their attributes including:
- **Hat Detection** - Whether the person is wearing a hat
- **Gender Classification** - Male or Female
- **Bag Detection** - Whether the person is carrying a bag
- **Upper Body Color** - Color of upper clothing
- **Lower Body Color** - Color of lower clothing
The model uses a custom ResNet-50 architecture with attention mechanisms and color consistency modules.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Upload Pedestrian Image",
type="pil"
)
predict_btn = gr.Button("Analyze Attributes", variant="primary", size="lg")
# Add examples if available
if example_images:
gr.Examples(
examples=[[img] for img in example_images],
inputs=input_image,
label="Example Images"
)
else:
gr.Markdown(
"""
**To add example images:**
1. Create a folder named `examples` in the same directory as this script
2. Add pedestrian images to the `examples` folder
3. Restart the app
"""
)
with gr.Column(scale=1):
output_image = gr.Image(
label="Annotated Result",
type="pil"
)
output_text = gr.Markdown(label="Predictions")
gr.Markdown(
"""
### About the Model
This system uses an enhanced Pedestrian Attribute Recognition (PAR) model with:
- **Three-stream ResNet-50** architecture for different body regions
- **CBAM Attention** mechanisms for improved feature extraction
- **Color Consistency Module** with differentiable histograms
- **Multi-task Learning** for simultaneous attribute prediction
**Regions Analyzed:**
- Top (0-20%): Hat detection
- Middle (20-60%): Upper color, gender, bag
- Bottom (60-100%): Lower color
"""
)
# Connect the button
predict_btn.click(
fn=predict,
inputs=input_image,
outputs=[output_image, output_text]
)
# Also trigger on image upload
input_image.change(
fn=predict,
inputs=input_image,
outputs=[output_image, output_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)