Update model.py
Browse files
model.py
CHANGED
|
@@ -2,24 +2,26 @@ import torch
|
|
| 2 |
import torchvision
|
| 3 |
from torch import nn
|
| 4 |
|
| 5 |
-
def create_model(num_classes:
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
param.requires_grad = False
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
# Replace the final classifier layer with a custom head
|
| 17 |
-
model.fc = nn.Sequential(
|
| 18 |
-
nn.Linear(2048, 1000),
|
| 19 |
-
nn.ReLU(),
|
| 20 |
-
nn.Linear(1000, 500),
|
| 21 |
-
nn.Dropout(),
|
| 22 |
-
nn.Linear(in_features=500, out_features=num_classes, bias=True)
|
| 23 |
)
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
| 2 |
import torchvision
|
| 3 |
from torch import nn
|
| 4 |
|
| 5 |
+
def create_model(num_classes:int=2,
|
| 6 |
+
seed:int=42):
|
| 7 |
+
weights=torchvision.models.ResNet50_Weights.DEFAULT
|
| 8 |
+
transforms=weights.transforms()
|
| 9 |
+
model=torchvision.models.resnet50(weights=weights)
|
| 10 |
|
| 11 |
+
for param in model.parameters():
|
| 12 |
+
param.requires_grad=False
|
|
|
|
| 13 |
|
| 14 |
+
torch.manual_seed(42)
|
| 15 |
+
model.fc= torch.nn.Sequential(
|
| 16 |
+
torch.nn.Linear(2048,1000),
|
| 17 |
+
torch.nn.ReLU(),
|
| 18 |
+
torch.nn.Linear(1000,500),
|
| 19 |
+
torch.nn.Dropout(),
|
| 20 |
+
torch.nn.Linear(in_features=500,
|
| 21 |
+
out_features=num_classes,
|
| 22 |
+
bias=True)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
| 25 |
|
| 26 |
+
return model,transforms
|
| 27 |
+
|