Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -66,7 +66,7 @@ def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
|
|
| 66 |
|
| 67 |
|
| 68 |
# Load the CSRNet model
|
| 69 |
-
csrmodel = CSRNet()
|
| 70 |
checkpoint = torch.load("model.pt")
|
| 71 |
csrmodel.load_state_dict(checkpoint)
|
| 72 |
csrmodel.eval()
|
|
@@ -82,7 +82,7 @@ transform = transforms.Compose([
|
|
| 82 |
# Define the prediction function
|
| 83 |
def predict_count(input_image):
|
| 84 |
# Preprocess the input image
|
| 85 |
-
image = transform(input_image).unsqueeze(0)
|
| 86 |
|
| 87 |
# Perform the forward pass
|
| 88 |
output = csrmodel(image)
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
# Load the CSRNet model
|
| 69 |
+
csrmodel = CSRNet(load_weights=True).cpu()
|
| 70 |
checkpoint = torch.load("model.pt")
|
| 71 |
csrmodel.load_state_dict(checkpoint)
|
| 72 |
csrmodel.eval()
|
|
|
|
| 82 |
# Define the prediction function
|
| 83 |
def predict_count(input_image):
|
| 84 |
# Preprocess the input image
|
| 85 |
+
image = transform(input_image).unsqueeze(0).cpu()
|
| 86 |
|
| 87 |
# Perform the forward pass
|
| 88 |
output = csrmodel(image)
|