SakibRumu commited on
Commit
6d03810
·
verified ·
1 Parent(s): c8f9aaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -59
app.py CHANGED
@@ -1,76 +1,99 @@
1
- import gradio as gr
2
  import torch
3
- import timm
4
- from torch import nn
5
- from torchvision import transforms
6
  from PIL import Image
 
7
 
8
- # Define the class mapping for the RAF-DB dataset
9
- class_mapping = {
10
- "1": "Surprise",
11
- "2": "Fear",
12
- "3": "Disgust",
13
- "4": "Happiness",
14
- "5": "Sadness",
15
- "6": "Anger",
16
- "7": "Contempt"
17
- }
18
 
19
- # Define the function to map folder names to labels
20
- def get_raf_label(file_path):
21
- # Use the class mapping to get the label
22
- return class_mapping[str(file_path.parent.name)]
23
 
24
- # Load the SeresNet50 model using timm
25
- model_path = "custom_seresnet50_emotion_model.pth" # Replace with your actual .pth model path
26
- model = timm.create_model('seresnet50', pretrained=False, num_classes=7) # Load the SeresNet50 model with 7 classes
27
- model.load_state_dict(torch.load(model_path)) # Load the saved model weights
28
- model.eval() # Set the model to evaluation mode
29
 
30
- # Define the emotion classes
31
- emotion_classes = list(class_mapping.values()) # Get emotion classes from the class mapping
32
 
33
- # Define the transformation pipeline
34
- transform = transforms.Compose([
35
- transforms.Resize((224, 224)), # Resize image to 224x224 (adjust as needed)
36
  transforms.ToTensor(),
37
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard normalization for models trained on ImageNet
38
  ])
39
 
40
- # Function for Emotion Prediction
41
  def predict_emotion(image):
42
- img = Image.open(image).convert('RGB') # Convert the uploaded image into RGB
43
- img = transform(img).unsqueeze(0) # Apply transformations and add batch dimension
 
44
  with torch.no_grad():
45
- outputs = model(img) # Get model output
46
- _, pred_idx = torch.max(outputs, 1) # Get the predicted class index
47
- predicted_emotion = emotion_classes[pred_idx.item()]
48
- confidence = torch.nn.functional.softmax(outputs, dim=1)[0][pred_idx].item() * 100 # Get confidence
49
- return predicted_emotion, f"{confidence:.2f}%"
 
 
 
 
 
 
 
50
 
51
- # Gradio interface with xkcd theme
52
- with gr.Blocks(theme="gstaff/xkcd") as demo:
53
- gr.Markdown("# Emotion Recognition Classifier")
54
- gr.Markdown("""
55
- This app uses a deep learning model to recognize emotions in facial images.
56
- The model has been trained on a dataset to classify images into different emotion categories:
57
- * Anger
58
- * Fear
59
- * Happiness
60
- * Sadness
61
- * Surprise
62
- * Contempt
63
- """)
 
 
64
 
65
- # Upload image widget
66
- image_input = gr.Image(type="pil", label="Upload an image of a face")
 
 
67
 
68
- # Outputs
69
- label_output = gr.Textbox(label="Predicted Emotion")
70
- confidence_output = gr.Textbox(label="Confidence Percentage")
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # Button to predict the emotion
73
- image_input.upload(predict_emotion, image_input, [label_output, confidence_output])
 
 
 
 
 
 
 
 
74
 
75
- # Launch the app
76
- demo.launch(share=True)
 
 
 
1
  import torch
2
+ import gradio as gr
3
+ from transformers import AutoModel
 
4
  from PIL import Image
5
+ from torchvision import transforms
6
 
7
+ # Load your custom model from Hugging Face (replace with your actual model)
8
+ model_name = 'Sakibrumu/HybridCNNTransformer' # Replace with your Hugging Face model ID
9
+ model = AutoModel.from_pretrained(model_name)
10
+
11
+ # If you need to fine-tune or adjust the final layer
12
+ model.fc = torch.nn.Linear(2048, 7) # Adjust the final layer for 7 emotion categories
 
 
 
 
13
 
14
+ # Load the model weights (you might not need this if your model is already fine-tuned in Hugging Face)
15
+ model.load_state_dict(torch.load("transformer_emotion_recognition_model.pth"))
 
 
16
 
17
+ # Move to the appropriate device (GPU or CPU)
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model.to(device)
 
 
20
 
21
+ # Make sure the model is in evaluation mode
22
+ model.eval()
23
 
24
+ # Image Preprocessing (e.g., resizing and normalization)
25
+ preprocess = transforms.Compose([
26
+ transforms.Resize((224, 224)), # Resize to the expected input size
27
  transforms.ToTensor(),
28
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization
29
  ])
30
 
31
+ # Prediction function
32
  def predict_emotion(image):
33
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
34
+ image = preprocess(image).unsqueeze(0).to(device) # Preprocess and add batch dimension
35
+
36
  with torch.no_grad():
37
+ outputs = model(image)
38
+ _, predicted = torch.max(outputs, 1) # Get the class with the highest probability
39
+
40
+ # Assuming you have an emotion label list
41
+ emotion_labels = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
42
+ predicted_label = emotion_labels[predicted.item()]
43
+
44
+ # Confidence is the probability of the predicted class
45
+ confidence = torch.nn.functional.softmax(outputs, dim=1)
46
+ predicted_confidence = confidence[0, predicted.item()].item()
47
+
48
+ return predicted_label, round(predicted_confidence * 100, 2)
49
 
50
+ # Custom CSS for layout styling
51
+ css = """
52
+ body {
53
+ background-color: #1e1e1e;
54
+ color: white;
55
+ font-family: Arial, sans-serif;
56
+ padding: 20px;
57
+ }
58
+
59
+ #component-1 {
60
+ background-color: rgba(255, 255, 255, 0.7);
61
+ padding: 20px;
62
+ border-radius: 10px;
63
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
64
+ }
65
 
66
+ #component-2 {
67
+ color: black;
68
+ font-weight: bold;
69
+ }
70
 
71
+ #title {
72
+ color: white;
73
+ font-size: 36px;
74
+ font-weight: bold;
75
+ text-align: center;
76
+ }
77
+
78
+ #description {
79
+ color: white;
80
+ font-size: 16px;
81
+ text-align: center;
82
+ margin-bottom: 20px;
83
+ }
84
+ """
85
 
86
+ # Gradio Interface
87
+ iface = gr.Interface(
88
+ fn=predict_emotion,
89
+ inputs=gr.Image(type="pil"),
90
+ outputs=[gr.Textbox(label="Predicted Emotion"), gr.Textbox(label="Confidence")],
91
+ live=True,
92
+ title="Emotion Classification",
93
+ description="Upload an image to predict the emotion expressed in the image using a fine-tuned SE-ResNet50 model.",
94
+ css=css
95
+ )
96
 
97
+ # Launch the app
98
+ if __name__ == "__main__":
99
+ iface.launch()