Spaces:
Runtime error
Runtime error
Commit ·
6f47252
1
Parent(s): 3294cf8
working example
Browse files- .devcontainer/devcontainer.json +1 -1
- .gitattributes +14 -0
- app.py +125 -21
- data/test-data.csv +3 -0
- data/weights/h1_w2vec/args.yaml +3 -0
- data/weights/h1_w2vec/model.pt +3 -0
- data/weights/h1_w2vec/train.log +3 -0
- data/weights/h3_w2vec/args.yaml +3 -0
- data/weights/h3_w2vec/model.pt +3 -0
- data/weights/h3_w2vec/train.log +3 -0
- data/weights/h5_w2vec/args.yaml +3 -0
- data/weights/h5_w2vec/model.pt +3 -0
- data/weights/h5_w2vec/train.log +3 -0
- data/weights/h7_w2vec/args.yaml +3 -0
- data/weights/h7_w2vec/model.pt +3 -0
- data/weights/h7_w2vec/train.log +3 -0
- data/weights/h9_w2vec/args.yaml +3 -0
- data/weights/h9_w2vec/model.pt +3 -0
- data/weights/h9_w2vec/train.log +3 -0
- data/weights/r1_w2vec/args.yaml +3 -0
- data/weights/r1_w2vec/model.pt +3 -0
- data/weights/r1_w2vec/train.log +3 -0
- data/weights/r3_local/args.yaml +3 -0
- data/weights/r3_local/model.pt +3 -0
- data/weights/r3_local/train.log +3 -0
- data/weights/r3_nbrs/args.yaml +3 -0
- data/weights/r3_nbrs/model.pt +3 -0
- data/weights/r3_nbrs/train.log +3 -0
- data/weights/r3_w2vec/args.yaml +3 -0
- data/weights/r3_w2vec/model.pt +3 -0
- data/weights/r3_w2vec/train.log +3 -0
- data/weights/r5_w2vec/args.yaml +3 -0
- data/weights/r5_w2vec/model.pt +3 -0
- data/weights/r5_w2vec/train.log +3 -0
- data/weights/r7_w2vec/args.yaml +3 -0
- data/weights/r7_w2vec/model.pt +3 -0
- data/weights/r7_w2vec/train.log +3 -0
- data/weights/r9_local/args.yaml +3 -0
- data/weights/r9_local/model.pt +3 -0
- data/weights/r9_local/train.log +3 -0
- data/weights/r9_nbrs/args.yaml +3 -0
- data/weights/r9_nbrs/model.pt +3 -0
- data/weights/r9_nbrs/train.log +3 -0
- data/weights/r9_w2vec/args.yaml +3 -0
- data/weights/r9_w2vec/model.pt +3 -0
- data/weights/r9_w2vec/train.log +3 -0
- models.py +6 -0
- requirements.txt +1 -0
- utils.py +58 -0
.devcontainer/devcontainer.json
CHANGED
|
@@ -21,7 +21,7 @@
|
|
| 21 |
// "features": {},
|
| 22 |
|
| 23 |
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
| 24 |
-
"forwardPorts": [
|
| 25 |
|
| 26 |
// Uncomment the next line to run commands after the container is created.
|
| 27 |
// "postCreateCommand": "cat /etc/os-release",
|
|
|
|
| 21 |
// "features": {},
|
| 22 |
|
| 23 |
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
| 24 |
+
// "forwardPorts": [],
|
| 25 |
|
| 26 |
// Uncomment the next line to run commands after the container is created.
|
| 27 |
// "postCreateCommand": "cat /etc/os-release",
|
.gitattributes
CHANGED
|
@@ -37,3 +37,17 @@ data/training_data.pkl filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 38 |
**/*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 39 |
data filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 38 |
**/*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 39 |
data filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data/weights/h1_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
data/weights/h7_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
data/weights/r3_local/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
data/weights/r3_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
data/weights/r5_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
data/weights/r9_local/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
data/weights/r9_nbrs/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
data/weights/r9_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
data/weights/h3_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
data/weights/r7_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
data/weights/r1_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
data/weights/r3_nbrs/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
data/weights/h5_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
data/weights/h9_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
from shiny import App, ui, render
|
|
|
|
| 2 |
import shinyswatch
|
| 3 |
|
| 4 |
import torch
|
|
@@ -12,10 +13,9 @@ from collections import defaultdict
|
|
| 12 |
from tqdm import tqdm
|
| 13 |
import itertools as it
|
| 14 |
from torch import nn
|
|
|
|
| 15 |
|
| 16 |
-
import
|
| 17 |
-
from models import UNetEncoder, Decoder
|
| 18 |
-
from utils import load_training_data
|
| 19 |
|
| 20 |
MONTHS= {
|
| 21 |
0: "Jan",
|
|
@@ -61,35 +61,139 @@ C, NAMES, Y, M = load_training_data(
|
|
| 61 |
)
|
| 62 |
_, _, YRAW, MRAW = load_training_data(path="data/training_data.pkl")
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# Part 1: ui ----
|
| 66 |
app_ui = ui.page_fluid(
|
| 67 |
shinyswatch.theme.minty(),
|
| 68 |
-
|
| 69 |
-
ui.
|
| 70 |
-
|
| 71 |
-
ui.
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
|
| 82 |
# Part 2: server ----
|
| 83 |
def server(input, output, session):
|
| 84 |
# make a plot
|
| 85 |
-
@
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
|
|
|
|
|
|
|
|
|
| 92 |
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# Combine into a shiny app.
|
| 95 |
# Note that the variable must be "app".
|
|
|
|
| 1 |
+
from shiny import App, ui, render, reactive
|
| 2 |
+
from shiny.ui import HTML, tags
|
| 3 |
import shinyswatch
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 13 |
from tqdm import tqdm
|
| 14 |
import itertools as it
|
| 15 |
from torch import nn
|
| 16 |
+
import io
|
| 17 |
|
| 18 |
+
from utils import load_training_data, load_models
|
|
|
|
|
|
|
| 19 |
|
| 20 |
MONTHS= {
|
| 21 |
0: "Jan",
|
|
|
|
| 61 |
)
|
| 62 |
_, _, YRAW, MRAW = load_training_data(path="data/training_data.pkl")
|
| 63 |
|
| 64 |
+
prefix = "h"
|
| 65 |
+
dirs = {
|
| 66 |
+
"r1": f"./data/weights/{prefix}1_w2vec",
|
| 67 |
+
"r3": f"./data/weights/{prefix}3_w2vec",
|
| 68 |
+
"r5": f"./data/weights/{prefix}5_w2vec",
|
| 69 |
+
"r7": f"./data/weights/{prefix}7_w2vec",
|
| 70 |
+
"r9": f"./data/weights/{prefix}9_w2vec",
|
| 71 |
+
}
|
| 72 |
+
MODELS = load_models(dirs, prefix=prefix, nd=5)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
multicol_html = tags.head(
|
| 76 |
+
tags.style(
|
| 77 |
+
HTML(
|
| 78 |
+
".multicol {"
|
| 79 |
+
# "height: 150px; "
|
| 80 |
+
"-webkit-column-count: 3;" # chrome, safari, opera
|
| 81 |
+
"-moz-column-count: 3;" # firefox
|
| 82 |
+
"column-count: 3;"
|
| 83 |
+
"-moz-column-fill: auto;"
|
| 84 |
+
"-column-fill: auto;"
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
instructions = f"""
|
| 90 |
+
### Instructions
|
| 91 |
+
|
| 92 |
+
Upload a CSV file with columns (id, lat, lon) using the `Browse` button on the sidebar.
|
| 93 |
+
Below is an example of the contents of the file:
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
```
|
| 97 |
+
id,lat,lon
|
| 98 |
+
0,47.5,-122.5
|
| 99 |
+
1,47.5,-122.25
|
| 100 |
+
2,47.5,-122.0
|
| 101 |
+
3,47.5,-121.75
|
| 102 |
+
4,47.5,-121.5
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
The id column can be any identifier, or the column can be ommited, in which case the row number will be used as the id.
|
| 107 |
+
Make sure that the latitude is before the longitude column in the CSV file. The valid range for latitude is
|
| 108 |
+
{YMIN} to {YMAX} and longitude is {XMIN} to {XMAX}, which cover the contiguous United States.
|
| 109 |
+
|
| 110 |
+
The resolution corresponds to how much neighboring information is captured by the embedding. If `local` is selected,
|
| 111 |
+
the original weather covariates will be returned. Currently, all the embeddings correspond to the variables:
|
| 112 |
+
air temperature (2m), precipitation, relative humidity (2m), vertical wind speed (10m), and horizontal wind speed (10m).
|
| 113 |
+
The native resolution of the covariates is ~32 km for a grid size of 128 x 256.
|
| 114 |
+
|
| 115 |
+
### Results
|
| 116 |
+
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
# After uploading the file, the app will generate a CSV, a download link will appear here.
|
| 120 |
+
# The CSV will contain the following columns:
|
| 121 |
+
|
| 122 |
|
| 123 |
# Part 1: ui ----
|
| 124 |
app_ui = ui.page_fluid(
|
| 125 |
shinyswatch.theme.minty(),
|
| 126 |
+
multicol_html,
|
| 127 |
+
ui.panel_title("Welcome to the Weather2vec Embedding Generator!"),
|
| 128 |
+
ui.layout_sidebar(
|
| 129 |
+
ui.panel_sidebar(
|
| 130 |
+
ui.input_file("df", "Upload CSV File", accept=".csv"),
|
| 131 |
+
tags.div(
|
| 132 |
+
ui.input_checkbox_group("months", HTML("<b>Months</b>"), MONTHS),
|
| 133 |
+
class_="multicol",
|
| 134 |
+
align="left",
|
| 135 |
+
inline=False
|
| 136 |
+
),
|
| 137 |
+
HTML("<b>Note:</b> Embedding of multiple months will be added.<br>True multi-temporal embeddings will be supported in the future.<br><br>"),
|
| 138 |
+
tags.div(
|
| 139 |
+
ui.input_radio_buttons("years", HTML("<b>Year</b>"), YEARS),
|
| 140 |
+
class_="multicol",
|
| 141 |
+
align="left",
|
| 142 |
+
inline=False
|
| 143 |
+
),
|
| 144 |
+
HTML("<br>"),
|
| 145 |
+
tags.div(
|
| 146 |
+
ui.input_radio_buttons("resolutions", HTML("<b>Resolution</b>"), RESOLUTIONS),
|
| 147 |
+
class_="multicol",
|
| 148 |
+
align="left",
|
| 149 |
+
inline=False
|
| 150 |
+
),
|
| 151 |
+
width=3,
|
| 152 |
+
),
|
| 153 |
+
ui.panel_main(
|
| 154 |
+
ui.markdown(instructions),
|
| 155 |
+
ui.download_button("download", "Download Embeddings"),
|
| 156 |
+
),
|
| 157 |
+
)
|
| 158 |
)
|
| 159 |
|
| 160 |
|
| 161 |
# Part 2: server ----
|
| 162 |
def server(input, output, session):
|
| 163 |
# make a plot
|
| 164 |
+
@session.download(filename="embeddings.csv")
|
| 165 |
+
def download():
|
| 166 |
+
# read input file
|
| 167 |
+
print(input.df()[-1].keys())
|
| 168 |
+
fname = input.df()[-1]['datapath']
|
| 169 |
+
df = pd.read_csv(fname)
|
| 170 |
+
|
| 171 |
+
# dfcols = []
|
| 172 |
+
# for k, v in D.items():
|
| 173 |
+
# mod = v["mod"]
|
| 174 |
+
# mod.load_state_dict(torch.load(os.path.join(k, "model.pt")))
|
| 175 |
+
# with torch.no_grad():
|
| 176 |
+
# Z = mod["enc"](Ct)
|
| 177 |
+
# Z = Z.mean(0).cpu().numpy()
|
| 178 |
+
# Zmat = Z[:, row, col].T
|
| 179 |
+
# colnames = [f"C{i:02d}" for i in range(Zmat.shape[-1])]
|
| 180 |
+
# Z = pd.DataFrame(Zmat, columns=colnames)
|
| 181 |
+
# Z = pd.DataFrame(Zmat, columns=[x + f"_{len(dfcols)}" for x in colnames])
|
| 182 |
+
# dfcols.append(Z)
|
| 183 |
+
# Z = pd.concat([locs, Z], axis=1)
|
| 184 |
+
# Z.to_csv(f"{savedir}/kms_{32 * int(v['radius']):03d}.csv", index=False)
|
| 185 |
+
|
| 186 |
+
# Cloc = Corig[ix].mean(0)[:, row, col].T
|
| 187 |
+
# Z = pd.DataFrame(Cloc, columns=colnames[:Cloc.shape[1]])
|
| 188 |
+
# Z = pd.concat([locs, Z], axis=1)
|
| 189 |
+
# Z.to_csv(f"{savedir}/kms_000.csv", index=False)
|
| 190 |
|
| 191 |
+
with io.BytesIO() as f:
|
| 192 |
+
df.to_csv(f, index=False)
|
| 193 |
+
yield f.getvalue()
|
| 194 |
|
| 195 |
+
# # dfcols = pd.concat(dfcols, axis=1)
|
| 196 |
+
# # dfcols = pd.concat([Z, dfcols], axis=1)
|
| 197 |
|
| 198 |
# Combine into a shiny app.
|
| 199 |
# Note that the variable must be "app".
|
data/test-data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:053c897211e68701fc2930f66e30c30d319c9858ae8f89a5293999a86da1c1c8
|
| 3 |
+
size 35225
|
data/weights/h1_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1350b39d48c739af0bb100b777c5ecab4e2ba5d2da9ec0020988d265aee9fecd
|
| 3 |
+
size 320
|
data/weights/h1_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a9a3f67afdfe9ed2f2f08b6e71872c215865b4c61405115a79c7e26c9688111
|
| 3 |
+
size 492319
|
data/weights/h1_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6898b26aafd53d108e078214fed89eae06caa5294bbda98b4b931d5d8cdffafc
|
| 3 |
+
size 123
|
data/weights/h3_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6852ae1a0b4514ee2391bb10e4e57563faeb252b447c672390dccb0551f40167
|
| 3 |
+
size 320
|
data/weights/h3_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a81e01e7ce2048ebb19f9e8446cf626aefacf334539687ff8f488ca9483e98e
|
| 3 |
+
size 492319
|
data/weights/h3_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24991bdccad89102340b4f77f062dd1db046055efacdba6cecc507c2612525f4
|
| 3 |
+
size 22518
|
data/weights/h5_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea17a68d758e186172b873db76bf18c27eccbb32fb60b21b112603b75cd9cca0
|
| 3 |
+
size 320
|
data/weights/h5_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88611910163b4a9ddb7324a688c6b964321d821629250d38ef12624f4204783e
|
| 3 |
+
size 492319
|
data/weights/h5_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:525e8d69818e41e807220e6eead955704a31353dc6153ed8b964eab4e7e51f5b
|
| 3 |
+
size 22519
|
data/weights/h7_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c92d6f6fdd8f0e4fa19bfb7257e2fd3b3319bd418895d79dbfd50ff1365d13f1
|
| 3 |
+
size 320
|
data/weights/h7_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:150fdad043c8a9bf3c4548b97b019216c02dd241f4a267d544af1f6464b0856a
|
| 3 |
+
size 492319
|
data/weights/h7_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37f45005ff95ea52753df4b6771fdaae9c1320d84817b89e4d835d5f1ea7a70e
|
| 3 |
+
size 22520
|
data/weights/h9_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24098fc0d0c99e75b09345d27af6793b438383010d8dc56fabea5857769f8cea
|
| 3 |
+
size 320
|
data/weights/h9_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e39a8c1a079f68c7bb4ef90aa7cdd5b797474c582670f730d4d371e389c0ebf
|
| 3 |
+
size 492319
|
data/weights/h9_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6037749e8582a3954d0701ee8212497f8e2c75ab6764ef9f2e20614a306d2bcd
|
| 3 |
+
size 22522
|
data/weights/r1_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f438e42212b15d6eff79ef4fef2fa4e8a7af19567b24dda0ff7189dccddb4f83
|
| 3 |
+
size 298
|
data/weights/r1_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77c8865887c1eec2a710c7aa8f54977564b0e76276e427497b64d6122926ced9
|
| 3 |
+
size 4866847
|
data/weights/r1_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7da97ac5558d1c2d220fed659d64b6e8c83041ee53e129802a1aa17ba8dd6a4d
|
| 3 |
+
size 22003
|
data/weights/r3_local/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb68f0f9f23c5b0e9465357a032516f3d82da4d1e654db9e7a69fe311c802dd9
|
| 3 |
+
size 320
|
data/weights/r3_local/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:63e59b43d2633c39cb0b0906d052291b918a6c57d879452c956e1b13cd50c824
|
| 3 |
+
size 35323
|
data/weights/r3_local/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bfe90fc358e1068101c5a089be7c7ea6c7716e1fdcaefa221906185244f9b87d
|
| 3 |
+
size 21634
|
data/weights/r3_nbrs/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b121a5683b604ed17120febdf94f66f628fcf5e1df818ef3a12f5712bb402751
|
| 3 |
+
size 320
|
data/weights/r3_nbrs/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:015b4cbe386bd17782869275d5ff1db5ecc9b45a685600a81e617bd9795861ce
|
| 3 |
+
size 35323
|
data/weights/r3_nbrs/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24e28b6418688b93c2e10854ee2f95d66292efd99a19a01c0b2ab0f16d22d43b
|
| 3 |
+
size 21633
|
data/weights/r3_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d3471f6a8bbb72c9eac7f674601969edf9ae5830bf75556604215ad11f8905a
|
| 3 |
+
size 298
|
data/weights/r3_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1dd70f70e737d7ad349f0cde7661d509bd593f146a6de08710ef585376d37f13
|
| 3 |
+
size 4866847
|
data/weights/r3_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b73925d08e1f71f7333765fa5109b77e41e31f6ac2d35ff0b4351c83145bc321
|
| 3 |
+
size 22139
|
data/weights/r5_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3232f7d47d486442c966a5ede893ce491146c91c12183e41ad1502aaa794f94e
|
| 3 |
+
size 321
|
data/weights/r5_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:232950462057e8844d6cd2bba6ac9f1d875309d7d17277aba3645174f0768cd3
|
| 3 |
+
size 4866847
|
data/weights/r5_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:695093769fefcccc7a42b38ef0c1b6a032945aafbd50e4809115cd55d4195ac7
|
| 3 |
+
size 198
|
data/weights/r7_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1e0224de6018d8ddb966f853fc8c08e5d8eef5b7057b7542d82a9ae4befca40
|
| 3 |
+
size 298
|
data/weights/r7_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14fbfaeb91e6066f59c894638fe7ba5ca9d019298d5ba80e29d7584dab0afee6
|
| 3 |
+
size 4866847
|
data/weights/r7_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f731ec381f0dea7d7fe99936aad0dceb4a442af251093c78116c2889ae1db64
|
| 3 |
+
size 22245
|
data/weights/r9_local/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f7b79f12cbbf692de41095bff5cfeab906e97efe18343c74125e91748812ca3
|
| 3 |
+
size 320
|
data/weights/r9_local/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1bf0fa150c7cb85123497d5cf784386f7351665b196c635b26856ae044d52151
|
| 3 |
+
size 35323
|
data/weights/r9_local/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:973371a1053fd9a84b93188111cb732ae6610abaf3dae65aaac31b27e96db7b1
|
| 3 |
+
size 21912
|
data/weights/r9_nbrs/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:66acb2a59e43e975f09dd0acfaebf214ef822dfb4c2e04df71946988f0a1ba63
|
| 3 |
+
size 320
|
data/weights/r9_nbrs/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a515d0cdc564b3db7c2357ea8a0b1495fd482dd83d4ebd407616c4a65f9d8cbf
|
| 3 |
+
size 35323
|
data/weights/r9_nbrs/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2fb3010bc18246eb2eb81e9e670583a507d4497601e2939bf94825e868a0484c
|
| 3 |
+
size 21912
|
data/weights/r9_w2vec/args.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:373b05bda0a2c96a04500f850a143a6bc164c16744f6c1efad3bbfe0fc2a3261
|
| 3 |
+
size 298
|
data/weights/r9_w2vec/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3a53428d08e1f68cffdb78f6499d8fae7d42f7878f65d4e56ce6f0eeb37fa3a
|
| 3 |
+
size 4866847
|
data/weights/r9_w2vec/train.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf2141e8f42e65470cd391b93f3dfa9ce8d82b6f30cb27eb4b8cbb60d3744e83
|
| 3 |
+
size 22279
|
models.py
CHANGED
|
@@ -3,6 +3,7 @@ from torch import nn
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch import Tensor
|
| 5 |
from typing import Optional, List
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class LayerNorm(nn.Module):
|
|
@@ -303,6 +304,11 @@ class UNetEncoder(nn.Module):
|
|
| 303 |
x = self.final(x)
|
| 304 |
|
| 305 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
class Decoder(nn.Module):
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch import Tensor
|
| 5 |
from typing import Optional, List
|
| 6 |
+
from timm.models.layers import trunc_normal_
|
| 7 |
|
| 8 |
|
| 9 |
class LayerNorm(nn.Module):
|
|
|
|
| 304 |
x = self.final(x)
|
| 305 |
|
| 306 |
return x
|
| 307 |
+
|
| 308 |
+
def _init_weights(self, m):
|
| 309 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 310 |
+
trunc_normal_(m.weight, std=.02)
|
| 311 |
+
nn.init.constant_(m.bias, 0)
|
| 312 |
|
| 313 |
|
| 314 |
class Decoder(nn.Module):
|
requirements.txt
CHANGED
|
@@ -6,3 +6,4 @@ torch
|
|
| 6 |
numpy
|
| 7 |
shiny
|
| 8 |
shinyswatch
|
|
|
|
|
|
| 6 |
numpy
|
| 7 |
shiny
|
| 8 |
shinyswatch
|
| 9 |
+
timm
|
utils.py
CHANGED
|
@@ -1,5 +1,11 @@
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import pickle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
def load_training_data(
|
|
@@ -50,3 +56,55 @@ def load_training_data(
|
|
| 50 |
return C, names, Y, M
|
| 51 |
else:
|
| 52 |
return C, names, Y, M, data["pp_locs"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
import numpy as np
|
| 3 |
import pickle
|
| 4 |
+
import os
|
| 5 |
+
import yaml
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from models import UNetEncoder, Decoder
|
| 9 |
|
| 10 |
|
| 11 |
def load_training_data(
|
|
|
|
| 56 |
return C, names, Y, M
|
| 57 |
else:
|
| 58 |
return C, names, Y, M, data["pp_locs"]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def radius_from_dir(s: str, prefix: str):
|
| 62 |
+
return int(s.split("/")[-1].split("_")[0].replace(prefix, ""))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_models(dirs: dict, prefix="h", nd=5):
|
| 66 |
+
D = {}
|
| 67 |
+
for name, datadir in dirs.items():
|
| 68 |
+
radius = radius_from_dir(datadir, prefix)
|
| 69 |
+
args = argparse.Namespace()
|
| 70 |
+
with open(os.path.join(datadir, "args.yaml"), "r") as io:
|
| 71 |
+
for k, v in yaml.load(io, Loader=yaml.FullLoader).items():
|
| 72 |
+
setattr(args, k, v)
|
| 73 |
+
if k == "nbrs_av":
|
| 74 |
+
setattr(args, "av_nbrs", v)
|
| 75 |
+
elif k == "av_nbrs":
|
| 76 |
+
setattr(args, "nbrs_av", v)
|
| 77 |
+
|
| 78 |
+
bn_type ="frn" if not hasattr(args, "bn_type") else args.bn_type
|
| 79 |
+
mkw = dict(
|
| 80 |
+
n_hidden=args.nhidden,
|
| 81 |
+
depth=args.depth,
|
| 82 |
+
num_res=args.nres,
|
| 83 |
+
ksize=args.ksize,
|
| 84 |
+
groups=args.groups,
|
| 85 |
+
batchnorm=True,
|
| 86 |
+
batchnorm_type=bn_type,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
dkw = dict(batchnorm=True, offset=True, batchnorm_type=bn_type)
|
| 90 |
+
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
| 91 |
+
if not args.local and args.nbrs_av == 0:
|
| 92 |
+
enc = UNetEncoder(nd, args.nhidden, **mkw).to(dev)
|
| 93 |
+
dec = Decoder(args.nhidden, nd, args.nhidden, **dkw).to(dev)
|
| 94 |
+
else:
|
| 95 |
+
enc = nn.Identity()
|
| 96 |
+
dec = Decoder(nd, nd, args.nhidden, **dkw).to(dev)
|
| 97 |
+
mod = nn.ModuleDict({"enc": enc, "dec": dec})
|
| 98 |
+
objs = dict(
|
| 99 |
+
mod=mod,
|
| 100 |
+
args=args,
|
| 101 |
+
radius=radius,
|
| 102 |
+
nbrs_av=args.nbrs_av,
|
| 103 |
+
local=args.local,
|
| 104 |
+
)
|
| 105 |
+
mod.eval()
|
| 106 |
+
for p in mod.parameters():
|
| 107 |
+
p.requires_grad = False
|
| 108 |
+
mod = mod.to(dev)
|
| 109 |
+
D[datadir] = objs
|
| 110 |
+
return D
|