File size: 4,232 Bytes
60465e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Use the SegFormer++ without OpenMMLab:

## Building a Model

- Use [build_model.py](../../model_without_OpenMMLab/segformer_plusplus/build_model.py) to build preset and custom SegFormer++ models

Navigate to model_without_OpenMMLab.
```python

from segformer_plusplus.build_model import create_model

# backbone: choose from ['b0', 'b1', 'b2', 'b3', 'b4', 'b5']

# tome_strategy: choose from ['bsm_hq', 'bsm_fast', 'n2d_2x2']

out_channels = 19  # number of classes, e.g. 19 for cityscapes

model = create_model('b5', 'bsm_hq', out_channels=out_channels, pretrained=True)

```
Running this code snippet yields our SegFormer++<sub>HQ</sub> model pretrained on ImageNet.

- Use [random_benchmark.py](../../model_without_OpenMMLab/segformer_plusplus/random_benchmark.py) to evaluate a model in terms of FPS

```python

from segformer_plusplus.random_benchmark import random_benchmark

v = random_benchmark(model)

```
Calculate the FPS using our script.

## Loading a Checkpoint

[Checkpoints](../../README.md) are provided in this Repository.
They can be loaded and integrated into the model via PyTorch:
```python

import torch

checkpoint_path = "path_to_your_checkpoint.pth that you downloaded (links in Readme)"

checkpoint = torch.load(checkpoint_path)

model.load_state_dict(checkpoint)

```
An Example can be found in [start_cityscape_benchmark.py](../../model_without_OpenMMLab/segformer_plusplus/start_cityscape_benchmark.py)

## Image Preperation

Images can be imported via PIL and then converted into RGB:

```python

from PIL import Image

image_path = "path_to_your_image.png"

image = Image.open(image_path).convert("RGB")

```

After that, convert the image into a torch tensor:

```python

import torch

import numpy as np



img_tensor = torch.from_numpy(np.array(image) / 255.0)

img_tensor = img_tensor.permute(2, 0, 1).float().unsqueeze(0)  # (1, C, H, W)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

img_tensor = img_tensor.to(device)

```

Now we can load the model:

```python

from segformer_plusplus.build_model import create_model



out_channels = 19

model = create_model(

    backbone='b5', 

    tome_strategy='bsm_hq', 

    out_channels=out_channels, 

    pretrained=False

).to(device)



model.load_state_dict(torch.load("path_to_checkpoint", map_location=device))

model.eval()

```

Inference:
```python

with torch.no_grad():

    output = model(img_tensor)

    segmentation_map = torch.argmax(output, dim=1).squeeze().cpu().numpy()

```

Visualize the results (this is for cityscapes classes):

```python

import numpy as np



# Official Cityscapes colors for train IDs 0-18

cityscapes_colors = np.array([

    [128,  64, 128], # 0: road

    [244,  35, 232], # 1: sidewalk

    [ 70,  70,  70], # 2: building

    [102, 102, 156], # 3: wall

    [190, 153, 153], # 4: fence

    [153, 153, 153], # 5: pole

    [250, 170,  30], # 6: traffic light

    [220, 220,   0], # 7: traffic sign

    [107, 142,  35], # 8: vegetation

    [152, 251, 152], # 9: terrain

    [ 70, 130, 180], # 10: sky

    [220,  20,  60], # 11: person

    [255,   0,   0], # 12: rider

    [  0,   0, 142], # 13: car

    [  0,   0,  70], # 14: truck

    [  0,  60, 100], # 15: bus

    [  0,  80, 100], # 16: train

    [  0,   0, 230], # 17: motorcycle

    [119,  11,  32], # 18: bicycle

], dtype=np.uint8)



# Map each class to its corresponding color

height, width = segmentation_map.shape

color_image = np.zeros((height, width, 3), dtype=np.uint8)

for class_index in range(len(cityscapes_colors)):

    color_image[segmentation_map == class_index] = cityscapes_colors[class_index]

```

Display and save output:
```python

import matplotlib.pyplot as plt



plt.figure(figsize=(6, 6))

plt.imshow(color_image)

plt.title("Semantic Segmentation Visualization")

plt.axis('off')

plt.show()

# Save the colorized output as an image - important when using a System without GUI

plt.imsave("segmentation_output.png", color_image)

```

> Note: You have to install matplotlib for visualization.


## Token-Merge Setting

For information to the settings for the Token Merging look [here](../../docs/run/token_merging.md).