rnmee commited on
Commit
969f85e
·
verified ·
1 Parent(s): 8d8a5c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -1,23 +1,28 @@
1
  import torch
2
  import streamlit as st
 
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
 
8
- # Load models
9
  @st.cache_resource
10
  def load_models():
11
- # Load classification model (corrected filename)
12
- classification_model = torch.load("classifier.pt", map_location=torch.device("cpu"))
 
 
 
 
13
  classification_model.eval()
14
-
15
  # Load segmentation model
16
  segmentation_model = torch.load("best_unet_model.pth", map_location=torch.device("cpu"))
17
  segmentation_model.eval()
18
-
19
  return classification_model, segmentation_model
20
 
 
21
  classification_model, segmentation_model = load_models()
22
 
23
  # Define preprocessing function for classification
 
1
  import torch
2
  import streamlit as st
3
+ import torchvision.models as models
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
 
 
9
  @st.cache_resource
10
  def load_models():
11
+ # Load classification model architecture
12
+ classification_model = models.resnet152(pretrained=False) # Use the same architecture as before
13
+ num_ftrs = classification_model.fc.in_features
14
+ classification_model.fc = torch.nn.Linear(num_ftrs, 5) # Assuming 5 DR stages
15
+ classification_checkpoint = torch.load("classifier.pt", map_location=torch.device("cpu"))
16
+ classification_model.load_state_dict(classification_checkpoint["model_state_dict"])
17
  classification_model.eval()
18
+
19
  # Load segmentation model
20
  segmentation_model = torch.load("best_unet_model.pth", map_location=torch.device("cpu"))
21
  segmentation_model.eval()
22
+
23
  return classification_model, segmentation_model
24
 
25
+
26
  classification_model, segmentation_model = load_models()
27
 
28
  # Define preprocessing function for classification