Spaces:
Runtime error
Runtime error
Commit ·
a80914c
1
Parent(s): 9f804f9
Update app.py
Browse files
app.py
CHANGED
|
@@ -97,14 +97,28 @@ def sepia(input_img):
|
|
| 97 |
pred_img = pred_img.astype(np.uint8)
|
| 98 |
|
| 99 |
fig = draw_plot(pred_img, seg)
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
demo = gr.Interface(fn=sepia,
|
| 104 |
inputs=gr.Image(shape=(400, 600)),
|
| 105 |
-
outputs=['plot'],
|
| 106 |
examples=["citiscapes-1.jpeg", "citiscapes-2.jpeg"],
|
| 107 |
allow_flagging='never')
|
| 108 |
|
| 109 |
-
|
| 110 |
demo.launch()
|
|
|
|
| 97 |
pred_img = pred_img.astype(np.uint8)
|
| 98 |
|
| 99 |
fig = draw_plot(pred_img, seg)
|
| 100 |
+
|
| 101 |
+
# 각 물체에 대한 예측 클래스와 확률 얻기
|
| 102 |
+
unique_labels = np.unique(seg.numpy().astype("uint8"))
|
| 103 |
+
class_probabilities = {}
|
| 104 |
+
for label in unique_labels:
|
| 105 |
+
mask = (seg.numpy() == label)
|
| 106 |
+
class_name = LABEL_NAMES[label]
|
| 107 |
+
class_prob = np.mean(outputs.logits.numpy()[0][mask])
|
| 108 |
+
class_probabilities[class_name] = class_prob
|
| 109 |
+
|
| 110 |
+
# 정확성이 가장 높은 물체 정보 얻기
|
| 111 |
+
max_prob_class = max(class_probabilities, key=class_probabilities.get)
|
| 112 |
+
max_prob_value = class_probabilities[max_prob_class]
|
| 113 |
+
|
| 114 |
+
# 출력 및 반환
|
| 115 |
+
print(f"Predicted class with highest probability: {max_prob_class}, Probability: {max_prob_value:.4f}")
|
| 116 |
+
return fig, f"Predicted class with highest probability: {max_prob_class}, Probability: {max_prob_value:.4f}"
|
| 117 |
|
| 118 |
demo = gr.Interface(fn=sepia,
|
| 119 |
inputs=gr.Image(shape=(400, 600)),
|
| 120 |
+
outputs=['plot', 'text'],
|
| 121 |
examples=["citiscapes-1.jpeg", "citiscapes-2.jpeg"],
|
| 122 |
allow_flagging='never')
|
| 123 |
|
|
|
|
| 124 |
demo.launch()
|