Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -108,7 +108,7 @@ def setup_model():
|
|
| 108 |
# Function to segment image
|
| 109 |
def segment_image(image):
|
| 110 |
|
| 111 |
-
image = cv2.imread(
|
| 112 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 113 |
|
| 114 |
outputs = predictor(image)
|
|
@@ -154,14 +154,14 @@ def segment_image(image):
|
|
| 154 |
# Load models
|
| 155 |
modernity_model = models.resnet18(pretrained=True)
|
| 156 |
modernity_model.fc = nn.Linear(modernity_model.fc.in_features, 5)
|
| 157 |
-
modernity_checkpoint = torch.load(
|
| 158 |
modernity_model.load_state_dict(modernity_checkpoint)
|
| 159 |
modernity_model.to(device)
|
| 160 |
modernity_model.eval()
|
| 161 |
|
| 162 |
typicality_model = models.resnet18(pretrained=True)
|
| 163 |
typicality_model.fc = nn.Linear(typicality_model.fc.in_features, 5)
|
| 164 |
-
typicality_checkpoint = torch.load(
|
| 165 |
typicality_model.load_state_dict(typicality_checkpoint)
|
| 166 |
typicality_model.to(device)
|
| 167 |
typicality_model.eval()
|
|
|
|
| 108 |
# Function to segment image
|
| 109 |
def segment_image(image):
|
| 110 |
|
| 111 |
+
image = cv2.imread(image)
|
| 112 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 113 |
|
| 114 |
outputs = predictor(image)
|
|
|
|
| 154 |
# Load models
|
| 155 |
modernity_model = models.resnet18(pretrained=True)
|
| 156 |
modernity_model.fc = nn.Linear(modernity_model.fc.in_features, 5)
|
| 157 |
+
modernity_checkpoint = torch.load('modernity.pth', map_location=device)
|
| 158 |
modernity_model.load_state_dict(modernity_checkpoint)
|
| 159 |
modernity_model.to(device)
|
| 160 |
modernity_model.eval()
|
| 161 |
|
| 162 |
typicality_model = models.resnet18(pretrained=True)
|
| 163 |
typicality_model.fc = nn.Linear(typicality_model.fc.in_features, 5)
|
| 164 |
+
typicality_checkpoint = torch.load('typicality.pth', map_location=device)
|
| 165 |
typicality_model.load_state_dict(typicality_checkpoint)
|
| 166 |
typicality_model.to(device)
|
| 167 |
typicality_model.eval()
|