Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +3 -1
src/streamlit_app.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch
|
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
import torchvision.models as models
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def rgb2lab2(r0, g0, b0):
|
|
@@ -329,7 +330,8 @@ def prepare_test_image(img, dim=150):
|
|
| 329 |
|
| 330 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 331 |
|
| 332 |
-
model_path =
|
|
|
|
| 333 |
|
| 334 |
test_model = load_model_for_inference(model_path, device)
|
| 335 |
|
|
|
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
import torchvision.models as models
|
| 10 |
+
import os
|
| 11 |
|
| 12 |
|
| 13 |
def rgb2lab2(r0, g0, b0):
|
|
|
|
| 330 |
|
| 331 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 332 |
|
| 333 |
+
model_path = os.path.join(os.getcwd(), 'Hyper_U_NET_pytorch-MAE-30Epoch.pth')
|
| 334 |
+
# model_path = "Hyper_U_NET_pytorch-MAE-30Epoch.pth"
|
| 335 |
|
| 336 |
test_model = load_model_for_inference(model_path, device)
|
| 337 |
|