""" Batch prediction script for all 4 segmentation models. Runs inference on all images from filtered_images folder and saves masks as PNGs. Also generates a manifest.json for the comparison web tool. """ import os import sys import json import glob import torch import numpy as np from PIL import Image from torchvision import transforms from tqdm import tqdm # ── Paths ────────────────────────────────────────────────────────────────────── BASE_DIR = os.path.dirname(os.path.abspath(__file__)) INPUT_DIR = "/home/mohamed-ennhiri/Desktop/images_dept_69/69_/filtered_images" OUTPUT_DIR = os.path.join(BASE_DIR, "predictions") IMAGE_SIZE = 128 THRESHOLD = 0.5 # Model configs: (name, model_dir, model_class_import, checkpoint, needs_sigmoid) MODEL_CONFIGS = [ { "name": "SegNet (CNN)", "short_name": "segnet", "model_dir": os.path.join(BASE_DIR, "cnn_model"), "checkpoint": os.path.join(BASE_DIR, "cnn_model", "checkpoints", "best_model.pth"), "needs_sigmoid": False, # SegNet applies sigmoid in forward() }, { "name": "UNet", "short_name": "unet", "model_dir": os.path.join(BASE_DIR, "unet_model"), "checkpoint": os.path.join(BASE_DIR, "unet_model", "checkpoints", "best_model.pth"), "needs_sigmoid": True, }, { "name": "SegFormer-B0", "short_name": "segformer_b0", "model_dir": os.path.join(BASE_DIR, "vit_model"), "checkpoint": os.path.join(BASE_DIR, "vit_model", "checkpoints", "best_model.pth"), "needs_sigmoid": True, }, { "name": "SegFormer-B5", "short_name": "segformer_b5", "model_dir": os.path.join(BASE_DIR, "segformer_b5_model"), "checkpoint": os.path.join(BASE_DIR, "segformer_b5_model", "checkpoints", "best_model.pth"), "needs_sigmoid": True, }, ] def load_model(config, device): """Load a model from its config and checkpoint.""" model_dir = config["model_dir"] # Add model dir to path for imports if model_dir not in sys.path: sys.path.insert(0, model_dir) short = config["short_name"] if short == "segnet": from cnn_segmenter import SegNet model = SegNet(in_channels=3, out_channels=1) elif short == "unet": from unet_model import UNet model = UNet(in_channels=3, out_channels=1) elif short == "segformer_b0": from segformer_model import SegformerModel model = SegformerModel(pretrained_name="nvidia/mit-b0", num_classes=1) elif short == "segformer_b5": from segformer_model import SegformerModel model = SegformerModel(pretrained_name="nvidia/mit-b5", num_classes=1) # Load checkpoint checkpoint = torch.load(config["checkpoint"], map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval() # Remove model dir from path to avoid import conflicts if model_dir in sys.path: sys.path.remove(model_dir) print(f" ✓ Loaded {config['name']} from {os.path.basename(config['checkpoint'])}") return model def get_transform(): """Get the preprocessing transform (same as training — no normalization).""" return transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), ]) def predict_single(model, image_tensor, needs_sigmoid, device): """Run inference on a single image tensor.""" with torch.no_grad(): image_tensor = image_tensor.unsqueeze(0).to(device) output = model(image_tensor) if needs_sigmoid: output = torch.sigmoid(output) mask = (output > THRESHOLD).float() mask = mask.squeeze().cpu().numpy() return mask def save_mask(mask, save_path, original_size): """Save a binary mask as a PNG, resized to the original image dimensions.""" mask_uint8 = (mask * 255).astype(np.uint8) mask_img = Image.fromarray(mask_uint8, mode='L') mask_img = mask_img.resize(original_size, Image.NEAREST) mask_img.save(save_path) def count_images(input_dir): """Count the number of jpg images.""" return len(glob.glob(os.path.join(input_dir, "*.jpg"))) def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # Create output directories for config in MODEL_CONFIGS: os.makedirs(os.path.join(OUTPUT_DIR, config["short_name"]), exist_ok=True) # Get image list image_paths = sorted(glob.glob(os.path.join(INPUT_DIR, "*.jpg"))) print(f"Found {len(image_paths)} images to process\n") if len(image_paths) == 0: print("No images found!") return transform = get_transform() # Load all models print("Loading models...") models = {} for config in MODEL_CONFIGS: try: models[config["short_name"]] = load_model(config, device) except Exception as e: print(f" ✗ Failed to load {config['name']}: {e}") print(f"\nLoaded {len(models)}/{len(MODEL_CONFIGS)} models\n") # Run predictions manifest = { "image_dir": INPUT_DIR, "prediction_dir": OUTPUT_DIR, "models": [c["name"] for c in MODEL_CONFIGS if c["short_name"] in models], "model_short_names": [c["short_name"] for c in MODEL_CONFIGS if c["short_name"] in models], "images": [] } for img_path in tqdm(image_paths, desc="Predicting"): img_name = os.path.basename(img_path) img_stem = os.path.splitext(img_name)[0] try: # Load original image original_img = Image.open(img_path).convert("RGB") original_size = original_img.size # (W, H) # Preprocess img_tensor = transform(original_img) # Predict with each model image_entry = { "filename": img_name, "original_path": f"source_images/{img_name}", "masks": {} } for config in MODEL_CONFIGS: short = config["short_name"] if short not in models: continue mask = predict_single(models[short], img_tensor, config["needs_sigmoid"], device) mask_filename = f"{img_stem}.png" mask_path = os.path.join(OUTPUT_DIR, short, mask_filename) save_mask(mask, mask_path, original_size) image_entry["masks"][short] = os.path.join("predictions", short, mask_filename) manifest["images"].append(image_entry) except Exception as e: print(f"\n ✗ Error processing {img_name}: {e}") # Save manifest manifest_path = os.path.join(OUTPUT_DIR, "manifest.json") with open(manifest_path, 'w') as f: json.dump(manifest, f, indent=2) print(f"\n✓ Predictions complete!") print(f" Images processed: {len(manifest['images'])}") print(f" Output directory: {OUTPUT_DIR}") print(f" Manifest: {manifest_path}") if __name__ == "__main__": main()