Spaces:
Build error
Build error
Prediction includes rations
Browse files- .gitignore +2 -1
- S1_CNN_Model.py +4 -1
- app.py +18 -9
.gitignore
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
test.*
|
| 2 |
-
*.pt
|
|
|
|
|
|
| 1 |
test.*
|
| 2 |
+
*.pt
|
| 3 |
+
__pycache__/
|
S1_CNN_Model.py
CHANGED
|
@@ -110,8 +110,11 @@ class CNN_Model(nn.Module):
|
|
| 110 |
|
| 111 |
preds = self.forward(patches)
|
| 112 |
_, preds = torch.max(preds,1)
|
|
|
|
|
|
|
| 113 |
preds = torch.mode(preds, 0).values
|
| 114 |
-
|
|
|
|
| 115 |
|
| 116 |
class_count = 41
|
| 117 |
|
|
|
|
| 110 |
|
| 111 |
preds = self.forward(patches)
|
| 112 |
_, preds = torch.max(preds,1)
|
| 113 |
+
|
| 114 |
+
ratios = preds
|
| 115 |
preds = torch.mode(preds, 0).values
|
| 116 |
+
|
| 117 |
+
return ratios, preds
|
| 118 |
|
| 119 |
class_count = 41
|
| 120 |
|
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import gdown
|
|
|
|
| 2 |
import os
|
| 3 |
import torch
|
| 4 |
from S1_CNN_Model import CNN_Model
|
|
@@ -45,24 +46,28 @@ def resize_image(img):
|
|
| 45 |
|
| 46 |
PD_COLS=["image","predicted species"]
|
| 47 |
MAX_HISTORY = 10
|
|
|
|
| 48 |
|
| 49 |
def classify(image: np.array, history):
|
| 50 |
if history == None: history = []
|
| 51 |
|
| 52 |
with torch.no_grad():
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
history += [(resize_image(image),
|
| 57 |
hist = history[-MAX_HISTORY:]
|
| 58 |
|
| 59 |
-
return pred, *toggle_history_components(hist), history
|
| 60 |
|
| 61 |
def toggle_history_components(history: list[History]):
|
| 62 |
n_hidden = MAX_HISTORY - len(history)
|
| 63 |
images, names = list(zip(*history))
|
| 64 |
|
| 65 |
-
components =
|
| 66 |
components += [gr.Image(visible=False)] * n_hidden
|
| 67 |
components += [gr.Markdown(x, visible=True) for x in names]
|
| 68 |
components += [gr.Markdown(visible=False)] * n_hidden
|
|
@@ -75,9 +80,13 @@ def classification_tab():
|
|
| 75 |
with gr.Row():
|
| 76 |
submit = gr.Button("Submit", variant='primary')
|
| 77 |
clear = gr.ClearButton(image)
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
return image, submit, clear, pred
|
| 81 |
|
| 82 |
MAX_SAMPLE_COUNT = max([len(os.listdir(x)) for x in listdir_full(SAMPLE_DIR)])
|
| 83 |
|
|
@@ -136,7 +145,7 @@ with gr.Blocks() as demo:
|
|
| 136 |
history = gr.State([])
|
| 137 |
with gr.Tabs() as tabs:
|
| 138 |
with gr.Tab("Classification", id=0):
|
| 139 |
-
image, submit, clear, pred = classification_tab()
|
| 140 |
|
| 141 |
with gr.Tab("Samples", id=1):
|
| 142 |
sample_tab(image, tabs)
|
|
@@ -145,6 +154,6 @@ with gr.Blocks() as demo:
|
|
| 145 |
table_contents = history_tab()
|
| 146 |
|
| 147 |
# history = gr.Gallery(interactive=False)
|
| 148 |
-
submit.click(classify,[image, history],[pred, *table_contents, history])
|
| 149 |
|
| 150 |
demo.launch()
|
|
|
|
| 1 |
import gdown
|
| 2 |
+
from collections import Counter
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
from S1_CNN_Model import CNN_Model
|
|
|
|
| 46 |
|
| 47 |
PD_COLS=["image","predicted species"]
|
| 48 |
MAX_HISTORY = 10
|
| 49 |
+
MAX_PREDS = 10
|
| 50 |
|
| 51 |
def classify(image: np.array, history):
|
| 52 |
if history == None: history = []
|
| 53 |
|
| 54 |
with torch.no_grad():
|
| 55 |
+
r, p = model.predict_large_image(cv2.cvtColor(image, cv2. COLOR_RGB2BGR))
|
| 56 |
+
ratios = [gr.Textbox(f"{labels[label]}: {count/len(r)*100:.2f}%",visible=True)
|
| 57 |
+
for label, count in Counter(r.tolist()).most_common()][-MAX_PREDS:]
|
| 58 |
+
ratios += [gr.Textbox(visible=False)] * (MAX_PREDS - len(ratios))
|
| 59 |
+
pred = gr.Markdown(f"## Predictions {labels[p.item()]}")
|
| 60 |
|
| 61 |
+
history += [(resize_image(image), labels[p.item()])]
|
| 62 |
hist = history[-MAX_HISTORY:]
|
| 63 |
|
| 64 |
+
return pred, *ratios, *toggle_history_components(hist), history
|
| 65 |
|
| 66 |
def toggle_history_components(history: list[History]):
|
| 67 |
n_hidden = MAX_HISTORY - len(history)
|
| 68 |
images, names = list(zip(*history))
|
| 69 |
|
| 70 |
+
components = [gr.Image(x, visible=True) for x in images]
|
| 71 |
components += [gr.Image(visible=False)] * n_hidden
|
| 72 |
components += [gr.Markdown(x, visible=True) for x in names]
|
| 73 |
components += [gr.Markdown(visible=False)] * n_hidden
|
|
|
|
| 80 |
with gr.Row():
|
| 81 |
submit = gr.Button("Submit", variant='primary')
|
| 82 |
clear = gr.ClearButton(image)
|
| 83 |
+
with gr.Column():
|
| 84 |
+
pred = gr.Markdown("## Predictions")
|
| 85 |
+
ratios = []
|
| 86 |
+
for _ in range(MAX_PREDS):
|
| 87 |
+
ratios.append(gr.Textbox(show_label=False,visible=False))
|
| 88 |
|
| 89 |
+
return image, submit, clear, pred, ratios
|
| 90 |
|
| 91 |
MAX_SAMPLE_COUNT = max([len(os.listdir(x)) for x in listdir_full(SAMPLE_DIR)])
|
| 92 |
|
|
|
|
| 145 |
history = gr.State([])
|
| 146 |
with gr.Tabs() as tabs:
|
| 147 |
with gr.Tab("Classification", id=0):
|
| 148 |
+
image, submit, clear, pred, ratios = classification_tab()
|
| 149 |
|
| 150 |
with gr.Tab("Samples", id=1):
|
| 151 |
sample_tab(image, tabs)
|
|
|
|
| 154 |
table_contents = history_tab()
|
| 155 |
|
| 156 |
# history = gr.Gallery(interactive=False)
|
| 157 |
+
submit.click(classify,[image, history],[pred, *ratios, *table_contents, history])
|
| 158 |
|
| 159 |
demo.launch()
|