Spaces:
Running
on
Zero
Running
on
Zero
fix recursion gamma
Browse files
app.py
CHANGED
|
@@ -265,6 +265,7 @@ def ncut_run(
|
|
| 265 |
|
| 266 |
if recursion:
|
| 267 |
rgbs = []
|
|
|
|
| 268 |
inp = features
|
| 269 |
for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
|
| 270 |
logging_str += f"Recursion #{i+1}\n"
|
|
@@ -272,7 +273,7 @@ def ncut_run(
|
|
| 272 |
inp,
|
| 273 |
num_eig=n_eigs,
|
| 274 |
num_sample_ncut=num_sample_ncut,
|
| 275 |
-
affinity_focal_gamma=
|
| 276 |
knn_ncut=knn_ncut,
|
| 277 |
knn_tsne=knn_tsne,
|
| 278 |
num_sample_tsne=num_sample_tsne,
|
|
@@ -352,9 +353,14 @@ def ncut_run(
|
|
| 352 |
|
| 353 |
def _ncut_run(*args, **kwargs):
|
| 354 |
try:
|
|
|
|
|
|
|
|
|
|
| 355 |
ret = ncut_run(*args, **kwargs)
|
|
|
|
| 356 |
if torch.cuda.is_available():
|
| 357 |
torch.cuda.empty_cache()
|
|
|
|
| 358 |
return ret
|
| 359 |
except Exception as e:
|
| 360 |
gr.Error(str(e))
|
|
|
|
| 265 |
|
| 266 |
if recursion:
|
| 267 |
rgbs = []
|
| 268 |
+
recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
|
| 269 |
inp = features
|
| 270 |
for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
|
| 271 |
logging_str += f"Recursion #{i+1}\n"
|
|
|
|
| 273 |
inp,
|
| 274 |
num_eig=n_eigs,
|
| 275 |
num_sample_ncut=num_sample_ncut,
|
| 276 |
+
affinity_focal_gamma=recursion_gammas[i],
|
| 277 |
knn_ncut=knn_ncut,
|
| 278 |
knn_tsne=knn_tsne,
|
| 279 |
num_sample_tsne=num_sample_tsne,
|
|
|
|
| 353 |
|
| 354 |
def _ncut_run(*args, **kwargs):
|
| 355 |
try:
|
| 356 |
+
if torch.cuda.is_available():
|
| 357 |
+
torch.cuda.empty_cache()
|
| 358 |
+
|
| 359 |
ret = ncut_run(*args, **kwargs)
|
| 360 |
+
|
| 361 |
if torch.cuda.is_available():
|
| 362 |
torch.cuda.empty_cache()
|
| 363 |
+
|
| 364 |
return ret
|
| 365 |
except Exception as e:
|
| 366 |
gr.Error(str(e))
|