Spaces:
Running
on
Zero
Running
on
Zero
fix eig norm
Browse files
app.py
CHANGED
|
@@ -870,6 +870,7 @@ def ncut_run(
|
|
| 870 |
if not directed:
|
| 871 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
| 872 |
return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False)
|
|
|
|
| 873 |
|
| 874 |
rgb, _logging_str, eigvecs = compute_ncut(
|
| 875 |
features,
|
|
@@ -893,12 +894,16 @@ def ncut_run(
|
|
| 893 |
|
| 894 |
|
| 895 |
if only_eigvecs:
|
|
|
|
|
|
|
| 896 |
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
| 897 |
eigvecs = eigvecs.detach().numpy()
|
| 898 |
logging_str += _logging_str
|
| 899 |
return eigvecs, logging_str
|
| 900 |
|
| 901 |
if return_eigvec_and_rgb:
|
|
|
|
|
|
|
| 902 |
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
| 903 |
eigvecs = eigvecs.detach().numpy()
|
| 904 |
rgb = rgb.cpu().numpy()
|
|
@@ -1249,6 +1254,7 @@ def run_fn(
|
|
| 1249 |
directed=False,
|
| 1250 |
only_eigvecs=False,
|
| 1251 |
return_eigvec_and_rgb=False,
|
|
|
|
| 1252 |
):
|
| 1253 |
# print(node_type2, head_index_text, make_symmetric)
|
| 1254 |
progress=gr.Progress()
|
|
@@ -1391,6 +1397,7 @@ def run_fn(
|
|
| 1391 |
"make_symmetric": make_symmetric,
|
| 1392 |
"only_eigvecs": only_eigvecs,
|
| 1393 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
|
|
|
| 1394 |
}
|
| 1395 |
# print(kwargs)
|
| 1396 |
|
|
@@ -2232,16 +2239,16 @@ with demo:
|
|
| 2232 |
def __run_fn(*args, **kwargs):
|
| 2233 |
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
| 2234 |
rgb_gallery = to_pil_images(rgb)
|
| 2235 |
-
# normalize the eigvecs
|
| 2236 |
-
eigvecs = torch.tensor(eigvecs)
|
| 2237 |
-
if torch.cuda.is_available():
|
| 2238 |
-
|
| 2239 |
-
eigvecs = F.normalize(eigvecs, p=2, dim=-1)
|
| 2240 |
-
eigvecs = eigvecs.cpu().numpy()
|
| 2241 |
return eigvecs, rgb, rgb_gallery, logging_str
|
| 2242 |
|
| 2243 |
submit_button.click(
|
| 2244 |
-
partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True),
|
| 2245 |
inputs=[
|
| 2246 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 2247 |
positive_prompt, negative_prompt,
|
|
|
|
| 870 |
if not directed:
|
| 871 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
| 872 |
return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False)
|
| 873 |
+
normalize_eigvec_return = kwargs.get("normalize_eigvec_return", False)
|
| 874 |
|
| 875 |
rgb, _logging_str, eigvecs = compute_ncut(
|
| 876 |
features,
|
|
|
|
| 894 |
|
| 895 |
|
| 896 |
if only_eigvecs:
|
| 897 |
+
if normalize_eigvec_return:
|
| 898 |
+
eigvecs = F.normalize(eigvecs, dim=-1)
|
| 899 |
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
| 900 |
eigvecs = eigvecs.detach().numpy()
|
| 901 |
logging_str += _logging_str
|
| 902 |
return eigvecs, logging_str
|
| 903 |
|
| 904 |
if return_eigvec_and_rgb:
|
| 905 |
+
if normalize_eigvec_return:
|
| 906 |
+
eigvecs = F.normalize(eigvecs, dim=-1)
|
| 907 |
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
| 908 |
eigvecs = eigvecs.detach().numpy()
|
| 909 |
rgb = rgb.cpu().numpy()
|
|
|
|
| 1254 |
directed=False,
|
| 1255 |
only_eigvecs=False,
|
| 1256 |
return_eigvec_and_rgb=False,
|
| 1257 |
+
normalize_eigvec_return=False,
|
| 1258 |
):
|
| 1259 |
# print(node_type2, head_index_text, make_symmetric)
|
| 1260 |
progress=gr.Progress()
|
|
|
|
| 1397 |
"make_symmetric": make_symmetric,
|
| 1398 |
"only_eigvecs": only_eigvecs,
|
| 1399 |
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
| 1400 |
+
"normalize_eigvec_return": normalize_eigvec_return,
|
| 1401 |
}
|
| 1402 |
# print(kwargs)
|
| 1403 |
|
|
|
|
| 2239 |
def __run_fn(*args, **kwargs):
|
| 2240 |
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
| 2241 |
rgb_gallery = to_pil_images(rgb)
|
| 2242 |
+
# # normalize the eigvecs
|
| 2243 |
+
# eigvecs = torch.tensor(eigvecs)
|
| 2244 |
+
# if torch.cuda.is_available():
|
| 2245 |
+
# eigvecs = eigvecs.cuda()
|
| 2246 |
+
# eigvecs = F.normalize(eigvecs, p=2, dim=-1)
|
| 2247 |
+
# eigvecs = eigvecs.cpu().numpy()
|
| 2248 |
return eigvecs, rgb, rgb_gallery, logging_str
|
| 2249 |
|
| 2250 |
submit_button.click(
|
| 2251 |
+
partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True, normalize_eigvec_return=True),
|
| 2252 |
inputs=[
|
| 2253 |
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 2254 |
positive_prompt, negative_prompt,
|