Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms, models | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import os | |
| # Class Mapping for RAF-DB Dataset (7 classes) | |
| class_mapping = { | |
| 0: "Surprise", | |
| 1: "Fear", | |
| 2: "Disgust", | |
| 3: "Happiness", | |
| 4: "Sadness", | |
| 5: "Anger", | |
| 6: "Neutral" | |
| } | |
| # Transformations for inference (same as test transform) | |
| transform = transforms.Compose([ | |
| transforms.Resize((112, 112)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Feature Extraction Backbone | |
| class IR50(nn.Module): | |
| def __init__(self): | |
| super(IR50, self).__init__() | |
| resnet = models.resnet50(weights='IMAGENET1K_V1') | |
| self.conv1 = resnet.conv1 | |
| self.bn1 = resnet.bn1 | |
| self.relu = resnet.relu | |
| self.maxpool = resnet.maxpool | |
| self.layer1 = resnet.layer1 | |
| self.layer2 = resnet.layer2 | |
| self.downsample = nn.Conv2d(512, 256, 1, stride=2) | |
| self.bn_downsample = nn.BatchNorm2d(256, eps=1e-5) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.downsample(x) | |
| x = self.bn_downsample(x) | |
| return x | |
| # HLA Stream | |
| class HLA(nn.Module): | |
| def __init__(self, in_channels=256, reduction=4): | |
| super(HLA, self).__init__() | |
| reduced_channels = in_channels // reduction | |
| self.spatial_branch1 = nn.Conv2d(in_channels, reduced_channels, 1) | |
| self.spatial_branch2 = nn.Conv2d(in_channels, reduced_channels, 1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.channel_restore = nn.Conv2d(reduced_channels, in_channels, 1) | |
| self.channel_attention = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False), | |
| nn.ReLU(), | |
| nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| self.bn = nn.BatchNorm2d(in_channels, eps=1e-5) | |
| def forward(self, x): | |
| b1 = self.spatial_branch1(x) | |
| b2 = self.spatial_branch2(x) | |
| spatial_attn = self.sigmoid(torch.max(b1, b2)) | |
| spatial_attn = self.channel_restore(spatial_attn) | |
| spatial_out = x * spatial_attn | |
| channel_attn = self.channel_attention(spatial_out) | |
| out = spatial_out * channel_attn | |
| out = self.bn(out) | |
| return out | |
| # ViT Stream | |
| class ViT(nn.Module): | |
| def __init__(self, in_channels=256, patch_size=1, embed_dim=768, num_layers=8, num_heads=12): # 8 layers as in the 82.93% version | |
| super(ViT, self).__init__() | |
| self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| num_patches = (7 // patch_size) * (7 // patch_size) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) | |
| self.transformer = nn.ModuleList([ | |
| nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=1536, activation="gelu") | |
| for _ in range(num_layers) | |
| ]) | |
| self.ln = nn.LayerNorm(embed_dim) | |
| self.bn = nn.BatchNorm1d(embed_dim, eps=1e-5) | |
| # Initialize weights | |
| nn.init.xavier_uniform_(self.patch_embed.weight) | |
| nn.init.zeros_(self.patch_embed.bias) | |
| nn.init.normal_(self.cls_token, std=0.02) | |
| nn.init.normal_(self.pos_embed, std=0.02) | |
| def forward(self, x): | |
| x = self.patch_embed(x) | |
| x = x.flatten(2).transpose(1, 2) | |
| cls_tokens = self.cls_token.expand(x.size(0), -1, -1) | |
| x = torch.cat([cls_tokens, x], dim=1) | |
| x = x + self.pos_embed | |
| for layer in self.transformer: | |
| x = layer(x) | |
| x = x[:, 0] | |
| x = self.ln(x) | |
| x = self.bn(x) | |
| return x | |
| # Intensity Stream | |
| class IntensityStream(nn.Module): | |
| def __init__(self, in_channels=256): | |
| super(IntensityStream, self).__init__() | |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32) | |
| sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32) | |
| self.sobel_x = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels) | |
| self.sobel_y = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels) | |
| self.sobel_x.weight.data = sobel_x.repeat(in_channels, 1, 1, 1) | |
| self.sobel_y.weight.data = sobel_y.repeat(in_channels, 1, 1, 1) | |
| self.conv = nn.Conv2d(in_channels, 128, 3, padding=1) | |
| self.bn = nn.BatchNorm2d(128, eps=1e-5) | |
| self.pool = nn.AdaptiveAvgPool2d(1) | |
| self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=1) | |
| # Initialize weights | |
| nn.init.xavier_uniform_(self.conv.weight) | |
| nn.init.zeros_(self.conv.bias) | |
| def forward(self, x): | |
| gx = self.sobel_x(x) | |
| gy = self.sobel_y(x) | |
| grad_magnitude = torch.sqrt(gx**2 + gy**2 + 1e-8) | |
| variance = ((x - x.mean(dim=1, keepdim=True))**2).mean(dim=1).flatten(1) | |
| cnn_out = F.relu(self.conv(grad_magnitude)) | |
| cnn_out = self.bn(cnn_out) | |
| texture_out = self.pool(cnn_out).squeeze(-1).squeeze(-1) | |
| attn_in = cnn_out.flatten(2).permute(2, 0, 1) | |
| attn_in = attn_in / (attn_in.norm(dim=-1, keepdim=True) + 1e-8) | |
| attn_out, _ = self.attention(attn_in, attn_in, attn_in) | |
| context_out = attn_out.mean(dim=0) | |
| out = torch.cat([texture_out, context_out], dim=1) | |
| return out, grad_magnitude, variance | |
| # Full Model (Single-Label Prediction) | |
| class TripleStreamHLAViT(nn.Module): | |
| def __init__(self, num_classes=7): | |
| super(TripleStreamHLAViT, self).__init__() | |
| self.backbone = IR50() | |
| self.hla = HLA() | |
| self.vit = ViT() | |
| self.intensity = IntensityStream() | |
| self.fc_hla = nn.Linear(256*7*7, 768) | |
| self.fc_intensity = nn.Linear(256, 768) | |
| self.fusion_fc = nn.Linear(768*3, 512) | |
| self.bn_fusion = nn.BatchNorm1d(512, eps=1e-5) | |
| self.dropout = nn.Dropout(0.5) | |
| self.classifier = nn.Linear(512, num_classes) | |
| # Initialize weights | |
| nn.init.xavier_uniform_(self.fc_hla.weight) | |
| nn.init.zeros_(self.fc_hla.bias) | |
| nn.init.xavier_uniform_(self.fc_intensity.weight) | |
| nn.init.zeros_(self.fc_intensity.bias) | |
| nn.init.xavier_uniform_(self.fusion_fc.weight) | |
| nn.init.zeros_(self.fusion_fc.bias) | |
| nn.init.xavier_uniform_(self.classifier.weight) | |
| nn.init.zeros_(self.classifier.bias) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| hla_out = self.hla(features) | |
| vit_out = self.vit(features) | |
| intensity_out, grad_magnitude, variance = self.intensity(features) | |
| hla_flat = self.fc_hla(hla_out.view(-1, 256*7*7)) | |
| intensity_flat = self.fc_intensity(intensity_out) | |
| fused = torch.cat([hla_flat, vit_out, intensity_flat], dim=1) | |
| fused = F.relu(self.fusion_fc(fused)) | |
| fused = self.bn_fusion(fused) | |
| fused = self.dropout(fused) | |
| logits = self.classifier(fused) | |
| return logits, hla_out, vit_out, grad_magnitude, variance | |
| # Load the model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| model = TripleStreamHLAViT(num_classes=7).to(device) | |
| model_path = "triple_stream_model_rafdb.pth" # Ensure this file is in the Hugging Face Space repository | |
| try: | |
| # Map the weights to the appropriate device | |
| map_location = torch.device('cpu') if not torch.cuda.is_available() else None | |
| model.load_state_dict(torch.load(model_path, map_location=map_location, weights_only=True)) | |
| model.eval() | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| # Inference and Visualization Function | |
| def predict_emotion(image): | |
| # Convert the input image (from Gradio) to PIL Image | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Preprocess the image | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs, hla_out, _, grad_magnitude, _ = model(image_tensor) | |
| probs = F.softmax(outputs, dim=1) | |
| pred_label = torch.argmax(probs, dim=1).item() | |
| pred_label_name = class_mapping[pred_label] | |
| probabilities = probs.cpu().numpy()[0] | |
| # Create probability dictionary | |
| prob_dict = {class_mapping[i]: float(prob) for i, prob in enumerate(probabilities)} | |
| # Generate HLA heatmap | |
| heatmap = hla_out[0].mean(dim=0).detach().cpu().numpy() | |
| # Denormalize the image for visualization | |
| img = image_tensor[0].permute(1, 2, 0).detach().cpu().numpy() | |
| img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) | |
| img = np.clip(img, 0, 1) | |
| # Plot the input image and heatmap | |
| fig, axs = plt.subplots(1, 2, figsize=(8, 4)) | |
| axs[0].imshow(img) | |
| axs[0].set_title(f"Input Image\nPredicted: {pred_label_name}") | |
| axs[0].axis("off") | |
| axs[1].imshow(heatmap, cmap="jet") | |
| axs[1].set_title("HLA Heatmap") | |
| axs[1].axis("off") | |
| plt.tight_layout() | |
| # Save the plot to a temporary file | |
| plt.savefig("visualization.png") | |
| plt.close() | |
| return pred_label_name, prob_dict, "visualization.png" | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=predict_emotion, | |
| inputs=gr.Image(type="pil", label="Upload an Image"), | |
| outputs=[ | |
| gr.Textbox(label="Predicted Emotion"), | |
| gr.Label(label="Probabilities"), | |
| gr.Image(label="Input Image and HLA Heatmap") | |
| ], | |
| title="Facial Emotion Recognition with TripleStreamHLAViT", | |
| description="Upload an image to predict the facial emotion (Surprise, Fear, Disgust, Happiness, Sadness, Anger, Neutral). This model achieves 82.93% test accuracy on the RAF-DB dataset. The HLA heatmap shows where the model focuses.", | |
| examples=[ | |
| ["examples/surprise.jpg"], | |
| ["examples/sadness.jpg"] | |
| ] | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.launch(share=False) |