ar
Browse files
app.py
CHANGED
|
@@ -121,13 +121,13 @@ state_dict = torch.load(checkpoint_path, map_location=device)
|
|
| 121 |
nnet_1.load_state_dict(state_dict)
|
| 122 |
nnet_1.eval()
|
| 123 |
|
| 124 |
-
filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
| 125 |
-
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 126 |
-
nnet_2 = utils.get_nnet(**config_2.nnet)
|
| 127 |
-
nnet_2 = nnet_2.to(device)
|
| 128 |
-
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 129 |
-
nnet_2.load_state_dict(state_dict)
|
| 130 |
-
nnet_2.eval()
|
| 131 |
|
| 132 |
# Initialize text model.
|
| 133 |
llm = "clip"
|
|
@@ -181,10 +181,11 @@ def infer(
|
|
| 181 |
else:
|
| 182 |
assert num_of_interpolation == 3, "For arithmetic, please sample three images."
|
| 183 |
|
| 184 |
-
if num_of_interpolation == 3:
|
| 185 |
-
|
| 186 |
-
else:
|
| 187 |
-
|
|
|
|
| 188 |
|
| 189 |
# Get text embeddings and tokens.
|
| 190 |
_context, _token_mask, _token, _caption = get_caption(
|
|
@@ -301,7 +302,7 @@ examples_1 = [
|
|
| 301 |
]
|
| 302 |
|
| 303 |
examples_2 = [
|
| 304 |
-
["A
|
| 305 |
]
|
| 306 |
|
| 307 |
css = """
|
|
|
|
| 121 |
nnet_1.load_state_dict(state_dict)
|
| 122 |
nnet_1.eval()
|
| 123 |
|
| 124 |
+
# filename = "pretrained_models/t2i_512px_clip_dimr.pth"
|
| 125 |
+
# checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 126 |
+
# nnet_2 = utils.get_nnet(**config_2.nnet)
|
| 127 |
+
# nnet_2 = nnet_2.to(device)
|
| 128 |
+
# state_dict = torch.load(checkpoint_path, map_location=device)
|
| 129 |
+
# nnet_2.load_state_dict(state_dict)
|
| 130 |
+
# nnet_2.eval()
|
| 131 |
|
| 132 |
# Initialize text model.
|
| 133 |
llm = "clip"
|
|
|
|
| 181 |
else:
|
| 182 |
assert num_of_interpolation == 3, "For arithmetic, please sample three images."
|
| 183 |
|
| 184 |
+
# if num_of_interpolation == 3:
|
| 185 |
+
# nnet = nnet_2
|
| 186 |
+
# else:
|
| 187 |
+
# nnet = nnet_1
|
| 188 |
+
nnet = nnet_1
|
| 189 |
|
| 190 |
# Get text embeddings and tokens.
|
| 191 |
_context, _token_mask, _token, _caption = get_caption(
|
|
|
|
| 302 |
]
|
| 303 |
|
| 304 |
examples_2 = [
|
| 305 |
+
["A dog wearing sunglasses", "red hat"],
|
| 306 |
]
|
| 307 |
|
| 308 |
css = """
|