LTX-2.3-FP4 / comfy_bathroom.py
MrReclusive's picture
Upload comfy_bathroom.py
c7655a6 verified
"""
Comfy Bathroom - LoRA Loading Suite for FP4 Quantized Models
A complete LoRA loading system designed for use with FP4ME/F4PMEL quantized LTX-2.3 models.
Author: Super Z
"""
import comfy.utils
import folder_paths
import torch
import re
from typing import Dict, List, Optional, Tuple, Any
# =============================================================================
# PRESET CURVES
# =============================================================================
def generate_ramp_up(start_block: int, end_block: int, start_val: float, end_val: float) -> Dict[int, float]:
"""Generate a smooth ramp between two blocks."""
curve = {}
if end_block <= start_block:
return curve
steps = end_block - start_block
for i, block in enumerate(range(start_block, end_block + 1)):
t = i / steps
curve[block] = start_val + (end_val - start_val) * t
return curve
def get_fp4me_light_weights() -> Dict[int, float]:
"""FP4ME Light preset."""
weights = {}
weights[0] = 1.0
weights[1] = 0.0
weights.update(generate_ramp_up(2, 10, 0.10, 1.0))
for i in range(11, 40):
weights[i] = 1.0
weights.update(generate_ramp_up(40, 46, 0.95, 0.50))
weights[47] = 1.0
return weights
def get_fp4me_heavy_weights() -> Dict[int, float]:
"""FP4ME Heavy preset."""
weights = {}
weights[0] = 1.0
weights[1] = 0.0
weights.update(generate_ramp_up(2, 10, 0.10, 1.0))
for i in range(11, 40):
weights[i] = 1.0
weights.update(generate_ramp_up(40, 46, 1.0, 0.0))
weights[47] = 1.0
return weights
def get_fp4mel_light_weights() -> Dict[int, float]:
"""FP4MEL Light preset."""
weights = {}
weights[0] = 1.0
weights[1] = 1.0
weights.update(generate_ramp_up(2, 10, 0.10, 1.0))
for i in range(11, 41):
weights[i] = 1.0
weights.update(generate_ramp_up(41, 45, 0.95, 0.60))
weights[46] = 1.0
weights[47] = 1.0
return weights
def get_fp4mel_heavy_weights() -> Dict[int, float]:
"""FP4MEL Heavy preset."""
weights = {}
weights[0] = 1.0
weights[1] = 1.0
weights.update(generate_ramp_up(2, 10, 0.10, 1.0))
for i in range(11, 40):
weights[i] = 1.0
weights.update(generate_ramp_up(40, 45, 0.95, 0.0))
weights[46] = 1.0
weights[47] = 1.0
return weights
def apply_block_weights_to_lora(lora_data: dict, block_weights: Dict[int, float]) -> dict:
"""Apply per-block weights to LoRA data."""
filtered = {}
for key, value in lora_data.items():
block_match = re.search(r'transformer_blocks\.(\d+)\.', key)
if block_match:
block_idx = int(block_match.group(1))
weight = block_weights.get(block_idx, 1.0)
if weight > 0.0:
filtered[key] = value * weight if weight < 1.0 else value
else:
filtered[key] = value
return filtered
# =============================================================================
# TOOTHBRUSH - LoRA Loader
# =============================================================================
class ToothbrushLoRALoader:
PRESET_OPTIONS = ["default", "FP4ME Light", "FP4ME Heavy", "FP4MEL Light", "FP4MEL Heavy", "Custom"]
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_name": (folder_paths.get_filename_list("loras"),),
"preset": (s.PRESET_OPTIONS, {"default": "default"}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05}),
},
"optional": {
"custom_weights": ("LORA_BLOCK_WEIGHTS",),
}
}
RETURN_TYPES = ("LORA_PACKET",)
RETURN_NAMES = ("lora",)
FUNCTION = "load_lora"
CATEGORY = "bathroom"
DESCRIPTION = "🪥 Toothbrush - LoRA loader with FP4 presets"
def load_lora(self, lora_name, preset, strength, custom_weights=None):
lora_path = folder_paths.get_full_path("loras", lora_name)
lora_data = comfy.utils.load_torch_file(lora_path, safe_load=False)
block_weights = None
if preset == "FP4ME Light":
block_weights = get_fp4me_light_weights()
elif preset == "FP4ME Heavy":
block_weights = get_fp4me_heavy_weights()
elif preset == "FP4MEL Light":
block_weights = get_fp4mel_light_weights()
elif preset == "FP4MEL Heavy":
block_weights = get_fp4mel_heavy_weights()
elif preset == "Custom":
block_weights = custom_weights
packet = {
"lora_data": lora_data,
"preset": preset,
"strength": strength,
"block_weights": block_weights,
"lora_name": lora_name,
}
print(f"🪥 Toothbrush: '{lora_name}' | {preset} | {strength:.2f}")
return (packet,)
# =============================================================================
# MIRROR SIMPLE - Binary On/Off
# =============================================================================
class MirrorSimple:
@classmethod
def INPUT_TYPES(s):
block_inputs = {f"block_{i}": ("BOOLEAN", {"default": True}) for i in range(48)}
return {
"required": block_inputs,
"optional": {"lora_packet": ("LORA_PACKET",)}
}
RETURN_TYPES = ("LORA_BLOCK_WEIGHTS", "LORA_PACKET")
RETURN_NAMES = ("block_weights", "lora_out")
FUNCTION = "configure"
CATEGORY = "bathroom"
DESCRIPTION = "🪞 Mirror (Simple) - Per-block on/off"
def configure(self, lora_packet=None, **kwargs):
block_weights = {i: (1.0 if kwargs.get(f"block_{i}", True) else 0.0) for i in range(48)}
disabled = [i for i, w in block_weights.items() if w == 0.0]
print(f"🪞 Mirror (Simple): {48-len(disabled)} ON, {len(disabled)} OFF")
out_packet = lora_packet.copy() if lora_packet else None
if out_packet:
out_packet["block_weights"] = block_weights
return (block_weights, out_packet)
# =============================================================================
# MIRROR FANCY - Per-Block Strength
# =============================================================================
class MirrorFancy:
@classmethod
def INPUT_TYPES(s):
block_inputs = {f"block_{i}": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05}) for i in range(48)}
return {
"required": block_inputs,
"optional": {"lora_packet": ("LORA_PACKET",)}
}
RETURN_TYPES = ("LORA_BLOCK_WEIGHTS", "LORA_PACKET")
RETURN_NAMES = ("block_weights", "lora_out")
FUNCTION = "configure"
CATEGORY = "bathroom"
DESCRIPTION = "🪞 Mirror (Fancy) - Per-block strength"
def configure(self, lora_packet=None, **kwargs):
block_weights = {i: kwargs.get(f"block_{i}", 1.0) for i in range(48)}
active = sum(1 for w in block_weights.values() if w > 0)
print(f"🪞 Mirror (Fancy): {active} blocks active")
out_packet = lora_packet.copy() if lora_packet else None
if out_packet:
out_packet["block_weights"] = block_weights
return (block_weights, out_packet)
# =============================================================================
# BATHROOM SINK - LoRA Stacker
# =============================================================================
class BathroomSink:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"lora_1": ("LORA_PACKET",),
"global_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05}),
},
"optional": {f"lora_{i}": ("LORA_PACKET",) for i in range(2, 9)}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "apply_loras"
CATEGORY = "bathroom"
DESCRIPTION = "🚰 Bathroom Sink - Stack multiple LoRAs"
def apply_loras(self, model, lora_1, global_strength, **kwargs):
lora_packets = [lora_1] + [kwargs.get(f"lora_{i}") for i in range(2, 9) if kwargs.get(f"lora_{i}")]
print(f"\n{'='*60}")
print(f"🚰 Bathroom Sink - {len(lora_packets)} LoRAs, global: {global_strength:.2f}")
print(f"{'='*60}")
model_out = model.clone()
for idx, packet in enumerate(lora_packets):
lora_data = packet["lora_data"]
strength = packet["strength"] * global_strength
block_weights = packet.get("block_weights")
lora_name = packet.get("lora_name", f"LoRA_{idx+1}")
preset = packet.get("preset", "default")
if block_weights:
processed_data = apply_block_weights_to_lora(lora_data, block_weights)
else:
processed_data = lora_data
print(f" [{idx+1}] {lora_name} | {preset} | {strength:.2f}")
# Apply using ComfyUI's standard LoRA mechanism
key_map = comfy.lora.model_lora_keys_unet(model_out.model)
try:
# Try loading - handle both old and new ComfyUI API
result = comfy.lora.load_lora(processed_data, key_map)
# Check if result is the new LoRAAdapter format
if hasattr(result, 'patches'):
# New API - LoRAAdapter object
model_out.add_patches(result.patches, strength)
elif isinstance(result, dict):
# Old API - patch dict
model_out.add_patches(result, strength)
else:
# Try to apply directly
model_out.add_patches(result, strength)
except Exception as e:
print(f" ⚠️ LoRA load error: {e}")
# Fallback: use the original approach
try:
# Build patches manually
patches = self._build_patches(processed_data, key_map)
if patches:
model_out.add_patches(patches, strength)
except Exception as e2:
print(f" ⚠️ Fallback failed: {e2}")
print(f"{'='*60}\n")
return (model_out,)
def _build_patches(self, lora_data, key_map):
"""Build patch dict manually."""
patches = {}
for lora_key, lora_value in lora_data.items():
# Find the model key
model_key = key_map.get(lora_key, None)
if model_key is None:
continue
if model_key not in patches:
patches[model_key] = []
# Add as a diff patch
if ".lora_A.weight" in lora_key:
# Find the matching lora_B
b_key = lora_key.replace(".lora_A.weight", ".lora_B.weight")
if b_key in lora_data:
lora_b = lora_data[b_key]
# Compute delta
if lora_value.dim() == 2 and lora_b.dim() == 2:
delta = torch.mm(lora_b, lora_value)
patches[model_key].append(("diff", delta))
elif ".lora_B.weight" not in lora_key:
# Direct value (diff format)
patches[model_key].append(("diff", lora_value))
return patches
# =============================================================================
# SHOWER - Quick Preset
# =============================================================================
class ShowerPreset:
PRESET_OPTIONS = ["FP4ME Light", "FP4ME Heavy", "FP4MEL Light", "FP4MEL Heavy"]
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"lora_packet": ("LORA_PACKET",),
"preset": (s.PRESET_OPTIONS, {"default": "FP4ME Light"}),
}
}
RETURN_TYPES = ("LORA_PACKET",)
RETURN_NAMES = ("lora_out",)
FUNCTION = "apply_preset"
CATEGORY = "bathroom"
DESCRIPTION = "🚿 Shower - Quick preset"
def apply_preset(self, lora_packet, preset):
out = lora_packet.copy()
presets = {
"FP4ME Light": get_fp4me_light_weights,
"FP4ME Heavy": get_fp4me_heavy_weights,
"FP4MEL Light": get_fp4mel_light_weights,
"FP4MEL Heavy": get_fp4mel_heavy_weights,
}
out["block_weights"] = presets[preset]()
out["preset"] = preset
print(f"🚿 Shower: {preset}")
return (out,)
# =============================================================================
# TOWEL - Info Display
# =============================================================================
class TowelInfo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"lora_packet": ("LORA_PACKET",)}}
RETURN_TYPES = ("LORA_PACKET", "STRING")
RETURN_NAMES = ("lora_out", "info")
FUNCTION = "display_info"
CATEGORY = "bathroom"
OUTPUT_NODE = True
DESCRIPTION = "🧾 Towel - Info"
def display_info(self, lora_packet):
lines = [
f"LoRA: {lora_packet.get('lora_name', '?')}",
f"Preset: {lora_packet.get('preset', '?')}",
f"Strength: {lora_packet.get('strength', 1):.2f}",
]
bw = lora_packet.get("block_weights")
if bw:
lines.append(f"Disabled: {[i for i,w in bw.items() if w<0.01]}")
return (lora_packet, "\n".join(lines))
# =============================================================================
# NODE MAPPINGS
# =============================================================================
NODE_CLASS_MAPPINGS = {
"Toothbrush LoRA Loader": ToothbrushLoRALoader,
"Mirror (Simple)": MirrorSimple,
"Mirror (Fancy)": MirrorFancy,
"Bathroom Sink": BathroomSink,
"Shower Preset": ShowerPreset,
"Towel Info": TowelInfo,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Toothbrush LoRA Loader": "🪥 Toothbrush",
"Mirror (Simple)": "🪞 Mirror (Simple)",
"Mirror (Fancy)": "🪞 Mirror (Fancy)",
"Bathroom Sink": "🚰 Bathroom Sink",
"Shower Preset": "🚿 Shower",
"Towel Info": "🧾 Towel",
}