mauriciogtec commited on
Commit
6f47252
·
1 Parent(s): 3294cf8

working example

Browse files
Files changed (49) hide show
  1. .devcontainer/devcontainer.json +1 -1
  2. .gitattributes +14 -0
  3. app.py +125 -21
  4. data/test-data.csv +3 -0
  5. data/weights/h1_w2vec/args.yaml +3 -0
  6. data/weights/h1_w2vec/model.pt +3 -0
  7. data/weights/h1_w2vec/train.log +3 -0
  8. data/weights/h3_w2vec/args.yaml +3 -0
  9. data/weights/h3_w2vec/model.pt +3 -0
  10. data/weights/h3_w2vec/train.log +3 -0
  11. data/weights/h5_w2vec/args.yaml +3 -0
  12. data/weights/h5_w2vec/model.pt +3 -0
  13. data/weights/h5_w2vec/train.log +3 -0
  14. data/weights/h7_w2vec/args.yaml +3 -0
  15. data/weights/h7_w2vec/model.pt +3 -0
  16. data/weights/h7_w2vec/train.log +3 -0
  17. data/weights/h9_w2vec/args.yaml +3 -0
  18. data/weights/h9_w2vec/model.pt +3 -0
  19. data/weights/h9_w2vec/train.log +3 -0
  20. data/weights/r1_w2vec/args.yaml +3 -0
  21. data/weights/r1_w2vec/model.pt +3 -0
  22. data/weights/r1_w2vec/train.log +3 -0
  23. data/weights/r3_local/args.yaml +3 -0
  24. data/weights/r3_local/model.pt +3 -0
  25. data/weights/r3_local/train.log +3 -0
  26. data/weights/r3_nbrs/args.yaml +3 -0
  27. data/weights/r3_nbrs/model.pt +3 -0
  28. data/weights/r3_nbrs/train.log +3 -0
  29. data/weights/r3_w2vec/args.yaml +3 -0
  30. data/weights/r3_w2vec/model.pt +3 -0
  31. data/weights/r3_w2vec/train.log +3 -0
  32. data/weights/r5_w2vec/args.yaml +3 -0
  33. data/weights/r5_w2vec/model.pt +3 -0
  34. data/weights/r5_w2vec/train.log +3 -0
  35. data/weights/r7_w2vec/args.yaml +3 -0
  36. data/weights/r7_w2vec/model.pt +3 -0
  37. data/weights/r7_w2vec/train.log +3 -0
  38. data/weights/r9_local/args.yaml +3 -0
  39. data/weights/r9_local/model.pt +3 -0
  40. data/weights/r9_local/train.log +3 -0
  41. data/weights/r9_nbrs/args.yaml +3 -0
  42. data/weights/r9_nbrs/model.pt +3 -0
  43. data/weights/r9_nbrs/train.log +3 -0
  44. data/weights/r9_w2vec/args.yaml +3 -0
  45. data/weights/r9_w2vec/model.pt +3 -0
  46. data/weights/r9_w2vec/train.log +3 -0
  47. models.py +6 -0
  48. requirements.txt +1 -0
  49. utils.py +58 -0
.devcontainer/devcontainer.json CHANGED
@@ -21,7 +21,7 @@
21
  // "features": {},
22
 
23
  // Use 'forwardPorts' to make a list of ports inside the container available locally.
24
- "forwardPorts": [7860],
25
 
26
  // Uncomment the next line to run commands after the container is created.
27
  // "postCreateCommand": "cat /etc/os-release",
 
21
  // "features": {},
22
 
23
  // Use 'forwardPorts' to make a list of ports inside the container available locally.
24
+ // "forwardPorts": [],
25
 
26
  // Uncomment the next line to run commands after the container is created.
27
  // "postCreateCommand": "cat /etc/os-release",
.gitattributes CHANGED
@@ -37,3 +37,17 @@ data/training_data.pkl filter=lfs diff=lfs merge=lfs -text
37
  *.pkl filter=lfs diff=lfs merge=lfs -text
38
  **/*.pkl filter=lfs diff=lfs merge=lfs -text
39
  data filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  *.pkl filter=lfs diff=lfs merge=lfs -text
38
  **/*.pkl filter=lfs diff=lfs merge=lfs -text
39
  data filter=lfs diff=lfs merge=lfs -text
40
+ data/weights/h1_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
41
+ data/weights/h7_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
42
+ data/weights/r3_local/model.pt filter=lfs diff=lfs merge=lfs -text
43
+ data/weights/r3_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
44
+ data/weights/r5_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
45
+ data/weights/r9_local/model.pt filter=lfs diff=lfs merge=lfs -text
46
+ data/weights/r9_nbrs/model.pt filter=lfs diff=lfs merge=lfs -text
47
+ data/weights/r9_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
48
+ data/weights/h3_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
49
+ data/weights/r7_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
50
+ data/weights/r1_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
51
+ data/weights/r3_nbrs/model.pt filter=lfs diff=lfs merge=lfs -text
52
+ data/weights/h5_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
53
+ data/weights/h9_w2vec/model.pt filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,4 +1,5 @@
1
- from shiny import App, ui, render
 
2
  import shinyswatch
3
 
4
  import torch
@@ -12,10 +13,9 @@ from collections import defaultdict
12
  from tqdm import tqdm
13
  import itertools as it
14
  from torch import nn
 
15
 
16
- import pickle
17
- from models import UNetEncoder, Decoder
18
- from utils import load_training_data
19
 
20
  MONTHS= {
21
  0: "Jan",
@@ -61,35 +61,139 @@ C, NAMES, Y, M = load_training_data(
61
  )
62
  _, _, YRAW, MRAW = load_training_data(path="data/training_data.pkl")
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # Part 1: ui ----
66
  app_ui = ui.page_fluid(
67
  shinyswatch.theme.minty(),
68
- ui.panel_title("Generate Weather2vec Embeddings"),
69
- ui.panel_sidebar(
70
- ui.input_file("df", "Upload CSV File", accept=".csv"),
71
- ui.input_checkbox_group("months", "Select Months", MONTHS),
72
- ui.input_checkbox_group("months", "Select Years", YEARS),
73
- ui.input_checkbox_group("months", "Select Resolutions", RESOLUTIONS),
74
- width=6,
75
- ),
76
- ui.panel_main(
77
- ui.output_plot("plot"),
78
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  )
80
 
81
 
82
  # Part 2: server ----
83
  def server(input, output, session):
84
  # make a plot
85
- @output
86
- @render.plot
87
- def plot():
88
- fig, ax = plt.subplots()
89
- ax.plot(C[0, 0, 0])
90
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
 
92
 
 
 
93
 
94
  # Combine into a shiny app.
95
  # Note that the variable must be "app".
 
1
+ from shiny import App, ui, render, reactive
2
+ from shiny.ui import HTML, tags
3
  import shinyswatch
4
 
5
  import torch
 
13
  from tqdm import tqdm
14
  import itertools as it
15
  from torch import nn
16
+ import io
17
 
18
+ from utils import load_training_data, load_models
 
 
19
 
20
  MONTHS= {
21
  0: "Jan",
 
61
  )
62
  _, _, YRAW, MRAW = load_training_data(path="data/training_data.pkl")
63
 
64
+ prefix = "h"
65
+ dirs = {
66
+ "r1": f"./data/weights/{prefix}1_w2vec",
67
+ "r3": f"./data/weights/{prefix}3_w2vec",
68
+ "r5": f"./data/weights/{prefix}5_w2vec",
69
+ "r7": f"./data/weights/{prefix}7_w2vec",
70
+ "r9": f"./data/weights/{prefix}9_w2vec",
71
+ }
72
+ MODELS = load_models(dirs, prefix=prefix, nd=5)
73
+
74
+
75
+ multicol_html = tags.head(
76
+ tags.style(
77
+ HTML(
78
+ ".multicol {"
79
+ # "height: 150px; "
80
+ "-webkit-column-count: 3;" # chrome, safari, opera
81
+ "-moz-column-count: 3;" # firefox
82
+ "column-count: 3;"
83
+ "-moz-column-fill: auto;"
84
+ "-column-fill: auto;"
85
+ )
86
+ )
87
+ )
88
+
89
+ instructions = f"""
90
+ ### Instructions
91
+
92
+ Upload a CSV file with columns (id, lat, lon) using the `Browse` button on the sidebar.
93
+ Below is an example of the contents of the file:
94
+
95
+
96
+ ```
97
+ id,lat,lon
98
+ 0,47.5,-122.5
99
+ 1,47.5,-122.25
100
+ 2,47.5,-122.0
101
+ 3,47.5,-121.75
102
+ 4,47.5,-121.5
103
+ ```
104
+
105
+
106
+ The id column can be any identifier, or the column can be ommited, in which case the row number will be used as the id.
107
+ Make sure that the latitude is before the longitude column in the CSV file. The valid range for latitude is
108
+ {YMIN} to {YMAX} and longitude is {XMIN} to {XMAX}, which cover the contiguous United States.
109
+
110
+ The resolution corresponds to how much neighboring information is captured by the embedding. If `local` is selected,
111
+ the original weather covariates will be returned. Currently, all the embeddings correspond to the variables:
112
+ air temperature (2m), precipitation, relative humidity (2m), vertical wind speed (10m), and horizontal wind speed (10m).
113
+ The native resolution of the covariates is ~32 km for a grid size of 128 x 256.
114
+
115
+ ### Results
116
+
117
+ """
118
+
119
+ # After uploading the file, the app will generate a CSV, a download link will appear here.
120
+ # The CSV will contain the following columns:
121
+
122
 
123
  # Part 1: ui ----
124
  app_ui = ui.page_fluid(
125
  shinyswatch.theme.minty(),
126
+ multicol_html,
127
+ ui.panel_title("Welcome to the Weather2vec Embedding Generator!"),
128
+ ui.layout_sidebar(
129
+ ui.panel_sidebar(
130
+ ui.input_file("df", "Upload CSV File", accept=".csv"),
131
+ tags.div(
132
+ ui.input_checkbox_group("months", HTML("<b>Months</b>"), MONTHS),
133
+ class_="multicol",
134
+ align="left",
135
+ inline=False
136
+ ),
137
+ HTML("<b>Note:</b> Embedding of multiple months will be added.<br>True multi-temporal embeddings will be supported in the future.<br><br>"),
138
+ tags.div(
139
+ ui.input_radio_buttons("years", HTML("<b>Year</b>"), YEARS),
140
+ class_="multicol",
141
+ align="left",
142
+ inline=False
143
+ ),
144
+ HTML("<br>"),
145
+ tags.div(
146
+ ui.input_radio_buttons("resolutions", HTML("<b>Resolution</b>"), RESOLUTIONS),
147
+ class_="multicol",
148
+ align="left",
149
+ inline=False
150
+ ),
151
+ width=3,
152
+ ),
153
+ ui.panel_main(
154
+ ui.markdown(instructions),
155
+ ui.download_button("download", "Download Embeddings"),
156
+ ),
157
+ )
158
  )
159
 
160
 
161
  # Part 2: server ----
162
  def server(input, output, session):
163
  # make a plot
164
+ @session.download(filename="embeddings.csv")
165
+ def download():
166
+ # read input file
167
+ print(input.df()[-1].keys())
168
+ fname = input.df()[-1]['datapath']
169
+ df = pd.read_csv(fname)
170
+
171
+ # dfcols = []
172
+ # for k, v in D.items():
173
+ # mod = v["mod"]
174
+ # mod.load_state_dict(torch.load(os.path.join(k, "model.pt")))
175
+ # with torch.no_grad():
176
+ # Z = mod["enc"](Ct)
177
+ # Z = Z.mean(0).cpu().numpy()
178
+ # Zmat = Z[:, row, col].T
179
+ # colnames = [f"C{i:02d}" for i in range(Zmat.shape[-1])]
180
+ # Z = pd.DataFrame(Zmat, columns=colnames)
181
+ # Z = pd.DataFrame(Zmat, columns=[x + f"_{len(dfcols)}" for x in colnames])
182
+ # dfcols.append(Z)
183
+ # Z = pd.concat([locs, Z], axis=1)
184
+ # Z.to_csv(f"{savedir}/kms_{32 * int(v['radius']):03d}.csv", index=False)
185
+
186
+ # Cloc = Corig[ix].mean(0)[:, row, col].T
187
+ # Z = pd.DataFrame(Cloc, columns=colnames[:Cloc.shape[1]])
188
+ # Z = pd.concat([locs, Z], axis=1)
189
+ # Z.to_csv(f"{savedir}/kms_000.csv", index=False)
190
 
191
+ with io.BytesIO() as f:
192
+ df.to_csv(f, index=False)
193
+ yield f.getvalue()
194
 
195
+ # # dfcols = pd.concat(dfcols, axis=1)
196
+ # # dfcols = pd.concat([Z, dfcols], axis=1)
197
 
198
  # Combine into a shiny app.
199
  # Note that the variable must be "app".
data/test-data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:053c897211e68701fc2930f66e30c30d319c9858ae8f89a5293999a86da1c1c8
3
+ size 35225
data/weights/h1_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1350b39d48c739af0bb100b777c5ecab4e2ba5d2da9ec0020988d265aee9fecd
3
+ size 320
data/weights/h1_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a9a3f67afdfe9ed2f2f08b6e71872c215865b4c61405115a79c7e26c9688111
3
+ size 492319
data/weights/h1_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6898b26aafd53d108e078214fed89eae06caa5294bbda98b4b931d5d8cdffafc
3
+ size 123
data/weights/h3_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6852ae1a0b4514ee2391bb10e4e57563faeb252b447c672390dccb0551f40167
3
+ size 320
data/weights/h3_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a81e01e7ce2048ebb19f9e8446cf626aefacf334539687ff8f488ca9483e98e
3
+ size 492319
data/weights/h3_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24991bdccad89102340b4f77f062dd1db046055efacdba6cecc507c2612525f4
3
+ size 22518
data/weights/h5_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea17a68d758e186172b873db76bf18c27eccbb32fb60b21b112603b75cd9cca0
3
+ size 320
data/weights/h5_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88611910163b4a9ddb7324a688c6b964321d821629250d38ef12624f4204783e
3
+ size 492319
data/weights/h5_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:525e8d69818e41e807220e6eead955704a31353dc6153ed8b964eab4e7e51f5b
3
+ size 22519
data/weights/h7_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c92d6f6fdd8f0e4fa19bfb7257e2fd3b3319bd418895d79dbfd50ff1365d13f1
3
+ size 320
data/weights/h7_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:150fdad043c8a9bf3c4548b97b019216c02dd241f4a267d544af1f6464b0856a
3
+ size 492319
data/weights/h7_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37f45005ff95ea52753df4b6771fdaae9c1320d84817b89e4d835d5f1ea7a70e
3
+ size 22520
data/weights/h9_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24098fc0d0c99e75b09345d27af6793b438383010d8dc56fabea5857769f8cea
3
+ size 320
data/weights/h9_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e39a8c1a079f68c7bb4ef90aa7cdd5b797474c582670f730d4d371e389c0ebf
3
+ size 492319
data/weights/h9_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6037749e8582a3954d0701ee8212497f8e2c75ab6764ef9f2e20614a306d2bcd
3
+ size 22522
data/weights/r1_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f438e42212b15d6eff79ef4fef2fa4e8a7af19567b24dda0ff7189dccddb4f83
3
+ size 298
data/weights/r1_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77c8865887c1eec2a710c7aa8f54977564b0e76276e427497b64d6122926ced9
3
+ size 4866847
data/weights/r1_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7da97ac5558d1c2d220fed659d64b6e8c83041ee53e129802a1aa17ba8dd6a4d
3
+ size 22003
data/weights/r3_local/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb68f0f9f23c5b0e9465357a032516f3d82da4d1e654db9e7a69fe311c802dd9
3
+ size 320
data/weights/r3_local/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63e59b43d2633c39cb0b0906d052291b918a6c57d879452c956e1b13cd50c824
3
+ size 35323
data/weights/r3_local/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfe90fc358e1068101c5a089be7c7ea6c7716e1fdcaefa221906185244f9b87d
3
+ size 21634
data/weights/r3_nbrs/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b121a5683b604ed17120febdf94f66f628fcf5e1df818ef3a12f5712bb402751
3
+ size 320
data/weights/r3_nbrs/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:015b4cbe386bd17782869275d5ff1db5ecc9b45a685600a81e617bd9795861ce
3
+ size 35323
data/weights/r3_nbrs/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24e28b6418688b93c2e10854ee2f95d66292efd99a19a01c0b2ab0f16d22d43b
3
+ size 21633
data/weights/r3_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d3471f6a8bbb72c9eac7f674601969edf9ae5830bf75556604215ad11f8905a
3
+ size 298
data/weights/r3_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dd70f70e737d7ad349f0cde7661d509bd593f146a6de08710ef585376d37f13
3
+ size 4866847
data/weights/r3_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b73925d08e1f71f7333765fa5109b77e41e31f6ac2d35ff0b4351c83145bc321
3
+ size 22139
data/weights/r5_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3232f7d47d486442c966a5ede893ce491146c91c12183e41ad1502aaa794f94e
3
+ size 321
data/weights/r5_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:232950462057e8844d6cd2bba6ac9f1d875309d7d17277aba3645174f0768cd3
3
+ size 4866847
data/weights/r5_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:695093769fefcccc7a42b38ef0c1b6a032945aafbd50e4809115cd55d4195ac7
3
+ size 198
data/weights/r7_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1e0224de6018d8ddb966f853fc8c08e5d8eef5b7057b7542d82a9ae4befca40
3
+ size 298
data/weights/r7_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14fbfaeb91e6066f59c894638fe7ba5ca9d019298d5ba80e29d7584dab0afee6
3
+ size 4866847
data/weights/r7_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f731ec381f0dea7d7fe99936aad0dceb4a442af251093c78116c2889ae1db64
3
+ size 22245
data/weights/r9_local/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f7b79f12cbbf692de41095bff5cfeab906e97efe18343c74125e91748812ca3
3
+ size 320
data/weights/r9_local/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bf0fa150c7cb85123497d5cf784386f7351665b196c635b26856ae044d52151
3
+ size 35323
data/weights/r9_local/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:973371a1053fd9a84b93188111cb732ae6610abaf3dae65aaac31b27e96db7b1
3
+ size 21912
data/weights/r9_nbrs/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66acb2a59e43e975f09dd0acfaebf214ef822dfb4c2e04df71946988f0a1ba63
3
+ size 320
data/weights/r9_nbrs/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a515d0cdc564b3db7c2357ea8a0b1495fd482dd83d4ebd407616c4a65f9d8cbf
3
+ size 35323
data/weights/r9_nbrs/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fb3010bc18246eb2eb81e9e670583a507d4497601e2939bf94825e868a0484c
3
+ size 21912
data/weights/r9_w2vec/args.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:373b05bda0a2c96a04500f850a143a6bc164c16744f6c1efad3bbfe0fc2a3261
3
+ size 298
data/weights/r9_w2vec/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3a53428d08e1f68cffdb78f6499d8fae7d42f7878f65d4e56ce6f0eeb37fa3a
3
+ size 4866847
data/weights/r9_w2vec/train.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf2141e8f42e65470cd391b93f3dfa9ce8d82b6f30cb27eb4b8cbb60d3744e83
3
+ size 22279
models.py CHANGED
@@ -3,6 +3,7 @@ from torch import nn
3
  import torch.nn.functional as F
4
  from torch import Tensor
5
  from typing import Optional, List
 
6
 
7
 
8
  class LayerNorm(nn.Module):
@@ -303,6 +304,11 @@ class UNetEncoder(nn.Module):
303
  x = self.final(x)
304
 
305
  return x
 
 
 
 
 
306
 
307
 
308
  class Decoder(nn.Module):
 
3
  import torch.nn.functional as F
4
  from torch import Tensor
5
  from typing import Optional, List
6
+ from timm.models.layers import trunc_normal_
7
 
8
 
9
  class LayerNorm(nn.Module):
 
304
  x = self.final(x)
305
 
306
  return x
307
+
308
+ def _init_weights(self, m):
309
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
310
+ trunc_normal_(m.weight, std=.02)
311
+ nn.init.constant_(m.bias, 0)
312
 
313
 
314
  class Decoder(nn.Module):
requirements.txt CHANGED
@@ -6,3 +6,4 @@ torch
6
  numpy
7
  shiny
8
  shinyswatch
 
 
6
  numpy
7
  shiny
8
  shinyswatch
9
+ timm
utils.py CHANGED
@@ -1,5 +1,11 @@
 
1
  import numpy as np
2
  import pickle
 
 
 
 
 
3
 
4
 
5
  def load_training_data(
@@ -50,3 +56,55 @@ def load_training_data(
50
  return C, names, Y, M
51
  else:
52
  return C, names, Y, M, data["pp_locs"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
  import numpy as np
3
  import pickle
4
+ import os
5
+ import yaml
6
+ import torch
7
+ import torch.nn as nn
8
+ from models import UNetEncoder, Decoder
9
 
10
 
11
  def load_training_data(
 
56
  return C, names, Y, M
57
  else:
58
  return C, names, Y, M, data["pp_locs"]
59
+
60
+
61
+ def radius_from_dir(s: str, prefix: str):
62
+ return int(s.split("/")[-1].split("_")[0].replace(prefix, ""))
63
+
64
+
65
+ def load_models(dirs: dict, prefix="h", nd=5):
66
+ D = {}
67
+ for name, datadir in dirs.items():
68
+ radius = radius_from_dir(datadir, prefix)
69
+ args = argparse.Namespace()
70
+ with open(os.path.join(datadir, "args.yaml"), "r") as io:
71
+ for k, v in yaml.load(io, Loader=yaml.FullLoader).items():
72
+ setattr(args, k, v)
73
+ if k == "nbrs_av":
74
+ setattr(args, "av_nbrs", v)
75
+ elif k == "av_nbrs":
76
+ setattr(args, "nbrs_av", v)
77
+
78
+ bn_type ="frn" if not hasattr(args, "bn_type") else args.bn_type
79
+ mkw = dict(
80
+ n_hidden=args.nhidden,
81
+ depth=args.depth,
82
+ num_res=args.nres,
83
+ ksize=args.ksize,
84
+ groups=args.groups,
85
+ batchnorm=True,
86
+ batchnorm_type=bn_type,
87
+ )
88
+
89
+ dkw = dict(batchnorm=True, offset=True, batchnorm_type=bn_type)
90
+ dev = "cuda" if torch.cuda.is_available() else "cpu"
91
+ if not args.local and args.nbrs_av == 0:
92
+ enc = UNetEncoder(nd, args.nhidden, **mkw).to(dev)
93
+ dec = Decoder(args.nhidden, nd, args.nhidden, **dkw).to(dev)
94
+ else:
95
+ enc = nn.Identity()
96
+ dec = Decoder(nd, nd, args.nhidden, **dkw).to(dev)
97
+ mod = nn.ModuleDict({"enc": enc, "dec": dec})
98
+ objs = dict(
99
+ mod=mod,
100
+ args=args,
101
+ radius=radius,
102
+ nbrs_av=args.nbrs_av,
103
+ local=args.local,
104
+ )
105
+ mod.eval()
106
+ for p in mod.parameters():
107
+ p.requires_grad = False
108
+ mod = mod.to(dev)
109
+ D[datadir] = objs
110
+ return D