Yapp99 commited on
Commit
1fb7b02
·
1 Parent(s): ab1042a

Prediction includes rations

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. S1_CNN_Model.py +4 -1
  3. app.py +18 -9
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  test.*
2
- *.pt
 
 
1
  test.*
2
+ *.pt
3
+ __pycache__/
S1_CNN_Model.py CHANGED
@@ -110,8 +110,11 @@ class CNN_Model(nn.Module):
110
 
111
  preds = self.forward(patches)
112
  _, preds = torch.max(preds,1)
 
 
113
  preds = torch.mode(preds, 0).values
114
- return preds
 
115
 
116
  class_count = 41
117
 
 
110
 
111
  preds = self.forward(patches)
112
  _, preds = torch.max(preds,1)
113
+
114
+ ratios = preds
115
  preds = torch.mode(preds, 0).values
116
+
117
+ return ratios, preds
118
 
119
  class_count = 41
120
 
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gdown
 
2
  import os
3
  import torch
4
  from S1_CNN_Model import CNN_Model
@@ -45,24 +46,28 @@ def resize_image(img):
45
 
46
  PD_COLS=["image","predicted species"]
47
  MAX_HISTORY = 10
 
48
 
49
  def classify(image: np.array, history):
50
  if history == None: history = []
51
 
52
  with torch.no_grad():
53
- pred = model.predict_large_image(cv2.cvtColor(image, cv2. COLOR_RGB2BGR)).item()
54
- pred = labels[pred]
 
 
 
55
 
56
- history += [(resize_image(image), pred)]
57
  hist = history[-MAX_HISTORY:]
58
 
59
- return pred, *toggle_history_components(hist), history
60
 
61
  def toggle_history_components(history: list[History]):
62
  n_hidden = MAX_HISTORY - len(history)
63
  images, names = list(zip(*history))
64
 
65
- components = [gr.Image(x, visible=True) for x in images]
66
  components += [gr.Image(visible=False)] * n_hidden
67
  components += [gr.Markdown(x, visible=True) for x in names]
68
  components += [gr.Markdown(visible=False)] * n_hidden
@@ -75,9 +80,13 @@ def classification_tab():
75
  with gr.Row():
76
  submit = gr.Button("Submit", variant='primary')
77
  clear = gr.ClearButton(image)
78
- pred = gr.Textbox(label="Prediction")
 
 
 
 
79
 
80
- return image, submit, clear, pred
81
 
82
  MAX_SAMPLE_COUNT = max([len(os.listdir(x)) for x in listdir_full(SAMPLE_DIR)])
83
 
@@ -136,7 +145,7 @@ with gr.Blocks() as demo:
136
  history = gr.State([])
137
  with gr.Tabs() as tabs:
138
  with gr.Tab("Classification", id=0):
139
- image, submit, clear, pred = classification_tab()
140
 
141
  with gr.Tab("Samples", id=1):
142
  sample_tab(image, tabs)
@@ -145,6 +154,6 @@ with gr.Blocks() as demo:
145
  table_contents = history_tab()
146
 
147
  # history = gr.Gallery(interactive=False)
148
- submit.click(classify,[image, history],[pred, *table_contents, history])
149
 
150
  demo.launch()
 
1
  import gdown
2
+ from collections import Counter
3
  import os
4
  import torch
5
  from S1_CNN_Model import CNN_Model
 
46
 
47
  PD_COLS=["image","predicted species"]
48
  MAX_HISTORY = 10
49
+ MAX_PREDS = 10
50
 
51
  def classify(image: np.array, history):
52
  if history == None: history = []
53
 
54
  with torch.no_grad():
55
+ r, p = model.predict_large_image(cv2.cvtColor(image, cv2. COLOR_RGB2BGR))
56
+ ratios = [gr.Textbox(f"{labels[label]}: {count/len(r)*100:.2f}%",visible=True)
57
+ for label, count in Counter(r.tolist()).most_common()][-MAX_PREDS:]
58
+ ratios += [gr.Textbox(visible=False)] * (MAX_PREDS - len(ratios))
59
+ pred = gr.Markdown(f"## Predictions {labels[p.item()]}")
60
 
61
+ history += [(resize_image(image), labels[p.item()])]
62
  hist = history[-MAX_HISTORY:]
63
 
64
+ return pred, *ratios, *toggle_history_components(hist), history
65
 
66
  def toggle_history_components(history: list[History]):
67
  n_hidden = MAX_HISTORY - len(history)
68
  images, names = list(zip(*history))
69
 
70
+ components = [gr.Image(x, visible=True) for x in images]
71
  components += [gr.Image(visible=False)] * n_hidden
72
  components += [gr.Markdown(x, visible=True) for x in names]
73
  components += [gr.Markdown(visible=False)] * n_hidden
 
80
  with gr.Row():
81
  submit = gr.Button("Submit", variant='primary')
82
  clear = gr.ClearButton(image)
83
+ with gr.Column():
84
+ pred = gr.Markdown("## Predictions")
85
+ ratios = []
86
+ for _ in range(MAX_PREDS):
87
+ ratios.append(gr.Textbox(show_label=False,visible=False))
88
 
89
+ return image, submit, clear, pred, ratios
90
 
91
  MAX_SAMPLE_COUNT = max([len(os.listdir(x)) for x in listdir_full(SAMPLE_DIR)])
92
 
 
145
  history = gr.State([])
146
  with gr.Tabs() as tabs:
147
  with gr.Tab("Classification", id=0):
148
+ image, submit, clear, pred, ratios = classification_tab()
149
 
150
  with gr.Tab("Samples", id=1):
151
  sample_tab(image, tabs)
 
154
  table_contents = history_tab()
155
 
156
  # history = gr.Gallery(interactive=False)
157
+ submit.click(classify,[image, history],[pred, *ratios, *table_contents, history])
158
 
159
  demo.launch()