File size: 2,109 Bytes
aee1a39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# test_load_film_unet_4_stages.py
import torch, os
from transformers import AutoModel, AutoConfig, AutoImageProcessor
from PIL import Image
# This script tests the 4-stage U-Net model.
# ✅ Point to the root folder of your repository
repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
subfolder_4_stages = "unet_4_stages"
# --- IMPORTANT ---
# You need to place the correct model weights for the 4-stage U-Net in the
# 'unet_4_stages' directory. The file should be named 'model.safetensors'.
# The path is: /home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/unet_4_stages/model.safetensors
# -----------------
print("Loading 4-stage model and processor...")
try:
proc = AutoImageProcessor.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True)
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading the 4-stage model: {e}")
print("Please ensure the 'model.safetensors' file in the 'unet_4_stages' directory is compatible with the 4-stage architecture.")
exit()
# --- Inference ---
image_path = "/home/nicola/Downloads/45.png"
if not os.path.exists(image_path):
print(f"Error: Image file not found at {image_path}")
exit()
print(f"Loading image from {image_path}...")
image = Image.open(image_path).convert("RGB")
inputs = proc(images=image, return_tensors="pt")
# Use an appropriate organ ID for your test case
organ_id = torch.tensor([4])
print("Running inference...")
with torch.no_grad():
out = model(**inputs, organ_id=organ_id)
# Post-process to get the segmentation mask
masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)
# Save the output mask
output_path = "/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp_4_stages.png"
masks[0].save(output_path)
print(f"✅ Test complete. Segmentation mask saved to {output_path}")
|