File size: 2,157 Bytes
566e7ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch.hub
from PIL import Image

# --- IMPORTANT: TorchHub Dependencies ---
# Install the dependencies via:
# pip install tomesd omegaconf numpy rich yapf addict tqdm packaging torchvision

# Load the SegFormer++ model with predefined parameters.
print("Loading SegFormer++ Model...")
# Replace 'your_username/your_repo' with the actual path to your repository
model = torch.hub.load(
    'KieDani/SegformerPlusPlus',
    'segformer_plusplus',
    pretrained=True,
    backbone='b5',
    tome_strategy='bsm_hq',
    checkpoint_url='https://mediastore.rz.uni-augsburg.de/get/yzE65lzm6N/',
    out_channels=19,
)
model.eval()
print("Model loaded successfully.")

# Load the data transformations via the 'data_transforms' entry point.
print("Loading data transformations...")
transform = torch.hub.load(
    'KieDani/SegformerPlusPlus',
    'data_transforms',
)
print("Transformations loaded successfully.")

# --- Example for Image Preparation and Inference ---
# Create a dummy image, as we don't need a real image file.
# In a real scenario, you would load an image from the hard drive, e.g.
# from PIL import Image
# image = Image.open('path_to_your_image.jpg').convert('RGB')
print("Creating a dummy image for demonstration...")
dummy_image = Image.new('RGB', (1300, 1300), color='red')
print("Original image size:", dummy_image.size)

# Apply the transformations loaded from the Hub to the image.
print("Applying transformations to the image...")
input_tensor = transform(dummy_image).unsqueeze(0)  # Adds a batch dimension
print("Transformed image tensor size:", input_tensor.shape)

# Run inference.
print("Running inference...")
with torch.no_grad():
    output = model(input_tensor)

# The output tensor has the shape [1, num_classes, height, width]
# We remove the batch dimension (1)
output_tensor = output.squeeze(0)

print(f"\nInference completed. Output tensor size: {output_tensor.shape}")

# To get the final segmentation map, you would use argmax.
segmentation_map = torch.argmax(output_tensor, dim=0)
print(f"Size of the generated segmentation map: {segmentation_map.shape}")