Spaces:
Sleeping
Sleeping
| """ | |
| Batch prediction script for all 4 segmentation models. | |
| Runs inference on all images from filtered_images folder and saves masks as PNGs. | |
| Also generates a manifest.json for the comparison web tool. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import glob | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| # ── Paths ────────────────────────────────────────────────────────────────────── | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| INPUT_DIR = "/home/mohamed-ennhiri/Desktop/images_dept_69/69_/filtered_images" | |
| OUTPUT_DIR = os.path.join(BASE_DIR, "predictions") | |
| IMAGE_SIZE = 128 | |
| THRESHOLD = 0.5 | |
| # Model configs: (name, model_dir, model_class_import, checkpoint, needs_sigmoid) | |
| MODEL_CONFIGS = [ | |
| { | |
| "name": "SegNet (CNN)", | |
| "short_name": "segnet", | |
| "model_dir": os.path.join(BASE_DIR, "cnn_model"), | |
| "checkpoint": os.path.join(BASE_DIR, "cnn_model", "checkpoints", "best_model.pth"), | |
| "needs_sigmoid": False, # SegNet applies sigmoid in forward() | |
| }, | |
| { | |
| "name": "UNet", | |
| "short_name": "unet", | |
| "model_dir": os.path.join(BASE_DIR, "unet_model"), | |
| "checkpoint": os.path.join(BASE_DIR, "unet_model", "checkpoints", "best_model.pth"), | |
| "needs_sigmoid": True, | |
| }, | |
| { | |
| "name": "SegFormer-B0", | |
| "short_name": "segformer_b0", | |
| "model_dir": os.path.join(BASE_DIR, "vit_model"), | |
| "checkpoint": os.path.join(BASE_DIR, "vit_model", "checkpoints", "best_model.pth"), | |
| "needs_sigmoid": True, | |
| }, | |
| { | |
| "name": "SegFormer-B5", | |
| "short_name": "segformer_b5", | |
| "model_dir": os.path.join(BASE_DIR, "segformer_b5_model"), | |
| "checkpoint": os.path.join(BASE_DIR, "segformer_b5_model", "checkpoints", "best_model.pth"), | |
| "needs_sigmoid": True, | |
| }, | |
| ] | |
| def load_model(config, device): | |
| """Load a model from its config and checkpoint.""" | |
| model_dir = config["model_dir"] | |
| # Add model dir to path for imports | |
| if model_dir not in sys.path: | |
| sys.path.insert(0, model_dir) | |
| short = config["short_name"] | |
| if short == "segnet": | |
| from cnn_segmenter import SegNet | |
| model = SegNet(in_channels=3, out_channels=1) | |
| elif short == "unet": | |
| from unet_model import UNet | |
| model = UNet(in_channels=3, out_channels=1) | |
| elif short == "segformer_b0": | |
| from segformer_model import SegformerModel | |
| model = SegformerModel(pretrained_name="nvidia/mit-b0", num_classes=1) | |
| elif short == "segformer_b5": | |
| from segformer_model import SegformerModel | |
| model = SegformerModel(pretrained_name="nvidia/mit-b5", num_classes=1) | |
| # Load checkpoint | |
| checkpoint = torch.load(config["checkpoint"], map_location=device, weights_only=False) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| # Remove model dir from path to avoid import conflicts | |
| if model_dir in sys.path: | |
| sys.path.remove(model_dir) | |
| print(f" ✓ Loaded {config['name']} from {os.path.basename(config['checkpoint'])}") | |
| return model | |
| def get_transform(): | |
| """Get the preprocessing transform (same as training — no normalization).""" | |
| return transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| ]) | |
| def predict_single(model, image_tensor, needs_sigmoid, device): | |
| """Run inference on a single image tensor.""" | |
| with torch.no_grad(): | |
| image_tensor = image_tensor.unsqueeze(0).to(device) | |
| output = model(image_tensor) | |
| if needs_sigmoid: | |
| output = torch.sigmoid(output) | |
| mask = (output > THRESHOLD).float() | |
| mask = mask.squeeze().cpu().numpy() | |
| return mask | |
| def save_mask(mask, save_path, original_size): | |
| """Save a binary mask as a PNG, resized to the original image dimensions.""" | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| mask_img = Image.fromarray(mask_uint8, mode='L') | |
| mask_img = mask_img.resize(original_size, Image.NEAREST) | |
| mask_img.save(save_path) | |
| def count_images(input_dir): | |
| """Count the number of jpg images.""" | |
| return len(glob.glob(os.path.join(input_dir, "*.jpg"))) | |
| def main(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Device: {device}") | |
| # Create output directories | |
| for config in MODEL_CONFIGS: | |
| os.makedirs(os.path.join(OUTPUT_DIR, config["short_name"]), exist_ok=True) | |
| # Get image list | |
| image_paths = sorted(glob.glob(os.path.join(INPUT_DIR, "*.jpg"))) | |
| print(f"Found {len(image_paths)} images to process\n") | |
| if len(image_paths) == 0: | |
| print("No images found!") | |
| return | |
| transform = get_transform() | |
| # Load all models | |
| print("Loading models...") | |
| models = {} | |
| for config in MODEL_CONFIGS: | |
| try: | |
| models[config["short_name"]] = load_model(config, device) | |
| except Exception as e: | |
| print(f" ✗ Failed to load {config['name']}: {e}") | |
| print(f"\nLoaded {len(models)}/{len(MODEL_CONFIGS)} models\n") | |
| # Run predictions | |
| manifest = { | |
| "image_dir": INPUT_DIR, | |
| "prediction_dir": OUTPUT_DIR, | |
| "models": [c["name"] for c in MODEL_CONFIGS if c["short_name"] in models], | |
| "model_short_names": [c["short_name"] for c in MODEL_CONFIGS if c["short_name"] in models], | |
| "images": [] | |
| } | |
| for img_path in tqdm(image_paths, desc="Predicting"): | |
| img_name = os.path.basename(img_path) | |
| img_stem = os.path.splitext(img_name)[0] | |
| try: | |
| # Load original image | |
| original_img = Image.open(img_path).convert("RGB") | |
| original_size = original_img.size # (W, H) | |
| # Preprocess | |
| img_tensor = transform(original_img) | |
| # Predict with each model | |
| image_entry = { | |
| "filename": img_name, | |
| "original_path": f"source_images/{img_name}", | |
| "masks": {} | |
| } | |
| for config in MODEL_CONFIGS: | |
| short = config["short_name"] | |
| if short not in models: | |
| continue | |
| mask = predict_single(models[short], img_tensor, config["needs_sigmoid"], device) | |
| mask_filename = f"{img_stem}.png" | |
| mask_path = os.path.join(OUTPUT_DIR, short, mask_filename) | |
| save_mask(mask, mask_path, original_size) | |
| image_entry["masks"][short] = os.path.join("predictions", short, mask_filename) | |
| manifest["images"].append(image_entry) | |
| except Exception as e: | |
| print(f"\n ✗ Error processing {img_name}: {e}") | |
| # Save manifest | |
| manifest_path = os.path.join(OUTPUT_DIR, "manifest.json") | |
| with open(manifest_path, 'w') as f: | |
| json.dump(manifest, f, indent=2) | |
| print(f"\n✓ Predictions complete!") | |
| print(f" Images processed: {len(manifest['images'])}") | |
| print(f" Output directory: {OUTPUT_DIR}") | |
| print(f" Manifest: {manifest_path}") | |
| if __name__ == "__main__": | |
| main() | |