hamsteryang commited on
Commit
7998818
·
1 Parent(s): 99fe99c

update app.py

Browse files

edit by chatGPT

Files changed (1) hide show
  1. app.py +18 -37
app.py CHANGED
@@ -9,24 +9,13 @@ import torchvision
9
 
10
  from torch import nn
11
 
 
12
 
13
- def create_effnetb2_model(num_classes: int = 1,
14
- seed: int = 42):
15
- """Creates an EfficientNetB2 feature extractor model and transforms.
16
-
17
- Args:
18
- num_classes (int, optional): number of classes in the classifier head.
19
- Defaults to 3.
20
- seed (int, optional): random seed value. Defaults to 42.
21
-
22
- Returns:
23
- model (torch.nn.Module): EffNetB2 feature extractor model.
24
- transforms (torchvision.transforms): EffNetB2 image transforms.
25
- """
26
- # Create EffNetB2 pretrained weights, transforms and model
27
- weights = torchvision.models.AlexNet_Weights.DEFAULT
28
- transforms = weights.transforms()
29
- model = torchvision.models.alexnet(weights=weights)
30
 
31
  # Freeze all layers in base model
32
  for param in model.parameters():
@@ -34,31 +23,23 @@ def create_effnetb2_model(num_classes: int = 1,
34
 
35
  # Change classifier head with random seed for reproducibility
36
  torch.manual_seed(seed)
37
- model.classifier = nn.Sequential(
38
- nn.Dropout(p=0.2,),
39
- nn.Linear(in_features=9216, out_features=1),
40
- )
41
-
42
- return model, transforms
43
-
44
 
45
- # Setup class names
46
- class_names = ["Normal", "Pneumonia"]
 
 
 
 
47
 
48
- ### 2. Model and transforms preparation ###
49
 
50
- # Create EffNetB2 model
51
- effnetb2, effnetb2_transforms = create_effnetb2_model(
52
- num_classes=1, # len(class_names) would also work
53
- )
54
 
55
  # Load saved weights
56
- effnetb2.load_state_dict(
57
- torch.load(
58
- f="FL_global_model.pt",
59
- map_location=torch.device("cpu"), # load to CPU
60
- )
61
- )
62
 
63
 
64
  def predict(img) -> Tuple[Dict, float]:
 
9
 
10
  from torch import nn
11
 
12
+ from torchvision.models import densenet121
13
 
14
+ def create_densenet121_model(num_classes: int = 1, seed: int = 42):
15
+ """Creates a DenseNet121 model and transforms."""
16
+
17
+ # Create DenseNet121 model
18
+ model = densenet121(pretrained=False) # Set to False since we will be loading our own weights
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Freeze all layers in base model
21
  for param in model.parameters():
 
23
 
24
  # Change classifier head with random seed for reproducibility
25
  torch.manual_seed(seed)
26
+ model.classifier = nn.Linear(model.classifier.in_features, num_classes)
 
 
 
 
 
 
27
 
28
+ # You might want to use the appropriate transforms for densenet121 here
29
+ transforms = torchvision.transforms.Compose([
30
+ torchvision.transforms.Resize(224),
31
+ torchvision.transforms.ToTensor(),
32
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
+ ])
34
 
35
+ return model, transforms
36
 
37
+ # Create densenet121 model
38
+ densenet, densenet_transforms = create_densenet121_model(num_classes=1)
 
 
39
 
40
  # Load saved weights
41
+ densenet.load_state_dict(torch.load("FL_global_model.pt", map_location=torch.device("cpu")))
42
+
 
 
 
 
43
 
44
 
45
  def predict(img) -> Tuple[Dict, float]: