HanningChen commited on
Commit
e4172fe
·
1 Parent(s): 1a6c72f

Fix cross attention

Browse files
Files changed (1) hide show
  1. webui/runner.py +1 -10
webui/runner.py CHANGED
@@ -317,14 +317,6 @@ class ModelRunner:
317
  if score_function != "HDC" and ckpt_is_hdc:
318
  raise RuntimeError(f"Checkpoint IS HDC but score_function=default was selected. ckpt={ckpt_abs}")
319
 
320
- # Validate cross_attention against checkpoint (if we can infer it)
321
- # If your checkpoints don't contain cross-attn keys, ckpt_has_cross may be False even when the arch uses cross-attn.
322
- # In that case, either update inference or remove this validation.
323
- if bool(cross_attention) != bool(ckpt_has_cross):
324
- raise RuntimeError(
325
- f"cross_attention mismatch: runtime={cross_attention} but checkpoint has_cross_attention={ckpt_has_cross}. ckpt={ckpt_abs}"
326
- )
327
-
328
  # Validate d_model against checkpoint (if inferred)
329
  if ckpt_d_model != -1 and int(d_model) != int(ckpt_d_model):
330
  raise RuntimeError(
@@ -421,8 +413,7 @@ class ModelRunner:
421
  # - default => cross_attention True
422
  # - HDC => cross_attention False
423
  # If your actual training differs, change this rule OR pass it from app.py.
424
- # cross_attention = (score_function != "HDC")
425
- cross_attention = True
426
 
427
  with self._lock:
428
  img = Image.open(image_path).convert("RGB")
 
317
  if score_function != "HDC" and ckpt_is_hdc:
318
  raise RuntimeError(f"Checkpoint IS HDC but score_function=default was selected. ckpt={ckpt_abs}")
319
 
 
 
 
 
 
 
 
 
320
  # Validate d_model against checkpoint (if inferred)
321
  if ckpt_d_model != -1 and int(d_model) != int(ckpt_d_model):
322
  raise RuntimeError(
 
413
  # - default => cross_attention True
414
  # - HDC => cross_attention False
415
  # If your actual training differs, change this rule OR pass it from app.py.
416
+ cross_attention = (score_function != "HDC")
 
417
 
418
  with self._lock:
419
  img = Image.open(image_path).convert("RGB")