Chukwuka commited on
Commit
cdb9930
·
1 Parent(s): fa1b904

Added Torch Library

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  # import torch
3
- # from torch import nn
4
  import torch.nn.functional as F
5
  import torchvision
6
  from torchvision.models import resnet50, ResNet50_Weights
@@ -51,7 +51,7 @@ class FlowerClassificationModel(ImageClassificationBase):
51
  self.network = self.network = resnet50(weights=ResNet50_Weights.DEFAULT)
52
 
53
  else:
54
- # 1. Get the base mdoel with pretrained weights and send to target device
55
  self.network = torchvision.models.resnet50(pretrained=True)
56
 
57
  for param in self.network.parameters():
 
1
 
2
  # import torch
3
+ from torch import nn
4
  import torch.nn.functional as F
5
  import torchvision
6
  from torchvision.models import resnet50, ResNet50_Weights
 
51
  self.network = self.network = resnet50(weights=ResNet50_Weights.DEFAULT)
52
 
53
  else:
54
+ # 1. Get the base model with pretrained weights and send to target device
55
  self.network = torchvision.models.resnet50(pretrained=True)
56
 
57
  for param in self.network.parameters():