SakibRumu commited on
Commit
fbe096f
·
verified ·
1 Parent(s): 90ad8e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -91
app.py CHANGED
@@ -1,114 +1,273 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import gradio as gr
4
- from torchvision import models, transforms
 
5
  from PIL import Image
6
- from transformers import ViTModel
 
7
 
8
- # Define HybridCNNTransformer Model
9
- class HybridCNNTransformer(nn.Module):
10
- def __init__(self, num_classes=7):
11
- super(HybridCNNTransformer, self).__init__()
12
-
13
- # CNN Feature Extractor (ResNet50)
14
- self.cnn = models.resnet50(pretrained=True)
15
- self.cnn = nn.Sequential(*list(self.cnn.children())[:-2]) # Remove FC layers
 
 
16
 
17
- # Reduce channels (2048 64)
18
- self.channel_reduction = nn.Conv2d(in_channels=2048, out_channels=64, kernel_size=1)
 
 
 
 
19
 
20
- # Convert to 3 channels for ViT
21
- self.to_rgb = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Vision Transformer
24
- self.transformer = ViTModel.from_pretrained("google/vit-base-patch16-224")
 
 
 
 
 
 
 
 
25
 
26
- # Fully Connected Layers (Classifier Head)
27
- self.fc = nn.Sequential(
28
- nn.Linear(768, 512),
 
 
 
 
 
 
 
 
 
29
  nn.ReLU(),
30
- nn.Dropout(0.3),
31
- nn.Linear(512, num_classes)
32
  )
 
33
 
34
  def forward(self, x):
35
- cnn_features = self.cnn(x)
36
- reduced_features = self.channel_reduction(cnn_features)
37
- rgb_features = self.to_rgb(reduced_features)
38
- resized_features = nn.functional.interpolate(rgb_features, size=(224, 224), mode="bilinear", align_corners=False)
39
-
40
- transformer_output = self.transformer(pixel_values=resized_features).last_hidden_state[:, 0, :]
41
- output = self.fc(transformer_output)
42
- return output
43
-
44
- # Load Model
45
- model = HybridCNNTransformer(num_classes=7)
46
- state_dict = torch.load("transformer_emotion_recognition_model.pth", map_location=torch.device('cpu'))
47
- model.load_state_dict(state_dict, strict=False)
48
- model.eval()
49
-
50
- # Define Preprocessing Transform
51
- transform = transforms.Compose([
52
- transforms.Resize((224, 224)),
53
- transforms.ToTensor(),
54
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Define Prediction Function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def predict_emotion(image):
59
- image = transform(image).unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
 
60
  with torch.no_grad():
61
- output = model(image)
62
- probabilities = torch.nn.functional.softmax(output, dim=1)
63
- confidence, predicted_class = torch.max(probabilities, 1)
 
 
 
 
 
 
 
 
64
 
65
- class_labels = ["Angry", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
66
- predicted_emotion = class_labels[predicted_class.item()]
67
- return predicted_emotion, f"{confidence.item() * 100:.2f}%"
68
-
69
- # Custom CSS for UI Styling
70
- css = """
71
- body {
72
- background-color: #1e1e1e;
73
- color: white;
74
- font-family: Arial, sans-serif;
75
- padding: 20px;
76
- }
77
- #component-1 {
78
- background-color: rgba(255, 255, 255, 0.7);
79
- padding: 20px;
80
- border-radius: 10px;
81
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
82
- }
83
- #component-2 {
84
- color: black;
85
- font-weight: bold;
86
- }
87
- #title {
88
- color: white;
89
- font-size: 36px;
90
- font-weight: bold;
91
- text-align: center;
92
- }
93
- #description {
94
- color: white;
95
- font-size: 16px;
96
- text-align: center;
97
- margin-bottom: 20px;
98
- }
99
- """
100
 
101
  # Gradio Interface
102
  iface = gr.Interface(
103
  fn=predict_emotion,
104
- inputs=gr.Image(type="pil"),
105
- outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
106
- live=True,
107
- title="Emotion Classification",
108
- description="Upload an image to predict the emotion expressed in the image using a fine-tuned ResNet50 + Vision Transformer model.",
109
- css=css
 
 
 
 
 
 
 
110
  )
111
 
112
- # Launch the app
113
  if __name__ == "__main__":
114
- iface.launch()
 
1
+ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms, models
6
+ import numpy as np
7
  from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import os
10
 
11
+ # Class Mapping for RAF-DB Dataset (7 classes)
12
+ class_mapping = {
13
+ 0: "Surprise",
14
+ 1: "Fear",
15
+ 2: "Disgust",
16
+ 3: "Happiness",
17
+ 4: "Sadness",
18
+ 5: "Anger",
19
+ 6: "Neutral"
20
+ }
21
 
22
+ # Transformations for inference (same as test transform)
23
+ transform = transforms.Compose([
24
+ transforms.Resize((112, 112)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
27
+ ])
28
 
29
+ # Feature Extraction Backbone
30
+ class IR50(nn.Module):
31
+ def __init__(self):
32
+ super(IR50, self).__init__()
33
+ resnet = models.resnet50(weights='IMAGENET1K_V1')
34
+ self.conv1 = resnet.conv1
35
+ self.bn1 = resnet.bn1
36
+ self.relu = resnet.relu
37
+ self.maxpool = resnet.maxpool
38
+ self.layer1 = resnet.layer1
39
+ self.layer2 = resnet.layer2
40
+ self.downsample = nn.Conv2d(512, 256, 1, stride=2)
41
+ self.bn_downsample = nn.BatchNorm2d(256, eps=1e-5)
42
+ # Fine-tuned layers (as in training)
43
+ for param in self.conv1.parameters():
44
+ param.requires_grad = True
45
+ for param in self.bn1.parameters():
46
+ param.requires_grad = True
47
+ for param in self.layer1.parameters():
48
+ param.requires_grad = True
49
 
50
+ def forward(self, x):
51
+ x = self.conv1(x)
52
+ x = self.bn1(x)
53
+ x = self.relu(x)
54
+ x = self.maxpool(x)
55
+ x = self.layer1(x)
56
+ x = self.layer2(x)
57
+ x = self.downsample(x)
58
+ x = self.bn_downsample(x)
59
+ return x
60
 
61
+ # HLA Stream
62
+ class HLA(nn.Module):
63
+ def __init__(self, in_channels=256, reduction=4):
64
+ super(HLA, self).__init__()
65
+ reduced_channels = in_channels // reduction
66
+ self.spatial_branch1 = nn.Conv2d(in_channels, reduced_channels, 1)
67
+ self.spatial_branch2 = nn.Conv2d(in_channels, reduced_channels, 1)
68
+ self.sigmoid = nn.Sigmoid()
69
+ self.channel_restore = nn.Conv2d(reduced_channels, in_channels, 1)
70
+ self.channel_attention = nn.Sequential(
71
+ nn.AdaptiveAvgPool2d(1),
72
+ nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
73
  nn.ReLU(),
74
+ nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False),
75
+ nn.Sigmoid()
76
  )
77
+ self.bn = nn.BatchNorm2d(in_channels, eps=1e-5)
78
 
79
  def forward(self, x):
80
+ b1 = self.spatial_branch1(x)
81
+ b2 = self.spatial_branch2(x)
82
+ spatial_attn = self.sigmoid(torch.max(b1, b2))
83
+ spatial_attn = self.channel_restore(spatial_attn)
84
+ spatial_out = x * spatial_attn
85
+ channel_attn = self.channel_attention(spatial_out)
86
+ out = spatial_out * channel_attn
87
+ out = self.bn(out)
88
+ return out
89
+
90
+ # ViT Stream
91
+ class ViT(nn.Module):
92
+ def __init__(self, in_channels=256, patch_size=1, embed_dim=768, num_layers=12, num_heads=12):
93
+ super(ViT, self).__init__()
94
+ self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
95
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
96
+ num_patches = (7 // patch_size) * (7 // patch_size)
97
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
98
+ self.transformer = nn.ModuleList([
99
+ nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=1536, activation="gelu")
100
+ for _ in range(num_layers)
101
+ ])
102
+ self.ln = nn.LayerNorm(embed_dim)
103
+ self.bn = nn.BatchNorm1d(embed_dim, eps=1e-5)
104
+
105
+ # Initialize weights
106
+ nn.init.xavier_uniform_(self.patch_embed.weight)
107
+ nn.init.zeros_(self.patch_embed.bias)
108
+ nn.init.normal_(self.cls_token, std=0.02)
109
+ nn.init.normal_(self.pos_embed, std=0.02)
110
+
111
+ def forward(self, x):
112
+ x = self.patch_embed(x)
113
+ x = x.flatten(2).transpose(1, 2)
114
+ cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
115
+ x = torch.cat([cls_tokens, x], dim=1)
116
+ x = x + self.pos_embed
117
+ for layer in self.transformer:
118
+ x = layer(x)
119
+ x = x[:, 0]
120
+ x = self.ln(x)
121
+ x = self.bn(x)
122
+ return x
123
+
124
+ # Intensity Stream
125
+ class IntensityStream(nn.Module):
126
+ def __init__(self, in_channels=256):
127
+ super(IntensityStream, self).__init__()
128
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
129
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
130
+ self.sobel_x = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels)
131
+ self.sobel_y = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels)
132
+ self.sobel_x.weight.data = sobel_x.repeat(in_channels, 1, 1, 1)
133
+ self.sobel_y.weight.data = sobel_y.repeat(in_channels, 1, 1, 1)
134
+ self.conv = nn.Conv2d(in_channels, 128, 3, padding=1)
135
+ self.bn = nn.BatchNorm2d(128, eps=1e-5)
136
+ self.pool = nn.AdaptiveAvgPool2d(1)
137
+ self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=1)
138
+
139
+ # Initialize weights
140
+ nn.init.xavier_uniform_(self.conv.weight)
141
+ nn.init.zeros_(self.conv.bias)
142
 
143
+ def forward(self, x):
144
+ gx = self.sobel_x(x)
145
+ gy = self.sobel_y(x)
146
+ grad_magnitude = torch.sqrt(gx**2 + gy**2 + 1e-8)
147
+ variance = ((x - x.mean(dim=1, keepdim=True))**2).mean(dim=1).flatten(1)
148
+ cnn_out = F.relu(self.conv(grad_magnitude))
149
+ cnn_out = self.bn(cnn_out)
150
+ texture_out = self.pool(cnn_out).squeeze(-1).squeeze(-1)
151
+ attn_in = cnn_out.flatten(2).permute(2, 0, 1)
152
+ attn_in = attn_in / (attn_in.norm(dim=-1, keepdim=True) + 1e-8)
153
+ attn_out, _ = self.attention(attn_in, attn_in, attn_in)
154
+ context_out = attn_out.mean(dim=0)
155
+ out = torch.cat([texture_out, context_out], dim=1)
156
+ return out, grad_magnitude, variance
157
+
158
+ # Full Model (Single-Label Prediction)
159
+ class TripleStreamHLAViT(nn.Module):
160
+ def __init__(self, num_classes=7):
161
+ super(TripleStreamHLAViT, self).__init__()
162
+ self.backbone = IR50()
163
+ self.hla = HLA()
164
+ self.vit = ViT()
165
+ self.intensity = IntensityStream()
166
+ self.fc_hla = nn.Linear(256*7*7, 768)
167
+ self.fc_intensity = nn.Linear(256, 768)
168
+ self.fusion_fc = nn.Linear(768*3, 512)
169
+ self.bn_fusion = nn.BatchNorm1d(512, eps=1e-5)
170
+ self.dropout = nn.Dropout(0.5)
171
+ self.classifier = nn.Linear(512, num_classes)
172
+
173
+ # Initialize weights
174
+ nn.init.xavier_uniform_(self.fc_hla.weight)
175
+ nn.init.zeros_(self.fc_hla.bias)
176
+ nn.init.xavier_uniform_(self.fc_intensity.weight)
177
+ nn.init.zeros_(self.fc_intensity.bias)
178
+ nn.init.xavier_uniform_(self.fusion_fc.weight)
179
+ nn.init.zeros_(self.fusion_fc.bias)
180
+ nn.init.xavier_uniform_(self.classifier.weight)
181
+ nn.init.zeros_(self.classifier.bias)
182
+
183
+ def forward(self, x):
184
+ features = self.backbone(x)
185
+ hla_out = self.hla(features)
186
+ vit_out = self.vit(features)
187
+ intensity_out, grad_magnitude, variance = self.intensity(features)
188
+ hla_flat = self.fc_hla(hla_out.view(-1, 256*7*7))
189
+ intensity_flat = self.fc_intensity(intensity_out)
190
+ fused = torch.cat([hla_flat, vit_out, intensity_flat], dim=1)
191
+ fused = F.relu(self.fusion_fc(fused))
192
+ fused = self.bn_fusion(fused)
193
+ fused = self.dropout(fused)
194
+ logits = self.classifier(fused)
195
+ return logits, hla_out, vit_out, grad_magnitude, variance
196
+
197
+ # Load the model
198
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
199
+ model = TripleStreamHLAViT(num_classes=7).to(device)
200
+ model_path = "triple_stream_model_rafdb.pth" # Ensure this file is in the Hugging Face Space repository
201
+ try:
202
+ model.load_state_dict(torch.load(model_path, weights_only=True))
203
+ model.eval()
204
+ print("Model loaded successfully")
205
+ except Exception as e:
206
+ print(f"Error loading model: {e}")
207
+ raise
208
+
209
+ # Inference and Visualization Function
210
  def predict_emotion(image):
211
+ # Convert the input image (from Gradio) to PIL Image
212
+ if isinstance(image, np.ndarray):
213
+ image = Image.fromarray(image)
214
+
215
+ # Preprocess the image
216
+ image_tensor = transform(image).unsqueeze(0).to(device)
217
+
218
+ # Run inference
219
  with torch.no_grad():
220
+ outputs, hla_out, _, grad_magnitude, _ = model(image_tensor)
221
+ probs = F.softmax(outputs, dim=1)
222
+ pred_label = torch.argmax(probs, dim=1).item()
223
+ pred_label_name = class_mapping[pred_label]
224
+ probabilities = probs.cpu().numpy()[0]
225
+
226
+ # Create probability dictionary
227
+ prob_dict = {class_mapping[i]: float(prob) for i, prob in enumerate(probabilities)}
228
+
229
+ # Generate HLA heatmap
230
+ heatmap = hla_out[0].mean(dim=0).detach().cpu().numpy()
231
 
232
+ # Denormalize the image for visualization
233
+ img = image_tensor[0].permute(1, 2, 0).detach().cpu().numpy()
234
+ img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
235
+ img = np.clip(img, 0, 1)
236
+
237
+ # Plot the input image and heatmap
238
+ fig, axs = plt.subplots(1, 2, figsize=(8, 4))
239
+ axs[0].imshow(img)
240
+ axs[0].set_title(f"Input Image\nPredicted: {pred_label_name}")
241
+ axs[0].axis("off")
242
+ axs[1].imshow(heatmap, cmap="jet")
243
+ axs[1].set_title("HLA Heatmap")
244
+ axs[1].axis("off")
245
+ plt.tight_layout()
246
+
247
+ # Save the plot to a temporary file
248
+ plt.savefig("visualization.png")
249
+ plt.close()
250
+
251
+ return pred_label_name, prob_dict, "visualization.png"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  # Gradio Interface
254
  iface = gr.Interface(
255
  fn=predict_emotion,
256
+ inputs=gr.Image(type="pil", label="Upload an Image"),
257
+ outputs=[
258
+ gr.Textbox(label="Predicted Emotion"),
259
+ gr.Label(label="Probabilities"),
260
+ gr.Image(label="Input Image and HLA Heatmap")
261
+ ],
262
+ title="Facial Emotion Recognition with TripleStreamHLAViT",
263
+ description="Upload an image to predict the facial emotion (Surprise, Fear, Disgust, Happiness, Sadness, Anger, Neutral). The model also visualizes the HLA heatmap showing where it focuses.",
264
+ examples=[
265
+ ["examples/Surprise.jpg"],
266
+ ["examples/happy.JPEG"],
267
+ ["examples/sadness.jpg"]
268
+ ]
269
  )
270
 
271
+ # Launch the interface
272
  if __name__ == "__main__":
273
+ iface.launch(share=False)