Dhanushlevi commited on
Commit
1007aeb
·
1 Parent(s): 1385d7b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +78 -0
  2. model.py +48 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_resnet50_model
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
12
+ ### 2. Model and transforms preparation ###
13
+
14
+ # Create model
15
+ resnet50, resnet50_transforms = create_resnet50_model(num_classes=36,
16
+ seed=42)
17
+
18
+ # Load saved weights
19
+ resnet50.load_state_dict(
20
+ torch.load(
21
+ f="AMS.pth",
22
+ map_location=torch.device("cpu"), # load to CPU
23
+ )
24
+ )
25
+
26
+ ### 3. Predict function ###
27
+
28
+ # Create predict function
29
+ def predict(img) -> Tuple[Dict, float]:
30
+ """Transforms and performs a prediction on img and returns prediction and time taken.
31
+ """
32
+ # Start the timer
33
+ start_time = timer()
34
+
35
+ img = img.convert('RGB')
36
+ # Transform the target image using the ResNet50 transforms
37
+ img = resnet50_transforms(img).unsqueeze(0)
38
+
39
+ # Put the ResNet50 model into evaluation mode
40
+ resnet50.eval()
41
+ with torch.inference_mode():
42
+ # Pass the transformed image through the model and obtain the prediction logits
43
+ pred_logits = resnet50(img)
44
+
45
+ # Convert the prediction logits to probabilities using softmax
46
+ pred_probs = torch.softmax(pred_logits, dim=1)
47
+
48
+ # Create a prediction label and prediction probability dictionary for each prediction class
49
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
50
+
51
+ # Calculate the prediction time
52
+ pred_time = round(timer() - start_time, 5)
53
+
54
+ # Return the prediction dictionary and prediction time
55
+ return pred_labels_and_probs, pred_time
56
+
57
+ ### 4. Gradio app ###
58
+
59
+ import gradio as gr
60
+
61
+ # Create title, description and article strings
62
+ title = "AMERICA SIGN LAGNGUAGE"
63
+ description = "An resnet50 feature extractor computer vision model to classify american sign language ."
64
+ #article = "Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/)."
65
+
66
+ # Create the Gradio demo
67
+ demo = gr.Interface(fn=predict, # mapping function from input to output
68
+ inputs=gr.Image(type="pil"), # what are the inputs?
69
+ outputs=[gr.Label(num_top_classes=5, label="Predictions"), # what are the outputs?
70
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
71
+ examples=example_list,
72
+ title=title,
73
+ description=description,
74
+ )
75
+
76
+ # Launch the demo!
77
+ demo.launch()
78
+
model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torch.nn as nn
4
+
5
+ def create_resnet50_model(num_classes: int = 2, seed: int = 42):
6
+ """Creates a ResNet50 feature extractor model and transforms.
7
+
8
+ Args:
9
+ num_classes (int, optional): Number of classes in the classifier head.
10
+ Defaults to 2.
11
+ seed (int, optional): Random seed value. Defaults to 42.
12
+
13
+ Returns:
14
+ model (torch.nn.Module): ResNet50 feature extractor model.
15
+ transforms (torchvision.transforms): ResNet50 image transforms.
16
+ """
17
+ # 1. Create ResNet50 pretrained weights and transforms
18
+ weights = torchvision.models.resnet50(pretrained=True)
19
+ transforms = torchvision.transforms.Compose([
20
+ torchvision.transforms.Resize(256),
21
+ torchvision.transforms.CenterCrop(224),
22
+ torchvision.transforms.ToTensor(),
23
+ torchvision.transforms.Normalize(
24
+ mean=[0.485, 0.456, 0.406],
25
+ std=[0.229, 0.224, 0.225]
26
+ )
27
+ ])
28
+
29
+ # 2. Create ResNet50 model with pretrained weights
30
+ model = torchvision.models.resnet50(pretrained=False)
31
+
32
+ # 3. Load the pretrained weights into the model
33
+ model.load_state_dict(weights.state_dict())
34
+
35
+ # 4. Freeze all layers in the base model
36
+ for param in model.parameters():
37
+ param.requires_grad = False
38
+
39
+ # 5. Change classifier head with random seed for reproducibility
40
+ torch.manual_seed(seed)
41
+ num_features = model.fc.in_features
42
+ model.fc = nn.Sequential(
43
+ nn.Dropout(p=0.3, inplace=True),
44
+ nn.Linear(in_features=num_features, out_features=num_classes),
45
+ )
46
+
47
+ return model, transforms
48
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4