Fiixq commited on
Commit
25c4656
·
verified ·
1 Parent(s): a7fbc52

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +3 -1
src/streamlit_app.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
  import torchvision.models as models
 
10
 
11
 
12
  def rgb2lab2(r0, g0, b0):
@@ -329,7 +330,8 @@ def prepare_test_image(img, dim=150):
329
 
330
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
331
 
332
- model_path = "Hyper_U_NET_pytorch-MAE-30Epoch.pth"
 
333
 
334
  test_model = load_model_for_inference(model_path, device)
335
 
 
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
  import torchvision.models as models
10
+ import os
11
 
12
 
13
  def rgb2lab2(r0, g0, b0):
 
330
 
331
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
332
 
333
+ model_path = os.path.join(os.getcwd(), 'Hyper_U_NET_pytorch-MAE-30Epoch.pth')
334
+ # model_path = "Hyper_U_NET_pytorch-MAE-30Epoch.pth"
335
 
336
  test_model = load_model_for_inference(model_path, device)
337