itsyogesh commited on
Commit
c32414c
·
verified ·
1 Parent(s): 2f0e3c8

Update load image

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -43,12 +43,33 @@ class GOSNormalize(object):
43
 
44
  transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
45
 
46
- def load_image(im_path, hypar):
47
- im = im_reader(im_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  im, im_shp = im_preprocess(im, hypar["cache_size"])
49
- im = torch.divide(im,255.0)
50
  shape = torch.from_numpy(np.array(im_shp))
51
- return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
 
 
 
 
 
52
 
53
 
54
  def build_model(hypar,device):
 
43
 
44
  transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
45
 
46
+ def load_image(image, hypar):
47
+ """
48
+ Load and preprocess an image.
49
+ :param image: The image to load. This can be either a file path or a PIL.Image object.
50
+ :param hypar: Hyperparameters for preprocessing.
51
+ :return: A tuple of the preprocessed image tensor and its original shape.
52
+ """
53
+ # Check if the image is a file path or a PIL.Image object
54
+ if isinstance(image, str):
55
+ # If it's a file path, read the image from disk
56
+ im = im_reader(image)
57
+ elif isinstance(image, Image.Image):
58
+ # If it's a PIL.Image object, convert it to a NumPy array
59
+ im = np.array(image)
60
+ else:
61
+ raise TypeError("Unsupported image type")
62
+
63
+ # Preprocess the image
64
  im, im_shp = im_preprocess(im, hypar["cache_size"])
65
+ im = torch.divide(im, 255.0)
66
  shape = torch.from_numpy(np.array(im_shp))
67
+
68
+ # Normalize and add batch dimension
69
+ im = transform(im).unsqueeze(0)
70
+ shape = shape.unsqueeze(0) # Add batch dimension to shape
71
+
72
+ return im, shape
73
 
74
 
75
  def build_model(hypar,device):