File size: 4,050 Bytes
d72a22b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
DeepLabV3+ Inference Script for Coffee Leaf Rust
================================================
This script runs inference using a trained DeepLabV3+ model.
It generates BINARY masks for diseased areas:
- 0 : Background / Healthy
- 255 : Lesion (Rust)
Requirements:
- torch
- albumentations
- segmentation-models-pytorch
- opencv-python
"""
import os
import cv2
import torch
import numpy as np
import albumentations as A
import segmentation_models_pytorch as smp
from tqdm import tqdm
# ================= Configuration =================
# Path to the trained model checkpoint
MODEL_PATH = "./checkpoints/deeplab_binary_best.pth"
# Input Directory (Extracted leaves)
INPUT_DIR = "./data/extracted_leaves"
# Output Directory (Where masks will be saved)
OUTPUT_DIR = "./data/inference_deeplab"
# Threshold to binarize probabilities (usually 0.5)
THRESHOLD = 0.5
IMG_SIZE = 512
# =================================================
def run_inference():
# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
# --- Load Model ---
if not os.path.exists(MODEL_PATH):
print(f"Error: Model checkpoint not found at {MODEL_PATH}")
return
print("Loading DeepLabV3+ model...")
model = smp.DeepLabV3Plus(
encoder_name="resnet50",
encoder_weights=None, # Weights are loaded from checkpoint
in_channels=3,
classes=1
)
try:
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
except Exception as e:
print(f"Error loading model weights: {e}")
return
model.to(device)
model.eval()
# --- Preprocessing ---
# Must match the transforms used during training
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform = A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.Normalize(mean=mean, std=std),
], is_check_shapes=False)
print(f"Processing images from: {INPUT_DIR}")
processed_count = 0
# --- Inference Loop ---
# Iterate through images
img_files = sorted(os.listdir(INPUT_DIR))
for fname in tqdm(img_files, desc="Running Inference"):
if not fname.lower().endswith((".png", ".jpg", ".jpeg", ".tif")):
continue
img_path = os.path.join(INPUT_DIR, fname)
img = cv2.imread(img_path)
if img is None:
print(f"Skipping unreadable file: {fname}")
continue
original_h, original_w = img.shape[:2]
# Convert BGR -> RGB and Apply Transforms
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
augmented = transform(image=img_rgb)
input_tensor = augmented["image"]
# (H,W,C) -> (1,C,H,W)
input_tensor = torch.tensor(input_tensor).permute(2, 0, 1).unsqueeze(0).float().to(device)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.sigmoid(logits)
# Binarize: > 0.5 becomes 1.0, else 0.0
mask = (probs > THRESHOLD).float().cpu().numpy()[0, 0]
# Convert to 0 -> 255 for Binary Image
mask_uint8 = (mask * 255).astype(np.uint8)
# Resize back to original image size -> Nearest Neighbor to keep it binary
final_mask = cv2.resize(mask_uint8, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
# Save
# Depending on convention, you might want to prepend "mask_" or keep same name
out_name = os.path.splitext(fname)[0] + ".png"
cv2.imwrite(os.path.join(OUTPUT_DIR, out_name), final_mask)
processed_count += 1
print(f"✅ Inference complete. Processed {processed_count} images.")
print(f"Binary masks saved in: {OUTPUT_DIR}")
if __name__ == "__main__":
run_inference()
|