CircleStar commited on
Commit
cfe30ee
·
verified ·
1 Parent(s): 14b719f

Update predict_utils.py

Browse files
Files changed (1) hide show
  1. predict_utils.py +12 -7
predict_utils.py CHANGED
@@ -3,8 +3,7 @@ import random
3
  import torch
4
  from PIL import Image
5
 
6
- from config import IMAGE_SIZE
7
- from data_utils import get_transform, load_charcoal_dataset
8
  from train_utils import load_model, get_runtime_device
9
 
10
 
@@ -19,7 +18,7 @@ def predict_uploaded_image(model_name: str, image: Image.Image):
19
  model, meta = load_model(model_name, device)
20
 
21
  class_names = meta["config"]["class_names"]
22
- transform = get_transform()
23
 
24
  image = image.convert("RGB")
25
  tensor = transform(image).unsqueeze(0).to(device)
@@ -48,17 +47,23 @@ def test_random_sample(model_name: str):
48
  device = get_runtime_device()
49
  model, meta = load_model(model_name, device)
50
 
51
- raw, class_names = load_charcoal_dataset()
52
- test_dataset = raw["test"]
 
53
 
54
  idx = random.randint(0, len(test_dataset) - 1)
55
  item = test_dataset[idx]
56
 
57
- image = item["image"].convert("RGB")
 
 
 
 
 
58
  label = int(item["label"])
59
  label_name = class_names[label]
60
 
61
- transform = get_transform()
62
  tensor = transform(image).unsqueeze(0).to(device)
63
 
64
  with torch.no_grad():
 
3
  import torch
4
  from PIL import Image
5
 
6
+ from data_utils import get_eval_transform, prepare_splits, get_class_names
 
7
  from train_utils import load_model, get_runtime_device
8
 
9
 
 
18
  model, meta = load_model(model_name, device)
19
 
20
  class_names = meta["config"]["class_names"]
21
+ transform = get_eval_transform()
22
 
23
  image = image.convert("RGB")
24
  tensor = transform(image).unsqueeze(0).to(device)
 
47
  device = get_runtime_device()
48
  model, meta = load_model(model_name, device)
49
 
50
+ splits = prepare_splits()
51
+ class_names = get_class_names()
52
+ test_dataset = splits["test"]
53
 
54
  idx = random.randint(0, len(test_dataset) - 1)
55
  item = test_dataset[idx]
56
 
57
+ image = item["image"]
58
+ if not isinstance(image, Image.Image):
59
+ image = Image.open(image)
60
+
61
+ image = image.convert("RGB")
62
+
63
  label = int(item["label"])
64
  label_name = class_names[label]
65
 
66
+ transform = get_eval_transform()
67
  tensor = transform(image).unsqueeze(0).to(device)
68
 
69
  with torch.no_grad():