basementparking commited on
Commit
0af0b1c
·
verified ·
1 Parent(s): c2faf5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -108,7 +108,7 @@ def setup_model():
108
  # Function to segment image
109
  def segment_image(image):
110
 
111
- image = cv2.imread(image_path)
112
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
113
 
114
  outputs = predictor(image)
@@ -154,14 +154,14 @@ def segment_image(image):
154
  # Load models
155
  modernity_model = models.resnet18(pretrained=True)
156
  modernity_model.fc = nn.Linear(modernity_model.fc.in_features, 5)
157
- modernity_checkpoint = torch.load(path +'modernity.pth', map_location=device)
158
  modernity_model.load_state_dict(modernity_checkpoint)
159
  modernity_model.to(device)
160
  modernity_model.eval()
161
 
162
  typicality_model = models.resnet18(pretrained=True)
163
  typicality_model.fc = nn.Linear(typicality_model.fc.in_features, 5)
164
- typicality_checkpoint = torch.load(path + 'typicality.pth', map_location=device)
165
  typicality_model.load_state_dict(typicality_checkpoint)
166
  typicality_model.to(device)
167
  typicality_model.eval()
 
108
  # Function to segment image
109
  def segment_image(image):
110
 
111
+ image = cv2.imread(image)
112
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
113
 
114
  outputs = predictor(image)
 
154
  # Load models
155
  modernity_model = models.resnet18(pretrained=True)
156
  modernity_model.fc = nn.Linear(modernity_model.fc.in_features, 5)
157
+ modernity_checkpoint = torch.load('modernity.pth', map_location=device)
158
  modernity_model.load_state_dict(modernity_checkpoint)
159
  modernity_model.to(device)
160
  modernity_model.eval()
161
 
162
  typicality_model = models.resnet18(pretrained=True)
163
  typicality_model.fc = nn.Linear(typicality_model.fc.in_features, 5)
164
+ typicality_checkpoint = torch.load('typicality.pth', map_location=device)
165
  typicality_model.load_state_dict(typicality_checkpoint)
166
  typicality_model.to(device)
167
  typicality_model.eval()