setsosie commited on
Commit
ea90889
·
verified ·
1 Parent(s): 5a71792

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -8,8 +8,8 @@ import torch.nn.functional as F
8
  import torchvision.transforms as T
9
 
10
  # Get pre-trained model
11
- #weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
12
- model = torchvision.models.resnet18()# weights=weights)
13
 
14
  # Set model to evaluation mode
15
  model.eval()
@@ -20,10 +20,13 @@ labels = r.text.split("\n")
20
 
21
  # Define prediction function
22
  def predict(img):
 
 
 
 
23
  # Transform image to pytorch tensor of shape [1, 3, 224, 224]
24
  img = T.PILToTensor()(img).unsqueeze(0)
25
  img = T.Resize(size=(224, 224))(img)
26
-
27
 
28
  # Use model without gradients to reduce computation
29
  with torch.no_grad():
@@ -38,4 +41,4 @@ gr.Interface(fn=predict,
38
  inputs=gr.Image(type="pil"),
39
  outputs=gr.Label(num_top_classes=10),
40
  theme="default",
41
- )
 
8
  import torchvision.transforms as T
9
 
10
  # Get pre-trained model
11
+ weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
12
+ model = torchvision.models.resnet18(weights=weights)
13
 
14
  # Set model to evaluation mode
15
  model.eval()
 
20
 
21
  # Define prediction function
22
  def predict(img):
23
+ '''
24
+ img: PIL image to be predicted
25
+ confidences: python dictionary containing confidences
26
+ '''
27
  # Transform image to pytorch tensor of shape [1, 3, 224, 224]
28
  img = T.PILToTensor()(img).unsqueeze(0)
29
  img = T.Resize(size=(224, 224))(img)
 
30
 
31
  # Use model without gradients to reduce computation
32
  with torch.no_grad():
 
41
  inputs=gr.Image(type="pil"),
42
  outputs=gr.Label(num_top_classes=10),
43
  theme="default",
44
+ ).launch()