Subh775 commited on
Commit
043f7ed
·
verified ·
1 Parent(s): af1c553

Update Batch_Inference.py

Browse files
Files changed (1) hide show
  1. Batch_Inference.py +49 -35
Batch_Inference.py CHANGED
@@ -10,13 +10,13 @@ from tqdm import tqdm
10
 
11
 
12
  HF_USERNAME = "Subh75"
13
- HF_ORGNAME="LeafNet75"
14
  MODEL_NAME = "Leaf-Annotate-v2"
15
  HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}"
16
 
17
- # Set to your original image and output folder respectively
18
- INPUT_IMAGE_DIR = "toast"
19
- OUTPUT_MASK_DIR = "masks"
20
 
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
  IMG_SIZE = 256
@@ -26,11 +26,11 @@ CONFIDENCE_THRESHOLD = 0.5
26
  def load_model_from_hub(repo_id: str):
27
  """Loads the interactive segmentation model from the Hub."""
28
  print(f"Loading model '{repo_id}' from Hugging Face Hub...")
29
-
30
  model = smp.Unet(
31
  encoder_name="mobilenet_v2",
32
  encoder_weights=None,
33
- in_channels=4, # RGB + Scribble
34
  classes=1,
35
  )
36
 
@@ -41,71 +41,85 @@ def load_model_from_hub(repo_id: str):
41
  print("Model loaded successfully.")
42
  return model
43
 
 
44
  def predict_scribble(model, pil_image, scribble_mask):
45
  """Runs inference using a scribble and returns a binary mask."""
46
- # Preprocess image and scribble
47
- img_resized = np.array(pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR))
48
- scribble_resized = cv2.resize(scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
 
 
 
49
 
50
- img_tensor = torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
51
- scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
 
 
 
 
52
 
53
  input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
54
-
55
  with torch.no_grad():
56
  output = model(input_tensor)
57
-
58
  probs = torch.sigmoid(output)
59
  binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
60
-
61
- final_mask = cv2.resize(binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST)
 
 
62
  return (final_mask * 255).astype(np.uint8)
63
 
 
64
  def main():
65
- """
66
- Main function to run batch inference on a folder of images.
67
- """
68
  if not os.path.isdir(INPUT_IMAGE_DIR):
69
  print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'")
70
- print("Please update the 'INPUT_IMAGE_DIR' variable in the script.")
71
  return
72
-
73
  os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)
74
-
75
  model = load_model_from_hub(HF_MODEL_REPO_ID)
76
-
77
- image_files = [f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
 
 
78
 
79
  print(f"\nFound {len(image_files)} images to process.")
80
-
81
  for filename in tqdm(image_files, desc="Generating Masks"):
82
  image_path = os.path.join(INPUT_IMAGE_DIR, filename)
83
-
84
  try:
85
  original_image = Image.open(image_path).convert("RGB")
86
  h, w = original_image.height, original_image.width
87
 
 
88
  scribble = np.zeros((h, w), dtype=np.uint8)
89
  center_x, center_y = w // 2, h // 2
90
- length = int(min(w, h) * 0.2)
91
-
92
  start_point = (center_x - length // 2, center_y)
93
  end_point = (center_x + length // 2, center_y)
94
  cv2.line(scribble, start_point, end_point, 255, thickness=25)
95
-
 
96
  predicted_mask = predict_scribble(model, original_image, scribble)
97
-
98
  mask_image = Image.fromarray(predicted_mask)
99
-
 
100
  base_name = Path(filename).stem
101
- output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}_mask.png")
102
-
103
  mask_image.save(output_path)
104
 
105
  except Exception as e:
106
- print(f"\nCould not process {filename}. Error: {e}")
 
 
107
 
108
- print(f"\n Done! All generated masks have been saved to the '{OUTPUT_MASK_DIR}' folder.")
109
 
110
  if __name__ == "__main__":
111
- main()
 
10
 
11
 
12
  HF_USERNAME = "Subh75"
13
+ HF_ORGNAME = "LeafNet75"
14
  MODEL_NAME = "Leaf-Annotate-v2"
15
  HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}"
16
 
17
+ # Set to your original image and output folder respectively
18
+ INPUT_IMAGE_DIR = "newimgs/images"
19
+ OUTPUT_MASK_DIR = "newimgs/masks"
20
 
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
  IMG_SIZE = 256
 
26
  def load_model_from_hub(repo_id: str):
27
  """Loads the interactive segmentation model from the Hub."""
28
  print(f"Loading model '{repo_id}' from Hugging Face Hub...")
29
+
30
  model = smp.Unet(
31
  encoder_name="mobilenet_v2",
32
  encoder_weights=None,
33
+ in_channels=4, # RGB + Scribble
34
  classes=1,
35
  )
36
 
 
41
  print("Model loaded successfully.")
42
  return model
43
 
44
+
45
  def predict_scribble(model, pil_image, scribble_mask):
46
  """Runs inference using a scribble and returns a binary mask."""
47
+ img_resized = np.array(
48
+ pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR)
49
+ )
50
+ scribble_resized = cv2.resize(
51
+ scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST
52
+ )
53
 
54
+ img_tensor = (
55
+ torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
56
+ )
57
+ scribble_tensor = (
58
+ torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
59
+ )
60
 
61
  input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
62
+
63
  with torch.no_grad():
64
  output = model(input_tensor)
65
+
66
  probs = torch.sigmoid(output)
67
  binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
68
+
69
+ final_mask = cv2.resize(
70
+ binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST
71
+ )
72
  return (final_mask * 255).astype(np.uint8)
73
 
74
+
75
  def main():
76
+ """Main function to run batch inference on a folder of images."""
 
 
77
  if not os.path.isdir(INPUT_IMAGE_DIR):
78
  print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'")
 
79
  return
80
+
81
  os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)
82
+
83
  model = load_model_from_hub(HF_MODEL_REPO_ID)
84
+
85
+ image_files = [
86
+ f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith((".png", ".jpg", ".jpeg"))
87
+ ]
88
 
89
  print(f"\nFound {len(image_files)} images to process.")
90
+
91
  for filename in tqdm(image_files, desc="Generating Masks"):
92
  image_path = os.path.join(INPUT_IMAGE_DIR, filename)
93
+
94
  try:
95
  original_image = Image.open(image_path).convert("RGB")
96
  h, w = original_image.height, original_image.width
97
 
98
+ # Create a dummy scribble (center line)
99
  scribble = np.zeros((h, w), dtype=np.uint8)
100
  center_x, center_y = w // 2, h // 2
101
+ length = int(min(w, h) * 0.2)
102
+
103
  start_point = (center_x - length // 2, center_y)
104
  end_point = (center_x + length // 2, center_y)
105
  cv2.line(scribble, start_point, end_point, 255, thickness=25)
106
+
107
+ # Predict mask
108
  predicted_mask = predict_scribble(model, original_image, scribble)
109
+
110
  mask_image = Image.fromarray(predicted_mask)
111
+
112
+ # Keep same base name, save as .png in OUTPUT_MASK_DIR
113
  base_name = Path(filename).stem
114
+ output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}.png")
115
+
116
  mask_image.save(output_path)
117
 
118
  except Exception as e:
119
+ print(f"\n❌ Could not process {filename}. Error: {e}")
120
+
121
+ print(f"\n Done! Masks saved in '{OUTPUT_MASK_DIR}' with same names as input images.")
122
 
 
123
 
124
  if __name__ == "__main__":
125
+ main()