Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch.nn.functional as F
|
|
| 6 |
from typing import Literal, Any
|
| 7 |
import gradio as gr
|
| 8 |
import spaces
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class Classifier:
|
|
@@ -74,7 +75,14 @@ def draw_bar_chart(data: dict[str, list[str | float]]):
|
|
| 74 |
|
| 75 |
plt.tight_layout()
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
def get_layout():
|
|
@@ -137,7 +145,7 @@ def get_layout():
|
|
| 137 |
|
| 138 |
with gr.Row(equal_height=True):
|
| 139 |
image_input = gr.Image(label="上傳影像", type="pil")
|
| 140 |
-
|
| 141 |
|
| 142 |
start_button = gr.Button("開始分類", variant="primary")
|
| 143 |
gr.HTML(
|
|
@@ -146,7 +154,7 @@ def get_layout():
|
|
| 146 |
start_button.click(
|
| 147 |
fn=Classifier().predict,
|
| 148 |
inputs=image_input,
|
| 149 |
-
outputs=
|
| 150 |
)
|
| 151 |
|
| 152 |
return demo
|
|
|
|
| 6 |
from typing import Literal, Any
|
| 7 |
import gradio as gr
|
| 8 |
import spaces
|
| 9 |
+
from io import BytesIO
|
| 10 |
|
| 11 |
|
| 12 |
class Classifier:
|
|
|
|
| 75 |
|
| 76 |
plt.tight_layout()
|
| 77 |
|
| 78 |
+
bimg = BytesIO()
|
| 79 |
+
plt.save(img, format="png")
|
| 80 |
+
plt.close
|
| 81 |
+
|
| 82 |
+
bimg.seek(0)
|
| 83 |
+
img = Image.open(bimg)
|
| 84 |
+
|
| 85 |
+
return img
|
| 86 |
|
| 87 |
|
| 88 |
def get_layout():
|
|
|
|
| 145 |
|
| 146 |
with gr.Row(equal_height=True):
|
| 147 |
image_input = gr.Image(label="上傳影像", type="pil")
|
| 148 |
+
chart = gr.Image(label="分類結果")
|
| 149 |
|
| 150 |
start_button = gr.Button("開始分類", variant="primary")
|
| 151 |
gr.HTML(
|
|
|
|
| 154 |
start_button.click(
|
| 155 |
fn=Classifier().predict,
|
| 156 |
inputs=image_input,
|
| 157 |
+
outputs=chart,
|
| 158 |
)
|
| 159 |
|
| 160 |
return demo
|