koesan commited on
Commit
28fa4f2
·
verified ·
1 Parent(s): b4d1c34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -55,22 +55,28 @@ def preprocess_image(image_path):
55
  # Get original shape
56
  original_shape = img.shape
57
 
58
- # Resize to model input size (assuming 256x256)
59
- img_resized = cv2.resize(img, (256, 256))
60
 
61
  # Normalize to [0, 1]
62
  img_normalized = img_resized.astype(np.float32) / 255.0
63
 
64
- # Add batch and channel dimensions
65
- img_input = np.expand_dims(img_normalized, axis=0)
66
- img_input = np.expand_dims(img_input, axis=-1)
 
67
 
68
  return img_input, original_shape
69
 
70
  def postprocess_mask(mask, original_shape):
71
  """Postprocess segmentation mask"""
72
- # Remove batch dimension
73
- mask = np.squeeze(mask)
 
 
 
 
 
74
 
75
  # Resize back to original shape
76
  mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
 
55
  # Get original shape
56
  original_shape = img.shape
57
 
58
+ # Resize to model input size: 160x160 (NOT 256x256!)
59
+ img_resized = cv2.resize(img, (160, 160))
60
 
61
  # Normalize to [0, 1]
62
  img_normalized = img_resized.astype(np.float32) / 255.0
63
 
64
+ # Add batch and channel dimensions in channels_first format (NCHW)
65
+ # Model expects: (batch, channels, height, width) = (None, 1, 160, 160)
66
+ img_input = np.expand_dims(img_normalized, axis=0) # (1, 160, 160)
67
+ img_input = np.expand_dims(img_input, axis=0) # (1, 1, 160, 160)
68
 
69
  return img_input, original_shape
70
 
71
  def postprocess_mask(mask, original_shape):
72
  """Postprocess segmentation mask"""
73
+ # Mask comes in channels_first format: (batch, channels, height, width)
74
+ # Squeeze to remove batch and channel dimensions
75
+ mask = np.squeeze(mask) # (160, 160)
76
+
77
+ # If mask still has extra dimensions, squeeze again
78
+ while len(mask.shape) > 2:
79
+ mask = np.squeeze(mask)
80
 
81
  # Resize back to original shape
82
  mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))