Hantr commited on
Commit
ba0292c
·
1 Parent(s): 48ecfaa
Files changed (1) hide show
  1. app.py +31 -24
app.py CHANGED
@@ -58,14 +58,27 @@ def label_to_color_image(label):
58
  return colormap[label]
59
 
60
 
61
- def draw_class_visualization(seg, class_id):
62
- class_mask = seg.numpy() == class_id
63
- class_color = colormap[class_id]
64
 
65
- class_visualization = np.zeros(seg.shape + (3,))
66
- class_visualization[class_mask] = class_color
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- return class_visualization
69
 
70
  def sepia(input_img):
71
  input_img = Image.fromarray(input_img)
@@ -73,35 +86,29 @@ def sepia(input_img):
73
  inputs = feature_extractor(images=input_img, return_tensors="tf")
74
  outputs = model(**inputs)
75
  logits = outputs.logits
76
-
77
  logits = tf.transpose(logits, [0, 2, 3, 1])
78
  logits = tf.image.resize(
79
  logits, input_img.size[::-1]
80
- )
81
  seg = tf.math.argmax(logits, axis=-1)[0]
82
 
83
- class_visualizations = []
84
- for class_id in range(len(colormap)):
85
- class_visualization = draw_class_visualization(seg, class_id)
86
- class_visualizations.append(class_visualization)
87
-
88
- return class_visualizations
89
-
90
 
91
- def plot_class_visualization(class_visualizations):
92
- fig, axes = plt.subplots(1, len(class_visualizations), figsize=(20, 15))
 
93
 
94
- for i, class_visualization in enumerate(class_visualizations):
95
- ax = axes[i]
96
- ax.imshow(class_visualization)
97
- ax.axis('off')
98
- ax.set_title(labels_list[i])
99
- return fig
100
 
101
 
102
  demo = gr.Interface(fn=sepia,
103
  inputs=gr.Image(shape=(1024, 1024)),
104
- outputs=gr.outputs.Image(type='plot', label="Class Visualizations"),
105
  examples=["city-1.jpg", "city-2.jpg", "city-3.jpg", "city-4.jpg", "city-5.jpg"],
106
  allow_flagging='never')
107
 
 
58
  return colormap[label]
59
 
60
 
61
+ def draw_plot(pred_img, seg):
62
+ fig = plt.figure(figsize=(20, 15))
 
63
 
64
+ grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
65
+
66
+ plt.subplot(grid_spec[0])
67
+ plt.imshow(pred_img)
68
+ plt.axis('off')
69
+ LABEL_NAMES = np.asarray(labels_list)
70
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
71
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
72
+
73
+ unique_labels = np.unique(seg.numpy().astype("uint8"))
74
+ ax = plt.subplot(grid_spec[1])
75
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
76
+ ax.yaxis.tick_right()
77
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
78
+ plt.xticks([], [])
79
+ ax.tick_params(width=0.0, labelsize=25)
80
+ return fig
81
 
 
82
 
83
  def sepia(input_img):
84
  input_img = Image.fromarray(input_img)
 
86
  inputs = feature_extractor(images=input_img, return_tensors="tf")
87
  outputs = model(**inputs)
88
  logits = outputs.logits
 
89
  logits = tf.transpose(logits, [0, 2, 3, 1])
90
  logits = tf.image.resize(
91
  logits, input_img.size[::-1]
92
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
93
  seg = tf.math.argmax(logits, axis=-1)[0]
94
 
95
+ color_seg = np.zeros(
96
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
97
+ ) # height, width, 3
98
+ for label, color in enumerate(colormap):
99
+ color_seg[seg.numpy() == label, :] = color
 
 
100
 
101
+ # Show image + mask
102
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
103
+ pred_img = pred_img.astype(np.uint8)
104
 
105
+ fig = draw_plot(pred_img, seg)
106
+ return fig
 
 
 
 
107
 
108
 
109
  demo = gr.Interface(fn=sepia,
110
  inputs=gr.Image(shape=(1024, 1024)),
111
+ outputs=['plot'],
112
  examples=["city-1.jpg", "city-2.jpg", "city-3.jpg", "city-4.jpg", "city-5.jpg"],
113
  allow_flagging='never')
114