# test_load_film_unet_4_stages.py import torch, os from transformers import AutoModel, AutoConfig, AutoImageProcessor from PIL import Image # This script tests the 4-stage U-Net model. # ✅ Point to the root folder of your repository repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers") subfolder_4_stages = "unet_4_stages" # --- IMPORTANT --- # You need to place the correct model weights for the 4-stage U-Net in the # 'unet_4_stages' directory. The file should be named 'model.safetensors'. # The path is: /home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/unet_4_stages/model.safetensors # ----------------- print("Loading 4-stage model and processor...") try: proc = AutoImageProcessor.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True) model = AutoModel.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True) model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading the 4-stage model: {e}") print("Please ensure the 'model.safetensors' file in the 'unet_4_stages' directory is compatible with the 4-stage architecture.") exit() # --- Inference --- image_path = "/home/nicola/Downloads/45.png" if not os.path.exists(image_path): print(f"Error: Image file not found at {image_path}") exit() print(f"Loading image from {image_path}...") image = Image.open(image_path).convert("RGB") inputs = proc(images=image, return_tensors="pt") # Use an appropriate organ ID for your test case organ_id = torch.tensor([4]) print("Running inference...") with torch.no_grad(): out = model(**inputs, organ_id=organ_id) # Post-process to get the segmentation mask masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True) # Save the output mask output_path = "/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp_4_stages.png" masks[0].save(output_path) print(f"✅ Test complete. Segmentation mask saved to {output_path}")