Upload 8 files
Browse files- README.md +24 -12
- app.py +45 -0
- demo_model.pkl +3 -0
- training/cnn_learner.py +31 -0
- training/rnn (does not work)/ConvLSTM.py +194 -0
- training/rnn (does not work)/generate_mi_tensors.py +37 -0
- training/rnn (does not work)/generate_norm_tensors.py +37 -0
- training/rnn (does not work)/rnn.py +134 -0
README.md
CHANGED
|
@@ -1,12 +1,24 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HeartNet
|
| 2 |
+
|
| 3 |
+
A joint project by [oapostrophe](https://github.com/oapostrophe), [gkenderova](https://github.com/gkenderova), [soksamnanglim](https://github.com/soksamnanglim), [syaa2018](https://github.com/syaa2018)
|
| 4 |
+
|
| 5 |
+
For a high-level overview of this project, check out this [blog post](https://oapostrophe.github.io/heartnet/) and [90-second demo](https://www.youtube.com/watch?v=EqAU-FRu6C4). For a full presentation and more detailed writeup on our methodology, check out the report on our [project website](https://oapostrophe.github.io/HeartNet/).
|
| 6 |
+
|
| 7 |
+
The trained model can be demoed by downloading `app.py` and `demo_model.pkl`, installing [streamlit](https://anaconda.org/conda-forge/streamlit) and [fastai](https://pypi.org/project/fastai/), then running:
|
| 8 |
+
```shell
|
| 9 |
+
streamlit run app.py
|
| 10 |
+
```
|
| 11 |
+
You can then visit the provided url in your browser; for convenience, sample generated MI and Normal EKG images are provided in the `/test files` directory.
|
| 12 |
+
|
| 13 |
+
To use any of the other files, you'll have to download the [PTB-XL](https://physionet.org/content/ptb-xl/1.0.1/) dataset.
|
| 14 |
+
|
| 15 |
+
The important files are the following:
|
| 16 |
+
- `app.py` StreamLit-based web interface using a trained model
|
| 17 |
+
- `dataset generation/generate_imgset1.py` our first iteration generating a dataset directly with MatPlotLib; these images look rough.
|
| 18 |
+
- `dataset generation/generate_imgset2.py` our second iteration that generates nicer-looking images
|
| 19 |
+
- `dataset generation/generate_imgset3.py` adds random simulated shadows overlaying generated images
|
| 20 |
+
- `dataset generation/generate_rnn_imgset.py` generates individual images for each of 12 leads, for input into an RNN (rnn code currently fails to learn).
|
| 21 |
+
- `dataset generation/automold.py` library with image augmentation code for adding shadows
|
| 22 |
+
- `training/cnn_learner.py` trains and saves a cnn on generated images.
|
| 23 |
+
|
| 24 |
+
Feel free to [email me](swow2015@mymail.pomona.edu) with any questions!
|
app.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastai.vision.all import *
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
import requests
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
# HeartNet
|
| 8 |
+
This is a classifier for images of 12-lead EKGs. It will attempt to detect whether the EKG indicates an acute MI. It was trained on simulated images.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def predict(img):
|
| 12 |
+
st.image(img, caption="Your image", use_column_width=True)
|
| 13 |
+
pred, _, probs = learn_inf.predict(img)
|
| 14 |
+
# st.write(learn_inf.predict(img))
|
| 15 |
+
|
| 16 |
+
f"""
|
| 17 |
+
## This **{'is ' if pred == 'mi' else 'is not'}** an MI (heart attack).
|
| 18 |
+
### Probability of MI: {probs[0].item()*100: .2f}%
|
| 19 |
+
### Probability Normal: {probs[1].item()*100: .2f}%
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
path = "./"
|
| 24 |
+
learn_inf = load_learner(path + "demo_model.pkl")
|
| 25 |
+
|
| 26 |
+
option = st.radio("", ["Upload Image", "Image URL"])
|
| 27 |
+
|
| 28 |
+
if option == "Upload Image":
|
| 29 |
+
uploaded_file = st.file_uploader("Please upload an image.")
|
| 30 |
+
|
| 31 |
+
if uploaded_file is not None:
|
| 32 |
+
img = PILImage.create(uploaded_file)
|
| 33 |
+
predict(img)
|
| 34 |
+
|
| 35 |
+
else:
|
| 36 |
+
url = st.text_input("Please input a url.")
|
| 37 |
+
|
| 38 |
+
if url != "":
|
| 39 |
+
try:
|
| 40 |
+
response = requests.get(url)
|
| 41 |
+
pil_img = PILImage.create(BytesIO(response.content))
|
| 42 |
+
predict(pil_img)
|
| 43 |
+
|
| 44 |
+
except:
|
| 45 |
+
st.text("Problem reading image from", url)
|
demo_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a787a70f3c2197c26beafd169d2e6b3364414fec7066eb7468a794d0a209ea68
|
| 3 |
+
size 50315747
|
training/cnn_learner.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fastbook
|
| 2 |
+
fastbook.setup_book()
|
| 3 |
+
from fastbook import *
|
| 4 |
+
from fastai.vision.widgets import *
|
| 5 |
+
|
| 6 |
+
# Pick a GPU with free resources (change this accordingly)
|
| 7 |
+
torch.cuda.set_device(0)
|
| 8 |
+
|
| 9 |
+
# Get images
|
| 10 |
+
image_path = Path('/raid/heartnet/data/imgset2')
|
| 11 |
+
images = get_image_files(image_path)
|
| 12 |
+
|
| 13 |
+
# Initialize metric functions
|
| 14 |
+
recall_function = Recall(pos_label=0)
|
| 15 |
+
precision_function = Precision(pos_label=0)
|
| 16 |
+
f1_score = F1Score(pos_label=0)
|
| 17 |
+
|
| 18 |
+
# Initialize DataLoader
|
| 19 |
+
images_datablock = DataBlock(
|
| 20 |
+
blocks=(ImageBlock, CategoryBlock),
|
| 21 |
+
get_items=get_image_files,
|
| 22 |
+
splitter=RandomSplitter(valid_pct=0.2, seed=42),
|
| 23 |
+
get_y=parent_label,
|
| 24 |
+
batch_tfms=aug_transforms(do_flip=False)
|
| 25 |
+
)
|
| 26 |
+
dls = images_datablock.dataloaders(image_path, bs=16)
|
| 27 |
+
|
| 28 |
+
# Create, train, and save model
|
| 29 |
+
learn = cnn_learner(dls, resnet152, metrics=[error_rate, recall_function, precision_function, f1_score])
|
| 30 |
+
learn.fine_tune(16)
|
| 31 |
+
learn.export('demo_model_50.pkl')
|
training/rnn (does not work)/ConvLSTM.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
from fastai.vision.all import *
|
| 9 |
+
|
| 10 |
+
class ConvLSTMCell(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
|
| 13 |
+
"""
|
| 14 |
+
Initialize ConvLSTM cell.
|
| 15 |
+
Parameters
|
| 16 |
+
----------
|
| 17 |
+
input_dim: int
|
| 18 |
+
Number of channels of input tensor.
|
| 19 |
+
hidden_dim: int
|
| 20 |
+
Number of channels of hidden state.
|
| 21 |
+
kernel_size: (int, int)
|
| 22 |
+
Size of the convolutional kernel.
|
| 23 |
+
bias: bool
|
| 24 |
+
Whether or not to add the bias.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
super(ConvLSTMCell, self).__init__()
|
| 28 |
+
|
| 29 |
+
self.input_dim = input_dim
|
| 30 |
+
self.hidden_dim = hidden_dim
|
| 31 |
+
|
| 32 |
+
self.kernel_size = kernel_size
|
| 33 |
+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
|
| 34 |
+
self.bias = bias
|
| 35 |
+
|
| 36 |
+
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
|
| 37 |
+
out_channels=4 * self.hidden_dim,
|
| 38 |
+
kernel_size=self.kernel_size,
|
| 39 |
+
padding=self.padding,
|
| 40 |
+
bias=self.bias)
|
| 41 |
+
|
| 42 |
+
def forward(self, input_tensor, cur_state):
|
| 43 |
+
h_cur, c_cur = cur_state
|
| 44 |
+
|
| 45 |
+
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
|
| 46 |
+
|
| 47 |
+
combined_conv = self.conv(combined)
|
| 48 |
+
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
|
| 49 |
+
i = torch.sigmoid(cc_i)
|
| 50 |
+
f = torch.sigmoid(cc_f)
|
| 51 |
+
o = torch.sigmoid(cc_o)
|
| 52 |
+
g = torch.tanh(cc_g)
|
| 53 |
+
|
| 54 |
+
c_next = f * c_cur + i * g
|
| 55 |
+
h_next = o * torch.tanh(c_next)
|
| 56 |
+
|
| 57 |
+
return h_next, c_next
|
| 58 |
+
|
| 59 |
+
def init_hidden(self, batch_size, image_size):
|
| 60 |
+
height, width = image_size
|
| 61 |
+
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
|
| 62 |
+
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ConvLSTM(nn.Module):
|
| 66 |
+
|
| 67 |
+
"""
|
| 68 |
+
Parameters:
|
| 69 |
+
input_dim: Number of channels in input
|
| 70 |
+
hidden_dim: Number of hidden channels
|
| 71 |
+
kernel_size: Size of kernel in convolutions
|
| 72 |
+
num_layers: Number of LSTM layers stacked on each other
|
| 73 |
+
batch_first: Whether or not dimension 0 is the batch or not
|
| 74 |
+
bias: Bias or no bias in Convolution
|
| 75 |
+
return_all_layers: Return the list of computations for all layers
|
| 76 |
+
Note: Will do same padding.
|
| 77 |
+
Input:
|
| 78 |
+
A tensor of size B, T, C, H, W or T, B, C, H, W
|
| 79 |
+
Output:
|
| 80 |
+
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
|
| 81 |
+
0 - layer_output_list is the list of lists of length T of each output
|
| 82 |
+
1 - last_state_list is the list of last states
|
| 83 |
+
each element of the list is a tuple (h, c) for hidden state and memory
|
| 84 |
+
Example:
|
| 85 |
+
>> x = torch.rand((32, 10, 64, 128, 128))
|
| 86 |
+
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
|
| 87 |
+
>> _, last_states = convlstm(x)
|
| 88 |
+
>> h = last_states[0][0] # 0 for layer index, 0 for h index
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
|
| 92 |
+
batch_first=False, bias=True, return_all_layers=False):
|
| 93 |
+
super(ConvLSTM, self).__init__()
|
| 94 |
+
|
| 95 |
+
self._check_kernel_size_consistency(kernel_size)
|
| 96 |
+
|
| 97 |
+
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
|
| 98 |
+
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
|
| 99 |
+
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
|
| 100 |
+
if not len(kernel_size) == len(hidden_dim) == num_layers:
|
| 101 |
+
raise ValueError('Inconsistent list length.')
|
| 102 |
+
|
| 103 |
+
self.input_dim = input_dim
|
| 104 |
+
self.hidden_dim = hidden_dim
|
| 105 |
+
self.kernel_size = kernel_size
|
| 106 |
+
self.num_layers = num_layers
|
| 107 |
+
self.batch_first = batch_first
|
| 108 |
+
self.bias = bias
|
| 109 |
+
self.return_all_layers = return_all_layers
|
| 110 |
+
|
| 111 |
+
cell_list = []
|
| 112 |
+
for i in range(0, self.num_layers):
|
| 113 |
+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
|
| 114 |
+
|
| 115 |
+
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
|
| 116 |
+
hidden_dim=self.hidden_dim[i],
|
| 117 |
+
kernel_size=self.kernel_size[i],
|
| 118 |
+
bias=self.bias))
|
| 119 |
+
|
| 120 |
+
self.cell_list = nn.ModuleList(cell_list)
|
| 121 |
+
|
| 122 |
+
def forward(self, input_tensor, hidden_state=None):
|
| 123 |
+
"""
|
| 124 |
+
Parameters
|
| 125 |
+
----------
|
| 126 |
+
input_tensor: todo
|
| 127 |
+
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
|
| 128 |
+
hidden_state: todo
|
| 129 |
+
None. todo implement stateful
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
last_state_list, layer_output
|
| 133 |
+
"""
|
| 134 |
+
if not self.batch_first:
|
| 135 |
+
# (t, b, c, h, w) -> (b, t, c, h, w)
|
| 136 |
+
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
|
| 137 |
+
|
| 138 |
+
b, _, _, h, w = input_tensor.size()
|
| 139 |
+
|
| 140 |
+
# Implement stateful ConvLSTM
|
| 141 |
+
if hidden_state is not None:
|
| 142 |
+
raise NotImplementedError()
|
| 143 |
+
else:
|
| 144 |
+
# Since the init is done in forward. Can send image size here
|
| 145 |
+
hidden_state = self._init_hidden(batch_size=b,
|
| 146 |
+
image_size=(h, w))
|
| 147 |
+
|
| 148 |
+
layer_output_list = []
|
| 149 |
+
last_state_list = []
|
| 150 |
+
|
| 151 |
+
seq_len = input_tensor.size(1)
|
| 152 |
+
cur_layer_input = input_tensor
|
| 153 |
+
|
| 154 |
+
for layer_idx in range(self.num_layers):
|
| 155 |
+
|
| 156 |
+
h, c = hidden_state[layer_idx]
|
| 157 |
+
output_inner = []
|
| 158 |
+
for t in range(seq_len):
|
| 159 |
+
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
|
| 160 |
+
cur_state=[h, c])
|
| 161 |
+
output_inner.append(h)
|
| 162 |
+
|
| 163 |
+
layer_output = torch.stack(output_inner, dim=1)
|
| 164 |
+
cur_layer_input = layer_output
|
| 165 |
+
|
| 166 |
+
layer_output_list.append(layer_output)
|
| 167 |
+
last_state_list.append([h, c])
|
| 168 |
+
|
| 169 |
+
if not self.return_all_layers:
|
| 170 |
+
layer_output_list = layer_output_list[-1:]
|
| 171 |
+
last_state_list = last_state_list[-1:]
|
| 172 |
+
|
| 173 |
+
return layer_output_list, last_state_list
|
| 174 |
+
|
| 175 |
+
def _init_hidden(self, batch_size, image_size):
|
| 176 |
+
init_states = []
|
| 177 |
+
for i in range(self.num_layers):
|
| 178 |
+
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
|
| 179 |
+
return init_states
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def _check_kernel_size_consistency(kernel_size):
|
| 183 |
+
if not (isinstance(kernel_size, tuple) or
|
| 184 |
+
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
|
| 185 |
+
raise ValueError('`kernel_size` must be tuple or list of tuples')
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def _extend_for_multilayer(param, num_layers):
|
| 189 |
+
if not isinstance(param, list):
|
| 190 |
+
param = [param] * num_layers
|
| 191 |
+
return param
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
training/rnn (does not work)/generate_mi_tensors.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
mi_src = "./imgset_rnn/mi/"
|
| 11 |
+
mi_dest = "./tensorfiles_rnn/mi"
|
| 12 |
+
|
| 13 |
+
tensornum = 0
|
| 14 |
+
|
| 15 |
+
for dirs, sub_dirs, files in os.walk(mi_src):
|
| 16 |
+
if (dirs != "./imgset_rnn/mi/"):
|
| 17 |
+
tensors = []
|
| 18 |
+
pathlist = Path(dirs).glob('**/*.png')
|
| 19 |
+
for path in pathlist:
|
| 20 |
+
path_in_str = str(path) # because path is object not string
|
| 21 |
+
#print(path_in_str)
|
| 22 |
+
|
| 23 |
+
pil_img = Image.open(path_in_str).convert("RGB")
|
| 24 |
+
pil_to_tensor = transforms.ToTensor()(pil_img)
|
| 25 |
+
tensors.append(pil_to_tensor)
|
| 26 |
+
|
| 27 |
+
tensor = torch.stack(tuple(tensors))
|
| 28 |
+
tensor2 = torch.unsqueeze(tensor, 0)
|
| 29 |
+
|
| 30 |
+
tensorfile = torch.save(tensor2, 'tensor' + str(tensornum) + '.pt')
|
| 31 |
+
|
| 32 |
+
shutil.move('tensor' + str(tensornum) + '.pt', './tensorfiles_rnn/mi/tensor' + str(tensornum) + '.pt')
|
| 33 |
+
print('moved ' + str(tensornum) + '!')
|
| 34 |
+
tensornum += 1
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
#generated 5486 mi tensors :)
|
training/rnn (does not work)/generate_norm_tensors.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
norm_src = "./imgset_rnn/normal/"
|
| 11 |
+
norm_dest = "./tensorfiles_rnn/norm"
|
| 12 |
+
|
| 13 |
+
tensornum = 0
|
| 14 |
+
|
| 15 |
+
for dirs, sub_dirs, files in os.walk(norm_src):
|
| 16 |
+
if (dirs != "./imgset_rnn/normal/"):
|
| 17 |
+
tensors = []
|
| 18 |
+
pathlist = Path(dirs).glob('**/*.png')
|
| 19 |
+
for path in pathlist:
|
| 20 |
+
path_in_str = str(path) # because path is object not string
|
| 21 |
+
#print(path_in_str)
|
| 22 |
+
|
| 23 |
+
pil_img = Image.open(path_in_str).convert("RGB")
|
| 24 |
+
pil_to_tensor = transforms.ToTensor()(pil_img)
|
| 25 |
+
tensors.append(pil_to_tensor)
|
| 26 |
+
|
| 27 |
+
tensor = torch.stack(tuple(tensors))
|
| 28 |
+
tensor2 = torch.unsqueeze(tensor, 0)
|
| 29 |
+
|
| 30 |
+
tensorfile = torch.save(tensor2, 'tensor' + str(tensornum) + '.pt')
|
| 31 |
+
|
| 32 |
+
shutil.move('tensor' + str(tensornum) + '.pt', './tensorfiles_rnn/norm/tensor' + str(tensornum) + '.pt')
|
| 33 |
+
print('moved ' + str(tensornum) + '!')
|
| 34 |
+
tensornum += 1
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
#generated 7547 tensors :)
|
training/rnn (does not work)/rnn.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from torchvision import datasets
|
| 4 |
+
from torchvision.transforms import ToTensor, Lambda
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import os
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from torchvision.io import read_image
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import ConvLSTM
|
| 12 |
+
import torch.optim as optim
|
| 13 |
+
|
| 14 |
+
class CustomImageDataset(Dataset):
|
| 15 |
+
def __init__(self, image_directory):
|
| 16 |
+
self.image_directory = image_directory
|
| 17 |
+
|
| 18 |
+
self.total_num_inputs = 0
|
| 19 |
+
|
| 20 |
+
for dir_name in Path(image_directory).glob('*'):
|
| 21 |
+
self.total_num_inputs += len(list(dir_name.glob('*.pt')))
|
| 22 |
+
|
| 23 |
+
self.tensor_labels = ["norm", "mi"]
|
| 24 |
+
|
| 25 |
+
def __len__(self):
|
| 26 |
+
return self.total_num_inputs
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, idx):
|
| 29 |
+
#there are 7547 norm images and 5486 mi images
|
| 30 |
+
#return the correct label and the corresponding tensor (loaded by the file!)
|
| 31 |
+
|
| 32 |
+
if idx <= 7546:
|
| 33 |
+
label = self.tensor_labels[0]
|
| 34 |
+
tensor_file_path = "./tensorfiles_rnn/norm/tensor" + str(idx) + ".pt"
|
| 35 |
+
tensor = torch.load(tensor_file_path)
|
| 36 |
+
else:
|
| 37 |
+
label = self.tensor_labels[1]
|
| 38 |
+
tensor_file_path = "./tensorfiles_rnn/mi/tensor" + str(idx-7547) + ".pt"
|
| 39 |
+
tensor = torch.load(tensor_file_path)
|
| 40 |
+
|
| 41 |
+
return (tensor.squeeze(), torch.tensor(0.0 if label == self.tensor_labels[0] else 1.0))
|
| 42 |
+
|
| 43 |
+
training_data = CustomImageDataset("./tensorfiles_rnn")
|
| 44 |
+
|
| 45 |
+
train_dataloader = DataLoader(training_data, batch_size=2, shuffle=True)
|
| 46 |
+
|
| 47 |
+
# Display image and label.
|
| 48 |
+
train_features, train_labels = next(iter(train_dataloader))
|
| 49 |
+
|
| 50 |
+
print(f"Feature batch shape: {train_features.size()}")
|
| 51 |
+
print(f"Labels batch shape: {train_labels.size()}")
|
| 52 |
+
|
| 53 |
+
plt.imshow(train_features[0].squeeze()[0].squeeze().permute(1, 2, 0))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class FlatConvLSTM(torch.nn.Module):
|
| 58 |
+
"""An ConvLSTM layer that ignores the current hidden and cell states."""
|
| 59 |
+
def __init__(self):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.convlstm = ConvLSTM.ConvLSTM(3, 10, (3,3), 1, True, True, False)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
_, lstm_output = self.convlstm(x)
|
| 65 |
+
return lstm_output[0][0]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
model = torch.nn.Sequential(
|
| 69 |
+
FlatConvLSTM(),
|
| 70 |
+
torch.nn.Flatten(),
|
| 71 |
+
torch.nn.Linear(10*480*640, 1),
|
| 72 |
+
torch.nn.Sigmoid()
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
#create the loss function
|
| 77 |
+
loss = torch.nn.BCELoss()
|
| 78 |
+
optimizer = optim.Adam(model.parameters())
|
| 79 |
+
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
|
| 81 |
+
num_epochs = 10
|
| 82 |
+
model.to(device)
|
| 83 |
+
|
| 84 |
+
for epoch in range(num_epochs):
|
| 85 |
+
|
| 86 |
+
# Set model to training mode
|
| 87 |
+
model.train()
|
| 88 |
+
|
| 89 |
+
# Update the model for each batch
|
| 90 |
+
train_count = 0
|
| 91 |
+
train_cost = 0
|
| 92 |
+
batch = 0
|
| 93 |
+
|
| 94 |
+
for X, y in train_dataloader:
|
| 95 |
+
|
| 96 |
+
# Compute model cost
|
| 97 |
+
#yhat = model(X.view(-1, nx))
|
| 98 |
+
X = X.to(device)
|
| 99 |
+
y = y.to(device)
|
| 100 |
+
|
| 101 |
+
yhat = model(X)
|
| 102 |
+
#print(yhat.shape, y.shape)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
cost = loss(yhat.squeeze(), y)
|
| 106 |
+
model.zero_grad()
|
| 107 |
+
cost.backward()
|
| 108 |
+
except:
|
| 109 |
+
print()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
train_count += X.shape[0]
|
| 113 |
+
train_cost += cost.item()
|
| 114 |
+
optimizer.step()
|
| 115 |
+
print(epoch, batch, cost.item())
|
| 116 |
+
batch += 1
|
| 117 |
+
|
| 118 |
+
# Set model to evaluation mode
|
| 119 |
+
model.eval()
|
| 120 |
+
|
| 121 |
+
# Test model on validation data
|
| 122 |
+
valid_count = 0
|
| 123 |
+
valid_cost = 0
|
| 124 |
+
valid_correct = 0
|
| 125 |
+
|
| 126 |
+
train_cost /= train_count
|
| 127 |
+
|
| 128 |
+
print(epoch, train_cost)
|
| 129 |
+
|
| 130 |
+
print('Done.')
|
| 131 |
+
|
| 132 |
+
#save the model in a file
|
| 133 |
+
torch.save(model, "./rnn_saved_models/model.py")
|
| 134 |
+
torch.save(model.state_dict(), "./rnn_saved_models/model_parameters.py")
|