Depth-Jitter / inference.py
yokko123's picture
Upload folder using huggingface_hub
3c96a9e verified
import torch
import argparse
import torchvision.transforms as transforms
from PIL import Image
from q2l_labeller.pl_modules.query2label_train_module import Query2LabelTrainModule
from q2l_labeller.data.dataset import SeaThruAugmentation
# Define the transformation pipeline (Resize, ToTensor, Normalize)
transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def load_model(checkpoint_path, num_classes):
"""Load the trained Query2Label model from a checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu")
model = Query2LabelTrainModule(
data=None, # Data module is not required for inference
backbone_desc=checkpoint["hyper_parameters"]["backbone_desc"],
conv_out_dim=checkpoint["hyper_parameters"]["conv_out_dim"],
hidden_dim=checkpoint["hyper_parameters"]["hidden_dim"],
num_encoders=checkpoint["hyper_parameters"]["num_encoders"],
num_decoders=checkpoint["hyper_parameters"]["num_decoders"],
num_heads=checkpoint["hyper_parameters"]["num_heads"],
batch_size=checkpoint["hyper_parameters"]["batch_size"],
image_dim=checkpoint["hyper_parameters"]["image_dim"],
learning_rate=checkpoint["hyper_parameters"]["learning_rate"],
momentum=checkpoint["hyper_parameters"]["momentum"],
weight_decay=checkpoint["hyper_parameters"]["weight_decay"],
n_classes=num_classes, # Dynamically set class count
thresh=0.4, # Default threshold, can be changed
use_cutmix=checkpoint["hyper_parameters"]["use_cutmix"],
use_pos_encoding=checkpoint["hyper_parameters"]["use_pos_encoding"],
loss=checkpoint["hyper_parameters"]["loss"],
)
model.load_state_dict(checkpoint["state_dict"]) # Load model weights
model.eval()
model.to("cuda" if torch.cuda.is_available() else "cpu")
return model
def preprocess_image(image_path, seathru_transform=None):
"""Load and preprocess an image for inference."""
image = Image.open(image_path).convert("RGB")
if seathru_transform:
image = seathru_transform(image_path, image)
image = transform(image).unsqueeze(0) # Add batch dimension
return image.to("cuda" if torch.cuda.is_available() else "cpu")
def predict(model, image_tensor, class_labels):
"""Run inference and return predicted labels and probabilities."""
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.sigmoid(outputs).squeeze().cpu().numpy() # Apply sigmoid to get probabilities
predictions = [(class_labels[i], probs[i]) for i in range(len(probs)) if probs[i] > 0.4] # Thresholding
return predictions
def main():
parser = argparse.ArgumentParser(description="Inference with Depth-Jitter Query2Label model.")
parser.add_argument("--image", type=str, required=True, help="Path to the input image.")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the model checkpoint (.ckpt).")
parser.add_argument("--num_classes", type=int, required=True, help="Number of classes (4 for UTDAC, 290 for FathomNet).")
parser.add_argument("--seathru", action="store_true", help="Apply SeaThru transformation if available.")
args = parser.parse_args()
# Load class labels (You should replace this with your dataset's class names)
class_labels = [f"Class {i}" for i in range(args.num_classes)]
# Load the model
model = load_model(args.checkpoint, args.num_classes)
# Initialize SeaThru transformation (if enabled)
seathru_transform = None
if args.seathru:
seathru_transform = SeaThruAugmentation(
image_folder="", # Not needed for inference
depth_image_folder="",
depth_npy_folder="",
seathru_parameters_path="parameters_train.json",
depth_variance_path="depth_variance.json",
threshold=7.5
)
# Preprocess the input image
image_tensor = preprocess_image(args.image, seathru_transform)
# Run inference
predictions = predict(model, image_tensor, class_labels)
# Print results
print("\n🎯 Predictions:")
for label, prob in predictions:
print(f"- {label}: {prob:.4f}")
if __name__ == "__main__":
main()