| import torch |
| import os |
| from PIL import Image |
| from model import FM_PhysMamba_UNET, ODESolver |
| from utils import predict_large_image_vectorized, preprocess_single_image |
| from data.utils import restandardize_tensor |
| import matplotlib.pyplot as plt |
|
|
| |
| |
| |
| |
| WEIGHT_PATH = "checkpoints/final_weights/DENSE_HAZE/pytorch_model.bin" |
|
|
| INPUT_IMG_PATH = "test_imgs/test_images.jpeg" |
| OUTPUT_DIR = "test_imgs/results" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def load_dehazing_model(weight_path, device): |
| """Initializes model and loads pre-trained weights.""" |
| print(f"Loading model to {device}...") |
| model = FM_PhysMamba_UNET("small").to(device) |
| |
| if not os.path.exists(weight_path): |
| raise FileNotFoundError(f"Weights not found at: {weight_path}") |
| |
| checkpoint = torch.load(weight_path, map_location=device, weights_only=True) |
| model.load_state_dict(checkpoint) |
| model.eval() |
| return model |
|
|
| def run_inference(model, img_path, device, output_dir): |
| """Performs tiled inference and saves comparison result.""" |
| solver = ODESolver(model) |
| raw_img = Image.open(img_path).convert("RGB") |
| input_tensor = preprocess_single_image(raw_img, device=device) |
| |
| print(f"Starting tiled inference for {img_path}...") |
| with torch.no_grad(): |
| restored_tensor = predict_large_image_vectorized( |
| solver=solver, |
| full_img_tensor=input_tensor, |
| device=device, |
| tile_size=256, |
| overlap_ratio=0.25 |
| ) |
| |
| |
| hazy_disp = restandardize_tensor(input_tensor.detach().squeeze(0).cpu()).permute(1, 2, 0).numpy() |
| restored_disp = restandardize_tensor(restored_tensor.detach().squeeze(0).cpu()).permute(1, 2, 0).numpy() |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
| fig, axes = plt.subplots(1, 2, figsize=(15, 7)) |
| axes[0].imshow(hazy_disp) |
| axes[0].set_title("Original Hazy Input") |
| axes[0].axis("off") |
| |
| axes[1].imshow(restored_disp) |
| axes[1].set_title("FM-PhysMamba Restored") |
| axes[1].axis("off") |
| |
| save_path = os.path.join(output_dir, "comparison_result.png") |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| print(f"Success! Result saved to: {save_path}") |
| plt.show() |
|
|
| if __name__ == "__main__": |
| |
| my_model = load_dehazing_model(WEIGHT_PATH, DEVICE) |
| |
| |
| run_inference(my_model, INPUT_IMG_PATH, DEVICE, OUTPUT_DIR) |