File size: 2,109 Bytes
aee1a39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# test_load_film_unet_4_stages.py
import torch, os
from transformers import AutoModel, AutoConfig, AutoImageProcessor
from PIL import Image

# This script tests the 4-stage U-Net model.

# ✅ Point to the root folder of your repository
repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
subfolder_4_stages = "unet_4_stages"

# --- IMPORTANT ---
# You need to place the correct model weights for the 4-stage U-Net in the
# 'unet_4_stages' directory. The file should be named 'model.safetensors'.
# The path is: /home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/unet_4_stages/model.safetensors
# -----------------

print("Loading 4-stage model and processor...")
try:
    proc = AutoImageProcessor.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True)
    model = AutoModel.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True)
    model.eval()
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading the 4-stage model: {e}")
    print("Please ensure the 'model.safetensors' file in the 'unet_4_stages' directory is compatible with the 4-stage architecture.")
    exit()

# --- Inference ---
image_path = "/home/nicola/Downloads/45.png"
if not os.path.exists(image_path):
    print(f"Error: Image file not found at {image_path}")
    exit()

print(f"Loading image from {image_path}...")
image = Image.open(image_path).convert("RGB")
inputs = proc(images=image, return_tensors="pt")

# Use an appropriate organ ID for your test case
organ_id = torch.tensor([4]) 

print("Running inference...")
with torch.no_grad():
    out = model(**inputs, organ_id=organ_id)

# Post-process to get the segmentation mask
masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)

# Save the output mask
output_path = "/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp_4_stages.png"
masks[0].save(output_path)

print(f"✅ Test complete. Segmentation mask saved to {output_path}")