|
|
|
|
|
import torch, os |
|
|
from transformers import AutoModel, AutoConfig, AutoImageProcessor |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers") |
|
|
subfolder_4_stages = "unet_4_stages" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading 4-stage model and processor...") |
|
|
try: |
|
|
proc = AutoImageProcessor.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True) |
|
|
model.eval() |
|
|
print("Model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading the 4-stage model: {e}") |
|
|
print("Please ensure the 'model.safetensors' file in the 'unet_4_stages' directory is compatible with the 4-stage architecture.") |
|
|
exit() |
|
|
|
|
|
|
|
|
image_path = "/home/nicola/Downloads/45.png" |
|
|
if not os.path.exists(image_path): |
|
|
print(f"Error: Image file not found at {image_path}") |
|
|
exit() |
|
|
|
|
|
print(f"Loading image from {image_path}...") |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
inputs = proc(images=image, return_tensors="pt") |
|
|
|
|
|
|
|
|
organ_id = torch.tensor([4]) |
|
|
|
|
|
print("Running inference...") |
|
|
with torch.no_grad(): |
|
|
out = model(**inputs, organ_id=organ_id) |
|
|
|
|
|
|
|
|
masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True) |
|
|
|
|
|
|
|
|
output_path = "/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp_4_stages.png" |
|
|
masks[0].save(output_path) |
|
|
|
|
|
print(f"✅ Test complete. Segmentation mask saved to {output_path}") |
|
|
|