Spaces:
Runtime error
Runtime error
Commit
Β·
d872920
1
Parent(s):
d6b8e04
Add model weights as submodule
Browse files- .gitmodules +3 -0
- galaxy-zoo-generation +1 -0
- src/app/compare_models.py +1 -16
- src/app/explore_biggan.py +1 -10
- src/app/explore_cvae.py +1 -9
- src/app/explore_infoscc_gan.py +1 -8
- src/app/interpolate_labels.py +1 -18
- src/app/params.py +4 -11
.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "galaxy-zoo-generation"]
|
| 2 |
+
path = galaxy-zoo-generation
|
| 3 |
+
url = https://huggingface.co/vitaliykinakh/galaxy-zoo-generation
|
galaxy-zoo-generation
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit af017a826b2be3dec5f364ac4e232ede6cc0e04f
|
src/app/compare_models.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
import math
|
| 3 |
|
| 4 |
import streamlit as st
|
|
@@ -14,7 +13,7 @@ from src.models import ConditionalGenerator as InfoSCC_GAN
|
|
| 14 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
| 15 |
from src.models import ConditionalDecoder as cVAE
|
| 16 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
| 17 |
-
from src.utils import
|
| 18 |
|
| 19 |
|
| 20 |
device = params.device
|
|
@@ -76,25 +75,14 @@ def load_model(model_type: str):
|
|
| 76 |
y_size=params.shape_label,
|
| 77 |
z_size=params.noise_dim)
|
| 78 |
|
| 79 |
-
if not Path(params.path_infoscc_gan).exists():
|
| 80 |
-
download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
|
| 81 |
-
|
| 82 |
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
| 83 |
g.load_state_dict(ckpt['g_ema'])
|
| 84 |
elif model_type == 'BigGAN':
|
| 85 |
g = BigGAN2Generator()
|
| 86 |
-
|
| 87 |
-
if not Path(params.path_biggan).exists():
|
| 88 |
-
download_file(params.drive_id_biggan, params.path_biggan)
|
| 89 |
-
|
| 90 |
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
| 91 |
g.load_state_dict(ckpt)
|
| 92 |
elif model_type == 'cVAE':
|
| 93 |
g = cVAE()
|
| 94 |
-
|
| 95 |
-
if not Path(params.path_cvae).exists():
|
| 96 |
-
download_file(params.drive_id_cvae, params.path_cvae)
|
| 97 |
-
|
| 98 |
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
| 99 |
g.load_state_dict(ckpt)
|
| 100 |
else:
|
|
@@ -107,9 +95,6 @@ def load_model(model_type: str):
|
|
| 107 |
def get_labels() -> torch.Tensor:
|
| 108 |
path_labels = params.path_labels
|
| 109 |
|
| 110 |
-
if not Path(path_labels).exists():
|
| 111 |
-
download_file(params.drive_id_labels, path_labels)
|
| 112 |
-
|
| 113 |
labels_train = get_labels_train(path_labels)
|
| 114 |
return labels_train
|
| 115 |
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
|
| 3 |
import streamlit as st
|
|
|
|
| 13 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
| 14 |
from src.models import ConditionalDecoder as cVAE
|
| 15 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
| 16 |
+
from src.utils import sample_labels
|
| 17 |
|
| 18 |
|
| 19 |
device = params.device
|
|
|
|
| 75 |
y_size=params.shape_label,
|
| 76 |
z_size=params.noise_dim)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
| 79 |
g.load_state_dict(ckpt['g_ema'])
|
| 80 |
elif model_type == 'BigGAN':
|
| 81 |
g = BigGAN2Generator()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
| 83 |
g.load_state_dict(ckpt)
|
| 84 |
elif model_type == 'cVAE':
|
| 85 |
g = cVAE()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
| 87 |
g.load_state_dict(ckpt)
|
| 88 |
else:
|
|
|
|
| 95 |
def get_labels() -> torch.Tensor:
|
| 96 |
path_labels = params.path_labels
|
| 97 |
|
|
|
|
|
|
|
|
|
|
| 98 |
labels_train = get_labels_train(path_labels)
|
| 99 |
return labels_train
|
| 100 |
|
src/app/explore_biggan.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import math
|
| 2 |
-
from pathlib import Path
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
import numpy as np
|
|
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
|
|
| 12 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
| 13 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
| 14 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
| 15 |
-
from src.utils import
|
| 16 |
|
| 17 |
|
| 18 |
# global parameters
|
|
@@ -25,7 +24,6 @@ dim_z = params.dim_z
|
|
| 25 |
bs = 16 # number of samples to generate
|
| 26 |
n_cols = int(math.sqrt(bs))
|
| 27 |
model_path = params.path_biggan
|
| 28 |
-
drive_id = params.drive_id_biggan
|
| 29 |
path_labels = params.path_labels
|
| 30 |
|
| 31 |
# manual labels
|
|
@@ -90,9 +88,6 @@ def get_eps(n: int) -> torch.Tensor:
|
|
| 90 |
|
| 91 |
@st.cache
|
| 92 |
def get_labels() -> torch.Tensor:
|
| 93 |
-
if not Path(path_labels).exists():
|
| 94 |
-
download_file(params.drive_id_labels, path_labels)
|
| 95 |
-
|
| 96 |
labels_train = get_labels_train(path_labels)
|
| 97 |
return labels_train
|
| 98 |
|
|
@@ -102,10 +97,6 @@ def app():
|
|
| 102 |
|
| 103 |
st.title('Explore BigGAN')
|
| 104 |
st.markdown('This demo shows BigGAN for conditional galaxy generation')
|
| 105 |
-
|
| 106 |
-
if not Path(model_path).exists():
|
| 107 |
-
download_file(drive_id, model_path)
|
| 108 |
-
|
| 109 |
model = load_model(model_path)
|
| 110 |
eps = get_eps(bs)
|
| 111 |
labels_train = get_labels()
|
|
|
|
| 1 |
import math
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
|
|
| 11 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
| 12 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
| 13 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
| 14 |
+
from src.utils import sample_labels
|
| 15 |
|
| 16 |
|
| 17 |
# global parameters
|
|
|
|
| 24 |
bs = 16 # number of samples to generate
|
| 25 |
n_cols = int(math.sqrt(bs))
|
| 26 |
model_path = params.path_biggan
|
|
|
|
| 27 |
path_labels = params.path_labels
|
| 28 |
|
| 29 |
# manual labels
|
|
|
|
| 88 |
|
| 89 |
@st.cache
|
| 90 |
def get_labels() -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
| 91 |
labels_train = get_labels_train(path_labels)
|
| 92 |
return labels_train
|
| 93 |
|
|
|
|
| 97 |
|
| 98 |
st.title('Explore BigGAN')
|
| 99 |
st.markdown('This demo shows BigGAN for conditional galaxy generation')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
model = load_model(model_path)
|
| 101 |
eps = get_eps(bs)
|
| 102 |
labels_train = get_labels()
|
src/app/explore_cvae.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import math
|
| 2 |
-
from pathlib import Path
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
import numpy as np
|
|
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
|
|
| 12 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
| 13 |
from src.models import ConditionalDecoder
|
| 14 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
| 15 |
-
from src.utils import
|
| 16 |
|
| 17 |
|
| 18 |
# global parameters
|
|
@@ -25,7 +24,6 @@ dim_z = params.dim_z
|
|
| 25 |
bs = 16 # number of samples to generate
|
| 26 |
n_cols = int(math.sqrt(bs))
|
| 27 |
model_path = params.path_cvae
|
| 28 |
-
drive_id = params.drive_id_cvae
|
| 29 |
path_labels = params.path_labels
|
| 30 |
|
| 31 |
# manual labels
|
|
@@ -90,9 +88,6 @@ def get_eps(n: int) -> torch.Tensor:
|
|
| 90 |
|
| 91 |
@st.cache
|
| 92 |
def get_labels() -> torch.Tensor:
|
| 93 |
-
if not Path(path_labels).exists():
|
| 94 |
-
download_file(params.drive_id_labels, path_labels)
|
| 95 |
-
|
| 96 |
labels_train = get_labels_train(path_labels)
|
| 97 |
return labels_train
|
| 98 |
|
|
@@ -103,9 +98,6 @@ def app():
|
|
| 103 |
st.title('Explore cVAE')
|
| 104 |
st.markdown('This demo shows cVAE for conditional galaxy generation')
|
| 105 |
|
| 106 |
-
if not Path(model_path).exists():
|
| 107 |
-
download_file(drive_id, model_path)
|
| 108 |
-
|
| 109 |
model = load_model(model_path)
|
| 110 |
eps = get_eps(bs)
|
| 111 |
labels_train = get_labels()
|
|
|
|
| 1 |
import math
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
|
|
| 11 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
| 12 |
from src.models import ConditionalDecoder
|
| 13 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
| 14 |
+
from src.utils import sample_labels
|
| 15 |
|
| 16 |
|
| 17 |
# global parameters
|
|
|
|
| 24 |
bs = 16 # number of samples to generate
|
| 25 |
n_cols = int(math.sqrt(bs))
|
| 26 |
model_path = params.path_cvae
|
|
|
|
| 27 |
path_labels = params.path_labels
|
| 28 |
|
| 29 |
# manual labels
|
|
|
|
| 88 |
|
| 89 |
@st.cache
|
| 90 |
def get_labels() -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
| 91 |
labels_train = get_labels_train(path_labels)
|
| 92 |
return labels_train
|
| 93 |
|
|
|
|
| 98 |
st.title('Explore cVAE')
|
| 99 |
st.markdown('This demo shows cVAE for conditional galaxy generation')
|
| 100 |
|
|
|
|
|
|
|
|
|
|
| 101 |
model = load_model(model_path)
|
| 102 |
eps = get_eps(bs)
|
| 103 |
labels_train = get_labels()
|
src/app/explore_infoscc_gan.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
import math
|
| 3 |
|
| 4 |
import numpy as np
|
|
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
|
|
| 12 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
| 13 |
from src.models import ConditionalGenerator
|
| 14 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
| 15 |
-
from src.utils import
|
| 16 |
|
| 17 |
# global parameters
|
| 18 |
device = params.device
|
|
@@ -27,7 +26,6 @@ y_type = params.y_type
|
|
| 27 |
bs = 16 # number of samples to generate
|
| 28 |
n_cols = int(math.sqrt(bs))
|
| 29 |
model_path = params.path_infoscc_gan # path to the model
|
| 30 |
-
drive_id = params.drive_id_infoscc_gan # google drive id of the model
|
| 31 |
path_labels = params.path_labels
|
| 32 |
|
| 33 |
# manual labels
|
|
@@ -87,8 +85,6 @@ def load_model(model_path: str) -> ConditionalGenerator:
|
|
| 87 |
|
| 88 |
@st.cache
|
| 89 |
def get_labels() -> torch.Tensor:
|
| 90 |
-
if not Path(path_labels).exists():
|
| 91 |
-
download_file(params.drive_id_labels, path_labels)
|
| 92 |
labels_train = get_labels_train(path_labels)
|
| 93 |
return labels_train
|
| 94 |
|
|
@@ -100,9 +96,6 @@ def app():
|
|
| 100 |
st.markdown('This demo shows InfoSCC-GAN for conditional galaxy generation')
|
| 101 |
st.subheader(r'<- Use sidebar to explore $z_1, ..., z_k$ latent variables')
|
| 102 |
|
| 103 |
-
if not Path(model_path).exists():
|
| 104 |
-
download_file(drive_id, model_path)
|
| 105 |
-
|
| 106 |
model = load_model(model_path)
|
| 107 |
eps = model.sample_eps(bs).to(device)
|
| 108 |
labels_train = get_labels()
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 11 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
| 12 |
from src.models import ConditionalGenerator
|
| 13 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
| 14 |
+
from src.utils import sample_labels
|
| 15 |
|
| 16 |
# global parameters
|
| 17 |
device = params.device
|
|
|
|
| 26 |
bs = 16 # number of samples to generate
|
| 27 |
n_cols = int(math.sqrt(bs))
|
| 28 |
model_path = params.path_infoscc_gan # path to the model
|
|
|
|
| 29 |
path_labels = params.path_labels
|
| 30 |
|
| 31 |
# manual labels
|
|
|
|
| 85 |
|
| 86 |
@st.cache
|
| 87 |
def get_labels() -> torch.Tensor:
|
|
|
|
|
|
|
| 88 |
labels_train = get_labels_train(path_labels)
|
| 89 |
return labels_train
|
| 90 |
|
|
|
|
| 96 |
st.markdown('This demo shows InfoSCC-GAN for conditional galaxy generation')
|
| 97 |
st.subheader(r'<- Use sidebar to explore $z_1, ..., z_k$ latent variables')
|
| 98 |
|
|
|
|
|
|
|
|
|
|
| 99 |
model = load_model(model_path)
|
| 100 |
eps = model.sample_eps(bs).to(device)
|
| 101 |
labels_train = get_labels()
|
src/app/interpolate_labels.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
import math
|
| 3 |
|
| 4 |
import numpy as np
|
|
@@ -12,7 +11,7 @@ from src.models import ConditionalGenerator as InfoSCC_GAN
|
|
| 12 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
| 13 |
from src.models import ConditionalDecoder as cVAE
|
| 14 |
from src.data import get_labels_train
|
| 15 |
-
from src.utils import
|
| 16 |
|
| 17 |
|
| 18 |
device = params.device
|
|
@@ -31,26 +30,14 @@ def load_model(model_type: str):
|
|
| 31 |
g = InfoSCC_GAN(size=params.size,
|
| 32 |
y_size=params.shape_label,
|
| 33 |
z_size=params.noise_dim)
|
| 34 |
-
|
| 35 |
-
if not Path(params.path_infoscc_gan).exists():
|
| 36 |
-
download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
|
| 37 |
-
|
| 38 |
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
| 39 |
g.load_state_dict(ckpt['g_ema'])
|
| 40 |
elif model_type == 'BigGAN':
|
| 41 |
g = BigGAN2Generator()
|
| 42 |
-
|
| 43 |
-
if not Path(params.path_biggan).exists():
|
| 44 |
-
download_file(params.drive_id_biggan, params.path_biggan)
|
| 45 |
-
|
| 46 |
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
| 47 |
g.load_state_dict(ckpt)
|
| 48 |
elif model_type == 'cVAE':
|
| 49 |
g = cVAE()
|
| 50 |
-
|
| 51 |
-
if not Path(params.path_cvae).exists():
|
| 52 |
-
download_file(params.drive_id_cvae, params.path_cvae)
|
| 53 |
-
|
| 54 |
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
| 55 |
g.load_state_dict(ckpt)
|
| 56 |
else:
|
|
@@ -62,10 +49,6 @@ def load_model(model_type: str):
|
|
| 62 |
@st.cache
|
| 63 |
def get_labels() -> torch.Tensor:
|
| 64 |
path_labels = params.path_labels
|
| 65 |
-
|
| 66 |
-
if not Path(path_labels).exists():
|
| 67 |
-
download_file(params.drive_id_labels, path_labels)
|
| 68 |
-
|
| 69 |
labels_train = get_labels_train(path_labels)
|
| 70 |
return labels_train
|
| 71 |
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 11 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
| 12 |
from src.models import ConditionalDecoder as cVAE
|
| 13 |
from src.data import get_labels_train
|
| 14 |
+
from src.utils import sample_labels
|
| 15 |
|
| 16 |
|
| 17 |
device = params.device
|
|
|
|
| 30 |
g = InfoSCC_GAN(size=params.size,
|
| 31 |
y_size=params.shape_label,
|
| 32 |
z_size=params.noise_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
| 34 |
g.load_state_dict(ckpt['g_ema'])
|
| 35 |
elif model_type == 'BigGAN':
|
| 36 |
g = BigGAN2Generator()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
| 38 |
g.load_state_dict(ckpt)
|
| 39 |
elif model_type == 'cVAE':
|
| 40 |
g = cVAE()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
| 42 |
g.load_state_dict(ckpt)
|
| 43 |
else:
|
|
|
|
| 49 |
@st.cache
|
| 50 |
def get_labels() -> torch.Tensor:
|
| 51 |
path_labels = params.path_labels
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
labels_train = get_labels_train(path_labels)
|
| 53 |
return labels_train
|
| 54 |
|
src/app/params.py
CHANGED
|
@@ -12,14 +12,7 @@ n_basis = 6 # size of additional z vectors in InfoSCC-GAN
|
|
| 12 |
y_type = 'real' # type of labels in InfoSCC-GAN
|
| 13 |
dim_z = 128 # z vector size in BigGAN and cVAE
|
| 14 |
|
| 15 |
-
path_infoscc_gan = './models/InfoSCC-GAN/generator.pt'
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
drive_id_biggan = '1sMSDdnQ5GjHcno5knHTDSKAKhhoHh_4z'
|
| 20 |
-
|
| 21 |
-
path_cvae = './models/CVAE/generator.pth'
|
| 22 |
-
drive_id_cvae = '17FmLvhwXq8PQMrD1CtjqyoAy5BobYMTE'
|
| 23 |
-
|
| 24 |
-
path_labels = './data/training_solutions_rev1.csv'
|
| 25 |
-
drive_id_labels = '1dzsB_HdGtmSHE4pCppamISpFaJBfPF7E'
|
|
|
|
| 12 |
y_type = 'real' # type of labels in InfoSCC-GAN
|
| 13 |
dim_z = 128 # z vector size in BigGAN and cVAE
|
| 14 |
|
| 15 |
+
path_infoscc_gan = './galaxy-zoo-generation/models/InfoSCC-GAN/generator.pt'
|
| 16 |
+
path_biggan = './galaxy-zoo-generation/models/BigGAN/generator.pth'
|
| 17 |
+
path_cvae = './galaxy-zoo-generation/models/CVAE/generator.pth'
|
| 18 |
+
path_labels = './galaxy-zoo-generation/data/training_solutions_rev1.csv'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|