Winston de Jong
commited on
Commit
·
1bf5f86
1
Parent(s):
0da216f
implement model
Browse files
app.py
CHANGED
|
@@ -44,7 +44,7 @@ def load_model(repo_id):
|
|
| 44 |
# Initialize the ResNet-18 architecture
|
| 45 |
model = torchvision.models.resnet18(pretrained=True) # TODO: does it matter if this is set to true or false?
|
| 46 |
num_ftrs = model.fc.in_features
|
| 47 |
-
model.fc = nn.Linear(num_ftrs,
|
| 48 |
# TODO: check if this number^^ corresponds to the number of classes
|
| 49 |
|
| 50 |
# Load the model weights
|
|
@@ -78,10 +78,13 @@ def process_image_str(groupImageFilePath: str):
|
|
| 78 |
# do AI stuff here
|
| 79 |
with torch.no_grad():
|
| 80 |
outputs_t = model(intputTensor)
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
outputLabels.append(pred_t.item())
|
| 83 |
|
| 84 |
#return gr.Image(image)
|
|
|
|
| 85 |
return outputLabels.pop(0)
|
| 86 |
|
| 87 |
|
|
|
|
| 44 |
# Initialize the ResNet-18 architecture
|
| 45 |
model = torchvision.models.resnet18(pretrained=True) # TODO: does it matter if this is set to true or false?
|
| 46 |
num_ftrs = model.fc.in_features
|
| 47 |
+
model.fc = nn.Linear(num_ftrs, 128) # Adjust for your task (e.g., 128 classes)
|
| 48 |
# TODO: check if this number^^ corresponds to the number of classes
|
| 49 |
|
| 50 |
# Load the model weights
|
|
|
|
| 78 |
# do AI stuff here
|
| 79 |
with torch.no_grad():
|
| 80 |
outputs_t = model(intputTensor)
|
| 81 |
+
print(outputs_t)
|
| 82 |
+
temp, pred_t = torch.max(outputs_t, dim=1)
|
| 83 |
+
print(temp)
|
| 84 |
outputLabels.append(pred_t.item())
|
| 85 |
|
| 86 |
#return gr.Image(image)
|
| 87 |
+
print(outputLabels)
|
| 88 |
return outputLabels.pop(0)
|
| 89 |
|
| 90 |
|