kyurimsun commited on
Commit
092a74a
·
1 Parent(s): b0b3788

Add images with Git LFS

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. app.py +128 -0
  3. image1.jpg +3 -0
  4. image2.jpg +3 -0
  5. image3.jpg +3 -0
  6. image4.jpg +3 -0
  7. image5.jpg +3 -0
  8. labels.txt +35 -0
  9. requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from matplotlib import gridspec
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
8
+
9
+ MODEL_ID = "tobiasc/segformer-b0-finetuned-segments-sidewalk"
10
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
11
+ model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)
12
+
13
+ def ade_palette():
14
+ """ADE20K palette that maps each class to RGB values."""
15
+ return [
16
+ [0, 0, 0], # 0: unlabeled
17
+ [120, 120, 120], # 1: flat-road (회색)
18
+ [244, 35, 232], # 2: flat-sidewalk (분홍)
19
+ [107, 142, 35], # 3: flat-crosswalk (녹색)
20
+ [70, 130, 180], # 4: flat-cyclinglane (하늘색)
21
+ [255, 0, 0], # 5: flat-parkingdriveway (빨강)
22
+ [0, 0, 142], # 6: flat-railtrack (진청)
23
+ [220, 20, 60], # 7: flat-curb (진홍)
24
+ [220, 220, 0], # 8: human-person (노랑)
25
+ [119, 11, 32], # 9: human-rider (적갈)
26
+ [0, 0, 230], # 10: vehicle-car (파랑)
27
+ [0, 0, 70], # 11: vehicle-truck (남색)
28
+ [0, 60, 100], # 12: vehicle-bus (청록)
29
+ [0, 80, 100], # 13: vehicle-tramtrain
30
+ [0, 0, 110], # 14: vehicle-motorcycle
31
+ [111, 74, 0], # 15: vehicle-bicycle
32
+ [51, 51, 0], # 16: vehicle-caravan
33
+ [81, 0, 81], # 17: vehicle-cartrailer
34
+ [70, 70, 70], # 18: construction-building (진회색)
35
+ [150, 100, 100], # 19: construction-door
36
+ [190, 153, 153], # 20: construction-wall
37
+ [153, 153, 153], # 21: construction-fenceguardrail
38
+ [102, 102, 156], # 22: construction-bridge
39
+ [128, 64, 128], # 23: construction-tunnel (보라)
40
+ [64, 170, 64], # 24: construction-stairs
41
+ [250, 170, 30], # 25: object-pole (주황)
42
+ [255, 255, 0], # 26: object-trafficsign
43
+ [152, 251, 152], # 27: object-trafficlight
44
+ [31, 119, 180], # 28: nature-vegetation (초록)
45
+ [174, 199, 232], # 29: nature-terrain (연청)
46
+ [255, 127, 14], # 30: sky (연주황)
47
+ [140, 86, 75], # 31: void-ground
48
+ [148, 103, 189], # 32: void-dynamic
49
+ [227, 119, 194], # 33: void-static
50
+ [188, 189, 34] # 34: void-unclear
51
+ ]
52
+
53
+ labels_list = []
54
+ with open("labels.txt", "r", encoding="utf-8") as fp:
55
+ for line in fp:
56
+ labels_list.append(line.rstrip("\n"))
57
+
58
+ colormap = np.asarray(ade_palette(), dtype=np.uint8)
59
+
60
+ def label_to_color_image(label):
61
+ if label.ndim != 2:
62
+ raise ValueError("Expect 2-D input label")
63
+ if np.max(label) >= len(colormap):
64
+ raise ValueError("label value too large.")
65
+ return colormap[label]
66
+
67
+ def draw_plot(pred_img, seg_np):
68
+ fig = plt.figure(figsize=(20, 15))
69
+ grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
70
+
71
+ plt.subplot(grid_spec[0])
72
+ plt.imshow(pred_img)
73
+ plt.axis('off')
74
+
75
+ LABEL_NAMES = np.asarray(labels_list)
76
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
77
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
78
+
79
+ unique_labels = np.unique(seg_np.astype("uint8"))
80
+ ax = plt.subplot(grid_spec[1])
81
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
82
+ ax.yaxis.tick_right()
83
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
84
+ plt.xticks([], [])
85
+ ax.tick_params(width=0.0, labelsize=25)
86
+ return fig
87
+
88
+ def run_inference(input_img):
89
+ # input: numpy array from gradio -> PIL
90
+ img = Image.fromarray(input_img.astype(np.uint8)) if isinstance(input_img, np.ndarray) else input_img
91
+ if img.mode != "RGB":
92
+ img = img.convert("RGB")
93
+
94
+ inputs = processor(images=img, return_tensors="pt")
95
+ with torch.no_grad():
96
+ outputs = model(**inputs)
97
+ logits = outputs.logits # (1, C, h/4, w/4)
98
+
99
+ # resize to original
100
+ upsampled = torch.nn.functional.interpolate(
101
+ logits, size=img.size[::-1], mode="bilinear", align_corners=False
102
+ )
103
+ seg = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8) # (H,W)
104
+
105
+ # colorize & overlay
106
+ color_seg = colormap[seg] # (H,W,3)
107
+ pred_img = (np.array(img) * 0.5 + color_seg * 0.5).astype(np.uint8)
108
+
109
+ fig = draw_plot(pred_img, seg)
110
+ return fig
111
+
112
+ demo = gr.Interface(
113
+ fn=run_inference,
114
+ inputs=gr.Image(type="numpy", label="Input Image"),
115
+ outputs=gr.Plot(label="Overlay + Legend"),
116
+ examples=[
117
+ "image1.jpg",
118
+ "image2.jpg",
119
+ "image3.jpg",
120
+ "image4.jpg",
121
+ "image5.jpg"
122
+ ],
123
+ flagging_mode="never",
124
+ cache_examples=False,
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()
image1.jpg ADDED

Git LFS Details

  • SHA256: e34c5344802df608d08f06799e39b711c165cec15b71b63ddf098c9f9e7411cf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
image2.jpg ADDED

Git LFS Details

  • SHA256: e57d3096fbe008d9f6917c63ae681739dd40cdba8bb959356411b4358f1e6852
  • Pointer size: 130 Bytes
  • Size of remote file: 10.6 kB
image3.jpg ADDED

Git LFS Details

  • SHA256: dd74791f4bc7269a1946fc081833887c714228ca39325cbbeca54db81b4202f6
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
image4.jpg ADDED

Git LFS Details

  • SHA256: 590bb942dc21cd9b920de3b67657e0a84eb531647b04104616f9655315c1cd99
  • Pointer size: 131 Bytes
  • Size of remote file: 514 kB
image5.jpg ADDED

Git LFS Details

  • SHA256: 0f083f2dc854c8d16a2b6b3c12391a3dd58052927ff30e8c0f7071599b5f6fb4
  • Pointer size: 131 Bytes
  • Size of remote file: 342 kB
labels.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unlabeled
2
+ flat-road
3
+ flat-sidewalk
4
+ flat-crosswalk
5
+ flat-cyclinglane
6
+ flat-parkingdriveway
7
+ flat-railtrack
8
+ flat-curb
9
+ human-person
10
+ human-rider
11
+ vehicle-car
12
+ vehicle-truck
13
+ vehicle-bus
14
+ vehicle-tramtrain
15
+ vehicle-motorcycle
16
+ vehicle-bicycle
17
+ vehicle-caravan
18
+ vehicle-cartrailer
19
+ construction-building
20
+ construction-door
21
+ construction-wall
22
+ construction-fenceguardrail
23
+ construction-bridge
24
+ construction-tunnel
25
+ construction-stairs
26
+ object-pole
27
+ object-trafficsign
28
+ object-trafficlight
29
+ nature-vegetation
30
+ nature-terrain
31
+ sky
32
+ void-ground
33
+ void-dynamic
34
+ void-static
35
+ void-unclear
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.41.0
3
+ gradio>=4.0.0
4
+ Pillow
5
+ numpy
6
+ matplotlib