Subh775 commited on
Commit
3b68505
·
verified ·
1 Parent(s): dbf0507

Update Batch_Inference.py

Browse files
Files changed (1) hide show
  1. Batch_Inference.py +111 -118
Batch_Inference.py CHANGED
@@ -1,119 +1,112 @@
1
- import os
2
- from pathlib import Path
3
- import torch
4
- from PIL import Image
5
- import numpy as np
6
- import cv2
7
- import segmentation_models_pytorch as smp
8
- from huggingface_hub import hf_hub_download
9
- from tqdm import tqdm
10
-
11
- # --- 1. CONFIGURATION ---
12
- # --- IMPORTANT: UPDATE THESE VALUES ---
13
- HF_USERNAME = "Subh75"
14
- HF_ORGNAME="LeafNet75"
15
- MODEL_NAME = "Leaf-Annotate-v2"
16
- HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}"
17
-
18
- # --- Point this to your folder of unlabeled leaf images ---
19
- INPUT_IMAGE_DIR = "toast"
20
- # --- The script will save the generated masks here ---
21
- OUTPUT_MASK_DIR = "generated_masks"
22
-
23
- # --- Model & Prediction Configuration ---
24
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
- IMG_SIZE = 256
26
- CONFIDENCE_THRESHOLD = 0.5
27
-
28
-
29
- def load_model_from_hub(repo_id: str):
30
- """Loads the interactive segmentation model from the Hub."""
31
- print(f"Loading model '{repo_id}' from Hugging Face Hub...")
32
-
33
- model = smp.Unet(
34
- encoder_name="mobilenet_v2",
35
- encoder_weights=None,
36
- in_channels=4, # RGB + Scribble
37
- classes=1,
38
- )
39
-
40
- model_weights_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
41
- model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
42
- model.to(DEVICE)
43
- model.eval()
44
- print("Model loaded successfully.")
45
- return model
46
-
47
- def predict_scribble(model, pil_image, scribble_mask):
48
- """Runs inference using a scribble and returns a binary mask."""
49
- # Preprocess image and scribble
50
- img_resized = np.array(pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR))
51
- scribble_resized = cv2.resize(scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
52
-
53
- img_tensor = torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
54
- scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
55
-
56
- input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
57
-
58
- with torch.no_grad():
59
- output = model(input_tensor)
60
-
61
- probs = torch.sigmoid(output)
62
- binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
63
-
64
- # Resize mask back to original image size
65
- final_mask = cv2.resize(binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST)
66
- return (final_mask * 255).astype(np.uint8)
67
-
68
- def main():
69
- """
70
- Main function to run batch inference on a folder of images.
71
- """
72
- if not os.path.isdir(INPUT_IMAGE_DIR):
73
- print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'")
74
- print("Please update the 'INPUT_IMAGE_DIR' variable in the script.")
75
- return
76
-
77
- os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)
78
-
79
- model = load_model_from_hub(HF_MODEL_REPO_ID)
80
-
81
- image_files = [f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
82
-
83
- print(f"\nFound {len(image_files)} images to process.")
84
-
85
- for filename in tqdm(image_files, desc="Generating Masks"):
86
- image_path = os.path.join(INPUT_IMAGE_DIR, filename)
87
-
88
- try:
89
- original_image = Image.open(image_path).convert("RGB")
90
- h, w = original_image.height, original_image.width
91
-
92
- # --- Create a simple, centered scribble as a hint ---
93
- # This is a generic hint; for best results, scribbles should be more specific
94
- scribble = np.zeros((h, w), dtype=np.uint8)
95
- center_x, center_y = w // 2, h // 2
96
- length = int(min(w, h) * 0.2) # Scribble is 20% of the smallest dimension
97
-
98
- start_point = (center_x - length // 2, center_y)
99
- end_point = (center_x + length // 2, center_y)
100
- cv2.line(scribble, start_point, end_point, 255, thickness=25)
101
-
102
- # --- Predict and save the mask ---
103
- predicted_mask = predict_scribble(model, original_image, scribble)
104
-
105
- mask_image = Image.fromarray(predicted_mask)
106
-
107
- # Construct the output path
108
- base_name = Path(filename).stem
109
- output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}_mask.png")
110
-
111
- mask_image.save(output_path)
112
-
113
- except Exception as e:
114
- print(f"\nCould not process {filename}. Error: {e}")
115
-
116
- print(f"\n✅ Done! All generated masks have been saved to the '{OUTPUT_MASK_DIR}' folder.")
117
-
118
- if __name__ == "__main__":
119
  main()
 
1
+ import os
2
+ from pathlib import Path
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ import cv2
7
+ import segmentation_models_pytorch as smp
8
+ from huggingface_hub import hf_hub_download
9
+ from tqdm import tqdm
10
+
11
+
12
+ HF_USERNAME = "Subh75"
13
+ HF_ORGNAME="LeafNet75"
14
+ MODEL_NAME = "Leaf-Annotate-v2"
15
+ HF_MODEL_REPO_ID = f"{HF_ORGNAME}/{MODEL_NAME}"
16
+
17
+ # --- Point this to your folder of unlabeled leaf images ---
18
+ INPUT_IMAGE_DIR = "toast"
19
+ # --- The script will save the generated masks here ---
20
+ OUTPUT_MASK_DIR = "masks"
21
+
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+ IMG_SIZE = 256
24
+ CONFIDENCE_THRESHOLD = 0.5
25
+
26
+
27
+ def load_model_from_hub(repo_id: str):
28
+ """Loads the interactive segmentation model from the Hub."""
29
+ print(f"Loading model '{repo_id}' from Hugging Face Hub...")
30
+
31
+ model = smp.Unet(
32
+ encoder_name="mobilenet_v2",
33
+ encoder_weights=None,
34
+ in_channels=4, # RGB + Scribble
35
+ classes=1,
36
+ )
37
+
38
+ model_weights_path = hf_hub_download(repo_id=repo_id, filename="best_model.pth")
39
+ model.load_state_dict(torch.load(model_weights_path, map_location=DEVICE))
40
+ model.to(DEVICE)
41
+ model.eval()
42
+ print("Model loaded successfully.")
43
+ return model
44
+
45
+ def predict_scribble(model, pil_image, scribble_mask):
46
+ """Runs inference using a scribble and returns a binary mask."""
47
+ # Preprocess image and scribble
48
+ img_resized = np.array(pil_image.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.BILINEAR))
49
+ scribble_resized = cv2.resize(scribble_mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
50
+
51
+ img_tensor = torch.from_numpy(img_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
52
+ scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
53
+
54
+ input_tensor = torch.cat([img_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
55
+
56
+ with torch.no_grad():
57
+ output = model(input_tensor)
58
+
59
+ probs = torch.sigmoid(output)
60
+ binary_mask_resized = (probs > CONFIDENCE_THRESHOLD).float().squeeze().cpu().numpy()
61
+
62
+ final_mask = cv2.resize(binary_mask_resized, (pil_image.width, pil_image.height), interpolation=cv2.INTER_NEAREST)
63
+ return (final_mask * 255).astype(np.uint8)
64
+
65
+ def main():
66
+ """
67
+ Main function to run batch inference on a folder of images.
68
+ """
69
+ if not os.path.isdir(INPUT_IMAGE_DIR):
70
+ print(f"Error: Input directory not found at '{INPUT_IMAGE_DIR}'")
71
+ print("Please update the 'INPUT_IMAGE_DIR' variable in the script.")
72
+ return
73
+
74
+ os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)
75
+
76
+ model = load_model_from_hub(HF_MODEL_REPO_ID)
77
+
78
+ image_files = [f for f in os.listdir(INPUT_IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
79
+
80
+ print(f"\nFound {len(image_files)} images to process.")
81
+
82
+ for filename in tqdm(image_files, desc="Generating Masks"):
83
+ image_path = os.path.join(INPUT_IMAGE_DIR, filename)
84
+
85
+ try:
86
+ original_image = Image.open(image_path).convert("RGB")
87
+ h, w = original_image.height, original_image.width
88
+
89
+ scribble = np.zeros((h, w), dtype=np.uint8)
90
+ center_x, center_y = w // 2, h // 2
91
+ length = int(min(w, h) * 0.2) # Scribble is 20% of the smallest dimension
92
+
93
+ start_point = (center_x - length // 2, center_y)
94
+ end_point = (center_x + length // 2, center_y)
95
+ cv2.line(scribble, start_point, end_point, 255, thickness=25)
96
+
97
+ predicted_mask = predict_scribble(model, original_image, scribble)
98
+
99
+ mask_image = Image.fromarray(predicted_mask)
100
+
101
+ base_name = Path(filename).stem
102
+ output_path = os.path.join(OUTPUT_MASK_DIR, f"{base_name}_mask.png")
103
+
104
+ mask_image.save(output_path)
105
+
106
+ except Exception as e:
107
+ print(f"\nCould not process {filename}. Error: {e}")
108
+
109
+ print(f"\n Done! All generated masks have been saved to the '{OUTPUT_MASK_DIR}' folder.")
110
+
111
+ if __name__ == "__main__":
 
 
 
 
 
 
 
112
  main()