File size: 4,050 Bytes
d72a22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""

DeepLabV3+ Inference Script for Coffee Leaf Rust

================================================



This script runs inference using a trained DeepLabV3+ model.

It generates BINARY masks for diseased areas:

    - 0   : Background / Healthy

    - 255 : Lesion (Rust)



Requirements:

    - torch

    - albumentations

    - segmentation-models-pytorch

    - opencv-python

"""

import os
import cv2
import torch
import numpy as np
import albumentations as A
import segmentation_models_pytorch as smp
from tqdm import tqdm

# ================= Configuration =================
# Path to the trained model checkpoint
MODEL_PATH = "./checkpoints/deeplab_binary_best.pth"

# Input Directory (Extracted leaves)
INPUT_DIR  = "./data/extracted_leaves"

# Output Directory (Where masks will be saved)
OUTPUT_DIR = "./data/inference_deeplab"

# Threshold to binarize probabilities (usually 0.5)
THRESHOLD = 0.5
IMG_SIZE = 512
# =================================================

def run_inference():
    # Ensure output directory exists
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    # --- Load Model ---
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model checkpoint not found at {MODEL_PATH}")
        return

    print("Loading DeepLabV3+ model...")
    model = smp.DeepLabV3Plus(
        encoder_name="resnet50",
        encoder_weights=None,  # Weights are loaded from checkpoint
        in_channels=3,
        classes=1
    )
    
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    except Exception as e:
        print(f"Error loading model weights: {e}")
        return

    model.to(device)
    model.eval()

    # --- Preprocessing ---
    # Must match the transforms used during training
    mean = [0.485, 0.456, 0.406]
    std  = [0.229, 0.224, 0.225]
    
    transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE),
        A.Normalize(mean=mean, std=std),
    ], is_check_shapes=False)

    print(f"Processing images from: {INPUT_DIR}")
    
    processed_count = 0
    
    # --- Inference Loop ---
    # Iterate through images
    img_files = sorted(os.listdir(INPUT_DIR))
    for fname in tqdm(img_files, desc="Running Inference"):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg", ".tif")):
            continue

        img_path = os.path.join(INPUT_DIR, fname)
        img = cv2.imread(img_path)

        if img is None:
            print(f"Skipping unreadable file: {fname}")
            continue

        original_h, original_w = img.shape[:2]

        # Convert BGR -> RGB and Apply Transforms
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        augmented = transform(image=img_rgb)
        input_tensor = augmented["image"]
        
        # (H,W,C) -> (1,C,H,W)
        input_tensor = torch.tensor(input_tensor).permute(2, 0, 1).unsqueeze(0).float().to(device)

        with torch.no_grad():
            logits = model(input_tensor)
            probs = torch.sigmoid(logits)
            
            # Binarize: > 0.5 becomes 1.0, else 0.0
            mask = (probs > THRESHOLD).float().cpu().numpy()[0, 0]

        # Convert to 0 -> 255 for Binary Image
        mask_uint8 = (mask * 255).astype(np.uint8)

        # Resize back to original image size -> Nearest Neighbor to keep it binary
        final_mask = cv2.resize(mask_uint8, (original_w, original_h), interpolation=cv2.INTER_NEAREST)

        # Save
        # Depending on convention, you might want to prepend "mask_" or keep same name
        out_name = os.path.splitext(fname)[0] + ".png"
        cv2.imwrite(os.path.join(OUTPUT_DIR, out_name), final_mask)
        processed_count += 1

    print(f"✅ Inference complete. Processed {processed_count} images.")
    print(f"Binary masks saved in: {OUTPUT_DIR}")

if __name__ == "__main__":
    run_inference()