| import torch |
| import argparse |
| import torchvision.transforms as transforms |
| from PIL import Image |
| from q2l_labeller.pl_modules.query2label_train_module import Query2LabelTrainModule |
| from q2l_labeller.data.dataset import SeaThruAugmentation |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((384, 384)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def load_model(checkpoint_path, num_classes): |
| """Load the trained Query2Label model from a checkpoint.""" |
| checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu") |
| |
| model = Query2LabelTrainModule( |
| data=None, |
| backbone_desc=checkpoint["hyper_parameters"]["backbone_desc"], |
| conv_out_dim=checkpoint["hyper_parameters"]["conv_out_dim"], |
| hidden_dim=checkpoint["hyper_parameters"]["hidden_dim"], |
| num_encoders=checkpoint["hyper_parameters"]["num_encoders"], |
| num_decoders=checkpoint["hyper_parameters"]["num_decoders"], |
| num_heads=checkpoint["hyper_parameters"]["num_heads"], |
| batch_size=checkpoint["hyper_parameters"]["batch_size"], |
| image_dim=checkpoint["hyper_parameters"]["image_dim"], |
| learning_rate=checkpoint["hyper_parameters"]["learning_rate"], |
| momentum=checkpoint["hyper_parameters"]["momentum"], |
| weight_decay=checkpoint["hyper_parameters"]["weight_decay"], |
| n_classes=num_classes, |
| thresh=0.4, |
| use_cutmix=checkpoint["hyper_parameters"]["use_cutmix"], |
| use_pos_encoding=checkpoint["hyper_parameters"]["use_pos_encoding"], |
| loss=checkpoint["hyper_parameters"]["loss"], |
| ) |
|
|
| model.load_state_dict(checkpoint["state_dict"]) |
| model.eval() |
| model.to("cuda" if torch.cuda.is_available() else "cpu") |
| return model |
|
|
|
|
| def preprocess_image(image_path, seathru_transform=None): |
| """Load and preprocess an image for inference.""" |
| image = Image.open(image_path).convert("RGB") |
|
|
| if seathru_transform: |
| image = seathru_transform(image_path, image) |
|
|
| image = transform(image).unsqueeze(0) |
| return image.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def predict(model, image_tensor, class_labels): |
| """Run inference and return predicted labels and probabilities.""" |
| with torch.no_grad(): |
| outputs = model(image_tensor) |
| probs = torch.sigmoid(outputs).squeeze().cpu().numpy() |
|
|
| predictions = [(class_labels[i], probs[i]) for i in range(len(probs)) if probs[i] > 0.4] |
| return predictions |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Inference with Depth-Jitter Query2Label model.") |
| parser.add_argument("--image", type=str, required=True, help="Path to the input image.") |
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to the model checkpoint (.ckpt).") |
| parser.add_argument("--num_classes", type=int, required=True, help="Number of classes (4 for UTDAC, 290 for FathomNet).") |
| parser.add_argument("--seathru", action="store_true", help="Apply SeaThru transformation if available.") |
| args = parser.parse_args() |
|
|
| |
| class_labels = [f"Class {i}" for i in range(args.num_classes)] |
|
|
| |
| model = load_model(args.checkpoint, args.num_classes) |
|
|
| |
| seathru_transform = None |
| if args.seathru: |
| seathru_transform = SeaThruAugmentation( |
| image_folder="", |
| depth_image_folder="", |
| depth_npy_folder="", |
| seathru_parameters_path="parameters_train.json", |
| depth_variance_path="depth_variance.json", |
| threshold=7.5 |
| ) |
|
|
| |
| image_tensor = preprocess_image(args.image, seathru_transform) |
|
|
| |
| predictions = predict(model, image_tensor, class_labels) |
|
|
| |
| print("\n🎯 Predictions:") |
| for label, prob in predictions: |
| print(f"- {label}: {prob:.4f}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|