alrichardbollans commited on
Commit
56df201
·
1 Parent(s): 2374391

Add an example ipynb to use model

Browse files
README.md CHANGED
@@ -10,14 +10,9 @@ license: cc-by-nc-sa-4.0
10
  short_description: Assessing viability of orchid seeds in TZ tests with AI.
11
  ---
12
 
13
- This is a templated Space for [Shiny for Python](https://shiny.rstudio.com/py/).
14
 
15
-
16
- To get started with a new app do the following:
17
-
18
- 1) Install Shiny with `pip install shiny`
19
- 2) Create a new app with `shiny create`
20
- 3) Then run the app with `shiny run --reload`
21
 
22
  To learn more about this framework please see the [Documentation](https://shiny.rstudio.com/py/docs/overview.html).
23
 
 
10
  short_description: Assessing viability of orchid seeds in TZ tests with AI.
11
  ---
12
 
13
+ This is an app created with [Shiny for Python](https://shiny.rstudio.com/py/).
14
 
15
+ If you create a python environment with all the requirements, you can download and run this app locally using shiny.
 
 
 
 
 
16
 
17
  To learn more about this framework please see the [Documentation](https://shiny.rstudio.com/py/docs/overview.html).
18
 
app.py CHANGED
@@ -349,8 +349,15 @@ def server(input, output, session: Session):
349
 
350
  @render.ui
351
  def download_results_ui():
 
352
  if analysis_results.get() and not is_analyzing.get():
353
- return ui.download_button("download_results", "Download Results", class_="btn-success"), ui.download_button("download_segmented_images",
 
 
 
 
 
 
354
  "Download Segmented Images",
355
  class_="btn-success")
356
 
 
349
 
350
  @render.ui
351
  def download_results_ui():
352
+
353
  if analysis_results.get() and not is_analyzing.get():
354
+ # results = analysis_results.get()
355
+ # current_nms = input.nms_threshold()
356
+ # print(f'Current NMS threshold: {current_nms}')
357
+ # if results[0].get('NMS threshold') != current_nms:
358
+ # print('NMS changed')
359
+ # else:
360
+ return ui.download_button("download_results", "Download Results", class_="btn-success"), ui.download_button("download_segmented_images",
361
  "Download Segmented Images",
362
  class_="btn-success")
363
 
example_using_final_model.ipynb ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {},
5
+ "cell_type": "markdown",
6
+ "source": [
7
+ "This file provides an example of how to use some of the underlying python methods used in this app.\n",
8
+ "\n",
9
+ "It requires a python environment with pytorch and detectron2. This can be set up on google colab using the following cell."
10
+ ],
11
+ "id": "54812cfe385be4e3"
12
+ },
13
+ {
14
+ "metadata": {},
15
+ "cell_type": "code",
16
+ "source": [
17
+ "## To run on google colab, run this cell to install requirements. You will then, however, need to copy some of our custom methods over (e.g. load_model etc..)\n",
18
+ "## In future we may package these properly, but they are just utility functions\n",
19
+ "try:\n",
20
+ " import google.colab\n",
21
+ " IN_COLAB = True\n",
22
+ "except:\n",
23
+ " IN_COLAB = False\n",
24
+ "if IN_COLAB:\n",
25
+ "\n",
26
+ " import sys, os, distutils.core\n",
27
+ " # Note: This is a faster way to install detectron2 in Colab, but it does not include all functionalities (e.g. compiled operators).\n",
28
+ " # See https://detectron2.readthedocs.io/tutorials/install.html for full installation instructions\n",
29
+ " # Issues raised:\n",
30
+ " # - pyaml install https://github.com/facebookresearch/detectron2/issues/5122 (think this is fixed)\n",
31
+ " !git clone 'https://github.com/facebookresearch/detectron2'\n",
32
+ " dist = distutils.core.run_setup(\"./detectron2/setup.py\")\n",
33
+ " !python -m pip install {' '.join([f\"'{x}'\" for x in dist.install_requires])}\n",
34
+ " sys.path.insert(0, os.path.abspath('./detectron2'))\n",
35
+ "\n",
36
+ " # Properly install detectron2. (Please do not install twice in both ways)\n",
37
+ " # !python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'\n",
38
+ " import torch, detectron2\n",
39
+ " !nvcc --version\n",
40
+ " TORCH_VERSION = \".\".join(torch.__version__.split(\".\")[:2])\n",
41
+ " CUDA_VERSION = torch.__version__.split(\"+\")[-1]\n",
42
+ " print(\"torch: \", TORCH_VERSION, \"; cuda: \", CUDA_VERSION)\n",
43
+ " print(\"detectron2:\", detectron2.__version__)\n",
44
+ " print(f'GPU available: {torch.cuda.is_available()}')"
45
+ ],
46
+ "id": "ae06666362c0995d",
47
+ "outputs": [],
48
+ "execution_count": null
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "id": "initial_id",
53
+ "metadata": {
54
+ "collapsed": true
55
+ },
56
+ "source": [
57
+ "## Import the model utilities that we will need, and load the model.\n",
58
+ "from python_utils import load_model, apply_nms, OPTIMAL_NMS_THRESHOLD\n",
59
+ "\n",
60
+ "predictor = load_model()"
61
+ ],
62
+ "outputs": [],
63
+ "execution_count": null
64
+ },
65
+ {
66
+ "metadata": {},
67
+ "cell_type": "code",
68
+ "source": [
69
+ "## Import and display an image.\n",
70
+ "import cv2\n",
71
+ "\n",
72
+ "if IN_COLAB:\n",
73
+ " from google.colab import files\n",
74
+ " uploaded = files.upload()\n",
75
+ " import cv2\n",
76
+ " from google.colab.patches import cv2_imshow\n",
77
+ "\n",
78
+ " im = cv2.imread(uploaded[0]) # just look at the first image\n",
79
+ " cv2_imshow(im)\n",
80
+ "else:\n",
81
+ "\n",
82
+ " import os\n",
83
+ " from IPython.display import Image\n",
84
+ "\n",
85
+ " img_file = os.path.join('assets', 'rbg_kew.jpg')\n",
86
+ " display(Image(filename=img_file))\n",
87
+ " im = cv2.imread(img_file)"
88
+ ],
89
+ "id": "12ab5c1116d6ff38",
90
+ "outputs": [],
91
+ "execution_count": null
92
+ },
93
+ {
94
+ "metadata": {},
95
+ "cell_type": "code",
96
+ "source": [
97
+ "## Run the model on the image, then apply NMS to filter out overlapping masks.\n",
98
+ "raw_output = predictor(im)\n",
99
+ "prediction = apply_nms(raw_output, mask=True, cls_agnostic_nms=OPTIMAL_NMS_THRESHOLD)"
100
+ ],
101
+ "id": "d9305d58e42d6504",
102
+ "outputs": [],
103
+ "execution_count": null
104
+ },
105
+ {
106
+ "metadata": {},
107
+ "cell_type": "code",
108
+ "source": [
109
+ "# Get the seed counts (0 = viable, 1 = non-viable, 2 = empty)\n",
110
+ "classes = prediction[\"instances\"].pred_classes.tolist()\n",
111
+ "counts = {\"viable\": classes.count(0),\n",
112
+ " \"non-viable\": classes.count(1),\n",
113
+ " \"empty\": classes.count(2),\n",
114
+ " \"total\": len(classes)}\n",
115
+ "print(counts)"
116
+ ],
117
+ "id": "676bd6fab69c10f2",
118
+ "outputs": [],
119
+ "execution_count": null
120
+ },
121
+ {
122
+ "metadata": {},
123
+ "cell_type": "code",
124
+ "source": [
125
+ "# Visualise the segmentation masks\n",
126
+ "\n",
127
+ "from matplotlib import pyplot as plt\n",
128
+ "from app import get_overlayed_image_from_single_result\n",
129
+ "\n",
130
+ "prediction['image'] = im\n",
131
+ "visualiser = get_overlayed_image_from_single_result(prediction)\n",
132
+ "fig, ax = plt.subplots(figsize=(8, 6.4))\n",
133
+ "ax.imshow(cv2.cvtColor(visualiser.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))\n",
134
+ "ax.get_xaxis().set_visible(False)\n",
135
+ "ax.get_yaxis().set_visible(False)\n",
136
+ "ax.set_title(\"Annotated\")\n",
137
+ "plt.tight_layout()"
138
+ ],
139
+ "id": "18934b98c3461a28",
140
+ "outputs": [],
141
+ "execution_count": null
142
+ }
143
+ ],
144
+ "metadata": {
145
+ "kernelspec": {
146
+ "display_name": "Python 3",
147
+ "language": "python",
148
+ "name": "python3"
149
+ },
150
+ "language_info": {
151
+ "codemirror_mode": {
152
+ "name": "ipython",
153
+ "version": 2
154
+ },
155
+ "file_extension": ".py",
156
+ "mimetype": "text/x-python",
157
+ "name": "python",
158
+ "nbconvert_exporter": "python",
159
+ "pygments_lexer": "ipython2",
160
+ "version": "2.7.6"
161
+ }
162
+ },
163
+ "nbformat": 4,
164
+ "nbformat_minor": 5
165
+ }
python_utils/get_model.py CHANGED
@@ -2,6 +2,7 @@ import urllib.request
2
  import tempfile
3
 
4
  ## Urls and model variables that might change.
 
5
  OPTIMAL_NMS_THRESHOLD = 0.7
6
  model_page = "https://huggingface.co/TZProject/final_tz_segmentor"
7
  _model_config_url = model_page + "/resolve/main/final_model_config.yaml"
 
2
  import tempfile
3
 
4
  ## Urls and model variables that might change.
5
+ ## If changing any of these, think about other places in repos where they might need changing (e.g. weights url inside config file).
6
  OPTIMAL_NMS_THRESHOLD = 0.7
7
  model_page = "https://huggingface.co/TZProject/final_tz_segmentor"
8
  _model_config_url = model_page + "/resolve/main/final_model_config.yaml"