| import hashlib |
| from pathlib import Path |
| from typing import Any |
|
|
| import folder_paths |
| import safetensors |
| import torch |
| from comfy_api.latest import io |
|
|
| from .nodes_registry import comfy_node |
|
|
|
|
| @comfy_node(name="LTXVLoadConditioning") |
| class LTXVLoadConditioning(io.ComfyNode): |
| @classmethod |
| def define_schema(cls) -> io.Schema: |
| files = folder_paths.get_filename_list("embeddings") |
| if not files: |
| files = [""] |
| return io.Schema( |
| node_id="LTXVLoadConditioning", |
| display_name="๐
๐
ฃ๐
ง LTXV Load Conditioning", |
| category="lightricks/LTXV", |
| inputs=[ |
| io.Combo.Input("file_name", options=sorted(files)), |
| io.Combo.Input("device", options=["cpu", "gpu"]), |
| ], |
| outputs=[ |
| io.Conditioning.Output(), |
| ], |
| ) |
|
|
| @classmethod |
| def execute(cls, file_name: str, device: str) -> io.NodeOutput: |
| file_path = folder_paths.get_full_path("embeddings", file_name) |
| if not Path(file_path).exists(): |
| raise FileNotFoundError(f"Conditioning file not found: {file_path}") |
|
|
| target_device = "cpu" |
| if device == "gpu": |
| target_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| conditioning: list[list[Any]] = [] |
|
|
| with safetensors.safe_open( |
| file_path, framework="pt", device=target_device |
| ) as f: |
| tensor_keys = [k for k in f.keys() if k.startswith("conditioning_data_")] |
|
|
| for tensor_key in sorted(tensor_keys): |
| idx = tensor_key.replace("conditioning_data_", "") |
| tensor = f.get_tensor(tensor_key) |
|
|
| options: dict[str, Any] = {} |
| mask_key = f"attention_mask_{idx}" |
| if mask_key in f.keys(): |
| options["attention_mask"] = f.get_tensor(mask_key) |
|
|
| conditioning.append([tensor, options]) |
|
|
| if not conditioning: |
| raise ValueError(f"No conditioning data found in file: {file_name}") |
|
|
| return io.NodeOutput(conditioning) |
|
|
| @classmethod |
| def fingerprint_inputs(cls, file_name: str, device: str) -> str: |
| file_path = folder_paths.get_full_path("embeddings", file_name) |
| with open(file_path, "rb") as f: |
| return hashlib.sha256(f.read()).hexdigest() |
|
|
| @classmethod |
| def validate_inputs(cls, file_name: str, device: str) -> bool | str: |
| if not file_name: |
| return "No files found. Please save a conditioning first." |
| try: |
| file_path = folder_paths.get_full_path("embeddings", file_name) |
| if not Path(file_path).exists(): |
| return f"File not found: {file_name}" |
| except Exception: |
| return f"Invalid file: {file_name}" |
| return True |
|
|