Spaces:
Paused
Paused
HanningChen
commited on
Commit
·
e4172fe
1
Parent(s):
1a6c72f
Fix cross attention
Browse files- webui/runner.py +1 -10
webui/runner.py
CHANGED
|
@@ -317,14 +317,6 @@ class ModelRunner:
|
|
| 317 |
if score_function != "HDC" and ckpt_is_hdc:
|
| 318 |
raise RuntimeError(f"Checkpoint IS HDC but score_function=default was selected. ckpt={ckpt_abs}")
|
| 319 |
|
| 320 |
-
# Validate cross_attention against checkpoint (if we can infer it)
|
| 321 |
-
# If your checkpoints don't contain cross-attn keys, ckpt_has_cross may be False even when the arch uses cross-attn.
|
| 322 |
-
# In that case, either update inference or remove this validation.
|
| 323 |
-
if bool(cross_attention) != bool(ckpt_has_cross):
|
| 324 |
-
raise RuntimeError(
|
| 325 |
-
f"cross_attention mismatch: runtime={cross_attention} but checkpoint has_cross_attention={ckpt_has_cross}. ckpt={ckpt_abs}"
|
| 326 |
-
)
|
| 327 |
-
|
| 328 |
# Validate d_model against checkpoint (if inferred)
|
| 329 |
if ckpt_d_model != -1 and int(d_model) != int(ckpt_d_model):
|
| 330 |
raise RuntimeError(
|
|
@@ -421,8 +413,7 @@ class ModelRunner:
|
|
| 421 |
# - default => cross_attention True
|
| 422 |
# - HDC => cross_attention False
|
| 423 |
# If your actual training differs, change this rule OR pass it from app.py.
|
| 424 |
-
|
| 425 |
-
cross_attention = True
|
| 426 |
|
| 427 |
with self._lock:
|
| 428 |
img = Image.open(image_path).convert("RGB")
|
|
|
|
| 317 |
if score_function != "HDC" and ckpt_is_hdc:
|
| 318 |
raise RuntimeError(f"Checkpoint IS HDC but score_function=default was selected. ckpt={ckpt_abs}")
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
# Validate d_model against checkpoint (if inferred)
|
| 321 |
if ckpt_d_model != -1 and int(d_model) != int(ckpt_d_model):
|
| 322 |
raise RuntimeError(
|
|
|
|
| 413 |
# - default => cross_attention True
|
| 414 |
# - HDC => cross_attention False
|
| 415 |
# If your actual training differs, change this rule OR pass it from app.py.
|
| 416 |
+
cross_attention = (score_function != "HDC")
|
|
|
|
| 417 |
|
| 418 |
with self._lock:
|
| 419 |
img = Image.open(image_path).convert("RGB")
|