Dollaya Piumsuwan commited on
Commit
014812f
·
verified ·
1 Parent(s): eccdd97

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +5 -8
src/streamlit_app.py CHANGED
@@ -16,13 +16,10 @@ from pathlib import Path
16
  #########################
17
  # SETTINGS
18
  # ########################
19
- base_path = Path(__file__).resolve().parent.parent
20
- model_file = base_path / "models" / "momaclassifier_resnet50.pt"
21
- image_csv = base_path / "data" / "demo_artworks.csv"
22
- image_folder = base_path / "demo_images"
23
- model_file = str(model_file)
24
- image_csv = str(image_csv)
25
- image_folder = str(image_folder)
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  class_index = {"Drawing": 0, "Photograph": 1, "Print": 2}
28
 
@@ -49,7 +46,7 @@ metadata_df = load_metadata()
49
  @st.cache_resource
50
  def load_model():
51
  num_class = len(class_index)
52
- model = models.resnet50(pretrained=False)
53
  model.fc = nn.Linear(model.fc.in_features, num_class)
54
  model.load_state_dict(torch.load(model_file, map_location=device))
55
  model.to(device)
 
16
  #########################
17
  # SETTINGS
18
  # ########################
19
+ base_path = Path(__file__).resolve().parent
20
+ model_file = str(base_path / "models" / "momaclassifier_resnet50.pt")
21
+ image_csv = str(base_path / "data" / "demo_artworks.csv")
22
+ image_folder = str(base_path / "demo_images")
 
 
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  class_index = {"Drawing": 0, "Photograph": 1, "Print": 2}
25
 
 
46
  @st.cache_resource
47
  def load_model():
48
  num_class = len(class_index)
49
+ model = models.resnet50(weights=None)
50
  model.fc = nn.Linear(model.fc.in_features, num_class)
51
  model.load_state_dict(torch.load(model_file, map_location=device))
52
  model.to(device)