seg-models / pv_panel_models /predict_all.py
Mohamed-ENNHIRI
Solar Panel Segmentation app for HF Spaces
52efd90
"""
Batch prediction script for all 4 segmentation models.
Runs inference on all images from filtered_images folder and saves masks as PNGs.
Also generates a manifest.json for the comparison web tool.
"""
import os
import sys
import json
import glob
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
# ── Paths ──────────────────────────────────────────────────────────────────────
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INPUT_DIR = "/home/mohamed-ennhiri/Desktop/images_dept_69/69_/filtered_images"
OUTPUT_DIR = os.path.join(BASE_DIR, "predictions")
IMAGE_SIZE = 128
THRESHOLD = 0.5
# Model configs: (name, model_dir, model_class_import, checkpoint, needs_sigmoid)
MODEL_CONFIGS = [
{
"name": "SegNet (CNN)",
"short_name": "segnet",
"model_dir": os.path.join(BASE_DIR, "cnn_model"),
"checkpoint": os.path.join(BASE_DIR, "cnn_model", "checkpoints", "best_model.pth"),
"needs_sigmoid": False, # SegNet applies sigmoid in forward()
},
{
"name": "UNet",
"short_name": "unet",
"model_dir": os.path.join(BASE_DIR, "unet_model"),
"checkpoint": os.path.join(BASE_DIR, "unet_model", "checkpoints", "best_model.pth"),
"needs_sigmoid": True,
},
{
"name": "SegFormer-B0",
"short_name": "segformer_b0",
"model_dir": os.path.join(BASE_DIR, "vit_model"),
"checkpoint": os.path.join(BASE_DIR, "vit_model", "checkpoints", "best_model.pth"),
"needs_sigmoid": True,
},
{
"name": "SegFormer-B5",
"short_name": "segformer_b5",
"model_dir": os.path.join(BASE_DIR, "segformer_b5_model"),
"checkpoint": os.path.join(BASE_DIR, "segformer_b5_model", "checkpoints", "best_model.pth"),
"needs_sigmoid": True,
},
]
def load_model(config, device):
"""Load a model from its config and checkpoint."""
model_dir = config["model_dir"]
# Add model dir to path for imports
if model_dir not in sys.path:
sys.path.insert(0, model_dir)
short = config["short_name"]
if short == "segnet":
from cnn_segmenter import SegNet
model = SegNet(in_channels=3, out_channels=1)
elif short == "unet":
from unet_model import UNet
model = UNet(in_channels=3, out_channels=1)
elif short == "segformer_b0":
from segformer_model import SegformerModel
model = SegformerModel(pretrained_name="nvidia/mit-b0", num_classes=1)
elif short == "segformer_b5":
from segformer_model import SegformerModel
model = SegformerModel(pretrained_name="nvidia/mit-b5", num_classes=1)
# Load checkpoint
checkpoint = torch.load(config["checkpoint"], map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
# Remove model dir from path to avoid import conflicts
if model_dir in sys.path:
sys.path.remove(model_dir)
print(f" ✓ Loaded {config['name']} from {os.path.basename(config['checkpoint'])}")
return model
def get_transform():
"""Get the preprocessing transform (same as training — no normalization)."""
return transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
])
def predict_single(model, image_tensor, needs_sigmoid, device):
"""Run inference on a single image tensor."""
with torch.no_grad():
image_tensor = image_tensor.unsqueeze(0).to(device)
output = model(image_tensor)
if needs_sigmoid:
output = torch.sigmoid(output)
mask = (output > THRESHOLD).float()
mask = mask.squeeze().cpu().numpy()
return mask
def save_mask(mask, save_path, original_size):
"""Save a binary mask as a PNG, resized to the original image dimensions."""
mask_uint8 = (mask * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_uint8, mode='L')
mask_img = mask_img.resize(original_size, Image.NEAREST)
mask_img.save(save_path)
def count_images(input_dir):
"""Count the number of jpg images."""
return len(glob.glob(os.path.join(input_dir, "*.jpg")))
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Create output directories
for config in MODEL_CONFIGS:
os.makedirs(os.path.join(OUTPUT_DIR, config["short_name"]), exist_ok=True)
# Get image list
image_paths = sorted(glob.glob(os.path.join(INPUT_DIR, "*.jpg")))
print(f"Found {len(image_paths)} images to process\n")
if len(image_paths) == 0:
print("No images found!")
return
transform = get_transform()
# Load all models
print("Loading models...")
models = {}
for config in MODEL_CONFIGS:
try:
models[config["short_name"]] = load_model(config, device)
except Exception as e:
print(f" ✗ Failed to load {config['name']}: {e}")
print(f"\nLoaded {len(models)}/{len(MODEL_CONFIGS)} models\n")
# Run predictions
manifest = {
"image_dir": INPUT_DIR,
"prediction_dir": OUTPUT_DIR,
"models": [c["name"] for c in MODEL_CONFIGS if c["short_name"] in models],
"model_short_names": [c["short_name"] for c in MODEL_CONFIGS if c["short_name"] in models],
"images": []
}
for img_path in tqdm(image_paths, desc="Predicting"):
img_name = os.path.basename(img_path)
img_stem = os.path.splitext(img_name)[0]
try:
# Load original image
original_img = Image.open(img_path).convert("RGB")
original_size = original_img.size # (W, H)
# Preprocess
img_tensor = transform(original_img)
# Predict with each model
image_entry = {
"filename": img_name,
"original_path": f"source_images/{img_name}",
"masks": {}
}
for config in MODEL_CONFIGS:
short = config["short_name"]
if short not in models:
continue
mask = predict_single(models[short], img_tensor, config["needs_sigmoid"], device)
mask_filename = f"{img_stem}.png"
mask_path = os.path.join(OUTPUT_DIR, short, mask_filename)
save_mask(mask, mask_path, original_size)
image_entry["masks"][short] = os.path.join("predictions", short, mask_filename)
manifest["images"].append(image_entry)
except Exception as e:
print(f"\n ✗ Error processing {img_name}: {e}")
# Save manifest
manifest_path = os.path.join(OUTPUT_DIR, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
print(f"\n✓ Predictions complete!")
print(f" Images processed: {len(manifest['images'])}")
print(f" Output directory: {OUTPUT_DIR}")
print(f" Manifest: {manifest_path}")
if __name__ == "__main__":
main()