SakibRumu commited on
Commit
19f4b73
·
verified ·
1 Parent(s): f11e381

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -41
app.py CHANGED
@@ -1,55 +1,62 @@
1
- import gradio as gr
2
  import torch
3
- import timm
4
- from torch import nn
5
- from torchvision import transforms
6
  from PIL import Image
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Load the pre-trained SE-ResNet50 model from timm
10
- model = timm.create_model("seresnet50", pretrained=False)
11
- model.fc = torch.nn.Linear(2048, 7) # Adjust for the number of emotion categories
12
-
13
- # Load the model weights
14
- model_path = "custom_resnet50_emotion_model.pth"
 
15
 
16
- try:
17
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Ensure compatibility
18
- model.eval()
19
- print("✅ Model loaded successfully.")
20
- except FileNotFoundError:
21
- print("❌ Model file not found. Please check the path.")
22
- except Exception as e:
23
- print(f"❌ Error loading model: {e}")
24
 
25
- # Define image transforms
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
29
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
30
  ])
31
 
32
- # Emotion classes (adjust based on your dataset)
33
- emotions = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
34
-
35
- # Define the prediction function
36
  def predict_emotion(image):
37
- img = transform(image).unsqueeze(0) # Add batch dimension
38
-
39
  with torch.no_grad():
40
- output = model(img)
41
  probs = torch.nn.functional.softmax(output, dim=1)
42
- confidence, predicted_class = probs.max(1)
43
 
44
- emotion = emotions[predicted_class.item()]
45
- percentage = confidence.item() * 100
46
- return emotion, f"{percentage:.2f}%"
47
 
48
- # Custom CSS for layout styling
49
  css = """
50
  body {
 
51
  color: white;
52
- font-family: Arial, sans-serif;
53
  }
54
  #component-1 {
55
  background-color: rgba(255, 255, 255, 0.7);
@@ -58,19 +65,15 @@ body {
58
  }
59
  #component-2 {
60
  color: black;
 
61
  }
62
  """
63
 
64
- iface = gr.Interface(
65
  fn=predict_emotion,
66
  inputs=gr.Image(type="pil"),
67
  outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
68
- live=True,
69
  title="Emotion Classification",
70
- description="Upload an image to predict the emotion expressed in the image using a fine-tuned SE-ResNet50 model.",
71
  css=css
72
- )
73
-
74
- # Launch the app
75
- if __name__ == "__main__":
76
- iface.launch()
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ from torchvision import models, transforms
5
  from PIL import Image
6
+ from transformers import ViTModel
7
 
8
+ # Define Hybrid CNN + Transformer
9
+ class HybridCNNTransformer(nn.Module):
10
+ def __init__(self, num_classes=7):
11
+ super(HybridCNNTransformer, self).__init__()
12
+ self.cnn = models.resnet50(pretrained=True)
13
+ self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])
14
+ self.channel_reduction = nn.Conv2d(2048, 64, kernel_size=1)
15
+ self.to_rgb = nn.Conv2d(64, 3, kernel_size=1)
16
+ self.transformer = ViTModel.from_pretrained("google/vit-base-patch16-224")
17
+ self.fc = nn.Sequential(
18
+ nn.Linear(768, 512),
19
+ nn.ReLU(),
20
+ nn.Dropout(0.3),
21
+ nn.Linear(512, num_classes)
22
+ )
23
 
24
+ def forward(self, x):
25
+ x = self.cnn(x)
26
+ x = self.channel_reduction(x)
27
+ x = self.to_rgb(x)
28
+ x = nn.functional.interpolate(x, size=(224, 224), mode="bilinear")
29
+ x = self.transformer(pixel_values=x).last_hidden_state[:, 0, :]
30
+ return self.fc(x)
31
 
32
+ # Load model
33
+ model = HybridCNNTransformer(num_classes=7)
34
+ model.load_state_dict(torch.load("transformerHybrid_emotation_model.pth", map_location=torch.device('cpu')), strict=False)
35
+ model.eval()
 
 
 
 
36
 
37
+ # Transform
38
  transform = transforms.Compose([
39
  transforms.Resize((224, 224)),
40
  transforms.ToTensor(),
41
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
42
  ])
43
 
44
+ # Prediction function
 
 
 
45
  def predict_emotion(image):
46
+ image = transform(image).unsqueeze(0)
 
47
  with torch.no_grad():
48
+ output = model(image)
49
  probs = torch.nn.functional.softmax(output, dim=1)
50
+ conf, pred = torch.max(probs, 1)
51
 
52
+ labels = ["Angry", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
53
+ return labels[pred.item()], f"{conf.item() * 100:.2f}%"
 
54
 
55
+ # Interface
56
  css = """
57
  body {
58
+ background-color: #1e1e1e;
59
  color: white;
 
60
  }
61
  #component-1 {
62
  background-color: rgba(255, 255, 255, 0.7);
 
65
  }
66
  #component-2 {
67
  color: black;
68
+ font-weight: bold;
69
  }
70
  """
71
 
72
+ gr.Interface(
73
  fn=predict_emotion,
74
  inputs=gr.Image(type="pil"),
75
  outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
 
76
  title="Emotion Classification",
77
+ description="Upload an image to predict the emotion expressed using a Hybrid CNN + ViT model.",
78
  css=css
79
+ ).launch()