comfyui-stanno / nodes.py
oldman-dev's picture
Up-to-date with original repo
8a0f449 verified
"""
STANNO custom nodes for ComfyUI.
Nodes:
- STANNOLoad Load or create a STANNO model
- STANNOTrainImages Train STANNO as autoencoder on a batch of images
- STANNOScoreImages Score and filter images by reconstruction error
- STANNODreamCond Inject dream-mode creativity into CONDITIONING
- STANNODynamicLoRA Patch MODEL attention weights using STANNO dream output
- STANNOCompositeCheck Score images against two STANNOs and route by winner
Full integration guide:
/mnt/juegos/proyectos/especiales/stanno/comfyui-stanno-integration.md
"""
from __future__ import annotations
import os
import json
import pickle
import sys
import numpy as np
import torch
from comfy_api.latest import ComfyExtension, io
# ─── helpers ────────────────────────────────────────────────────────────────
def _flatten_images(image_tensor: torch.Tensor, target_dim: int) -> np.ndarray:
"""
Resize each image in a ComfyUI IMAGE batch to produce a flat vector of
exactly `target_dim` floats. IMAGE format: (B, H, W, C) float32 [0, 1].
"""
import torch.nn.functional as F
b, h, w, c = image_tensor.shape
side = max(1, int(((target_dim // c) ** 0.5)))
x = image_tensor.permute(0, 3, 1, 2) # B C H W
x = F.interpolate(x, size=(side, side), mode="bilinear", align_corners=False)
x = x.permute(0, 2, 3, 1).reshape(b, -1) # B (sideΒ²Β·C)
# Trim or pad to exactly target_dim
if x.shape[1] > target_dim:
x = x[:, :target_dim]
elif x.shape[1] < target_dim:
pad = torch.zeros(b, target_dim - x.shape[1], device=x.device)
x = torch.cat([x, pad], dim=1)
return x.detach().cpu().numpy()
# ─── Node 1: Load / Create STANNO ───────────────────────────────────────────
class STANNOLoad(io.ComfyNode):
"""
Load a saved STANNO model from disk, or create a new untrained one.
If `model_path` points to an existing .pkl file it is loaded unchanged.
Otherwise a new STANNO is created with the given architecture and trainer.
The returned STANNO object can be passed to any other STANNO node.
"""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNOLoad",
display_name="STANNO Loader",
category="STANNO",
inputs=[
io.String.Input(
"model_path",
default="stanno_model.pkl",
multiline=False,
tooltip="Path to a saved STANNO .pkl file, or a new filename to create.",
),
io.String.Input(
"layers_json",
default="[1, 32, 1]",
multiline=False,
tooltip=(
"JSON list of layer sizes. Examples:\n"
" [1, 32, 1] sin regression (poc)\n"
" [768, 256, 768] CLIP-embedding autoencoder (SD 1.5)\n"
" [3072, 512, 3072] 32Γ—32 pixel autoencoder\n"
" [784, 256, 128, 10] classifier"
),
),
io.Combo.Input(
"trainer_type",
options=["fixed", "local_rule", "evolutionary"],
),
io.Float.Input(
"learning_rate",
default=0.01,
min=1e-5,
max=1.0,
step=0.001,
display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Custom.Output("STANNO"),
io.String.Output("info"),
],
)
@classmethod
def execute(cls, model_path, layers_json, trainer_type, learning_rate) -> io.NodeOutput:
from stanno.config.schema import STANNOConfig
from stanno.core.stanno import STANNO
if os.path.isfile(model_path):
with open(model_path, "rb") as f:
stanno_obj = pickle.load(f)
info = f"Loaded: {model_path} | layers={stanno_obj.config.layers}"
else:
layers = json.loads(layers_json)
config = STANNOConfig(
layers=layers,
trainer_type=trainer_type,
learning_rate=learning_rate,
)
stanno_obj = STANNO(config)
info = f"Created new STANNO | layers={layers} trainer={trainer_type} lr={learning_rate}"
print(f"[STANNO Loader] {info}")
return io.NodeOutput(stanno_obj, info)
# ─── Node 2: Train on Images ─────────────────────────────────────────────────
class STANNOTrainImages(io.ComfyNode):
"""
Train a STANNO as an autoencoder on a batch of images.
Images are resized to match the STANNO's input dimension, normalized to
[-1, 1], and used as both input and target (autoencoder). After training
the STANNO 'remembers' the style/distribution of those images.
Tip: connect the output STANNO to STANNOScoreImages to filter later
generated images against this learned distribution.
"""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNOTrainImages",
display_name="STANNO Train from Images",
category="STANNO",
inputs=[
io.Image.Input("images"),
io.Custom.Input("STANNO", "stanno"),
io.Int.Input("epochs", default=100, min=1, max=10000, step=10,
display_mode=io.NumberDisplay.number),
io.Int.Input("batch_size", default=16, min=1, max=512, step=8,
display_mode=io.NumberDisplay.number),
io.String.Input(
"save_path",
default="",
multiline=False,
tooltip="Optional: absolute path to save the trained STANNO as .pkl. Leave empty to skip.",
),
],
outputs=[
io.Custom.Output("STANNO"),
io.String.Output("training_log"),
],
)
@classmethod
def execute(cls, images, stanno, epochs, batch_size, save_path) -> io.NodeOutput:
import copy
stanno_copy = copy.deepcopy(stanno)
input_dim = stanno_copy.config.layers[0]
x = _flatten_images(images, input_dim).astype(np.float32)
x = x * 2.0 - 1.0 # normalize to [-1, 1]
log_lines: list[str] = []
report_every = max(1, epochs // 5)
def log_cb(epoch: int, loss: float) -> None:
if (epoch + 1) % report_every == 0:
line = f"epoch {epoch + 1:5d} loss={loss:.5f}"
log_lines.append(line)
print(f"[STANNO Train] {line}")
stanno_copy.fit(x, x, epochs=epochs, batch_size=batch_size, callback=log_cb)
save = save_path.strip()
if save:
os.makedirs(os.path.dirname(os.path.abspath(save)), exist_ok=True)
with open(save, "wb") as f:
pickle.dump(stanno_copy, f)
log_lines.append(f"Saved β†’ {save}")
return io.NodeOutput(stanno_copy, "\n".join(log_lines))
# ─── Node 3: Score & Filter Images ───────────────────────────────────────────
class STANNOScoreImages(io.ComfyNode):
"""
Score a batch of images using a trained STANNO autoencoder.
Reconstruction MSE is the anomaly score: low = in-distribution (style match),
high = outlier. Outputs the full batch sorted by score plus a filtered
sub-batch containing only images below the threshold.
"""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNOScoreImages",
display_name="STANNO Image Scorer",
category="STANNO",
inputs=[
io.Image.Input("images"),
io.Custom.Input("STANNO", "stanno"),
io.Float.Input(
"threshold",
default=0.10,
min=0.0,
max=2.0,
step=0.005,
display_mode=io.NumberDisplay.slider,
tooltip="MSE above this value is flagged as anomaly / style mismatch.",
),
io.Combo.Input(
"sort_order",
options=["best_first", "worst_first", "original"],
),
],
outputs=[
io.Image.Output(), # sorted batch
io.Image.Output(), # filtered batch (below threshold)
io.String.Output("scores_json"),
],
)
@classmethod
def execute(cls, images, stanno, threshold, sort_order) -> io.NodeOutput:
from stanno.integration.dsanno import DSANNO
input_dim = stanno.config.layers[0]
x = _flatten_images(images, input_dim).astype(np.float32) * 2.0 - 1.0
scanner = DSANNO(stanno, mode="reconstruction")
scores_arr, preds = scanner.score_batch(x)
scores = scores_arr.tolist()
max_s = max(scores) if max(scores) > 0 else 1.0
norm_scores = [s / max_s for s in scores]
indices = list(range(len(scores)))
if sort_order == "best_first":
indices.sort(key=lambda i: scores[i])
elif sort_order == "worst_first":
indices.sort(key=lambda i: -scores[i])
sorted_images = images[torch.tensor(indices, device=images.device)]
filtered_idx = [i for i in indices if scores[i] < threshold]
filtered_images = (
images[torch.tensor(filtered_idx, device=images.device)]
if filtered_idx else images[:1]
)
scores_data = [
{
"index": i,
"mse": round(scores[i], 5),
"norm": round(norm_scores[i], 4),
"pass": scores[i] < threshold,
}
for i in range(len(scores))
]
return io.NodeOutput(sorted_images, filtered_images, json.dumps(scores_data, indent=2))
# ─── Node 4: Dream Conditioning ──────────────────────────────────────────────
class STANNODreamCond(io.ComfyNode):
"""
Modify a CLIP CONDITIONING tensor using STANNO dream mode.
The STANNO must have been trained on CLIP embeddings (768-dim per token for
SD 1.5). Each token in the conditioning is fed as an input seed to dream(),
perturbed by noise, and the result is blended back with the original.
noise_sigma controls the creativity spectrum:
0.00–0.02 almost identical to original prompt
0.05–0.15 subtle but noticeable style shift (recommended starting point)
0.20–0.40 creative variations, may drift from original prompt meaning
0.50+ chaotic, unpredictable (useful for pure exploration)
blend_strength controls how much of the dream replaces the original:
0.0 = original conditioning unchanged
1.0 = full dream output (ignore original)
0.1–0.3 recommended for most workflows
"""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNODreamCond",
display_name="STANNO Dream Conditioning",
category="STANNO",
inputs=[
io.Conditioning.Input("conditioning"),
io.Custom.Input("STANNO", "stanno"),
io.Float.Input(
"noise_sigma",
default=0.05,
min=0.0,
max=2.0,
step=0.01,
display_mode=io.NumberDisplay.slider,
),
io.Float.Input(
"blend_strength",
default=0.20,
min=0.0,
max=1.0,
step=0.01,
display_mode=io.NumberDisplay.slider,
),
io.Int.Input(
"seed",
default=42,
min=0,
max=2 ** 31,
display_mode=io.NumberDisplay.number,
),
io.Combo.Input(
"feedback_projection",
options=["repeat", "linear", "zeros"],
),
],
outputs=[
io.Conditioning.Output(), # modified conditioning
io.Conditioning.Output(), # original pass-through
],
)
@classmethod
def execute(
cls, conditioning, stanno, noise_sigma, blend_strength, seed, feedback_projection
) -> io.NodeOutput:
import copy
rng = np.random.default_rng(seed)
result = []
for cond_tensor, cond_meta in conditioning:
# cond_tensor: (1, seq_len, embed_dim) e.g. (1, 77, 768) for SD 1.5
orig_np = cond_tensor.detach().cpu().numpy().astype(np.float32)
b, seq, dim = orig_np.shape
dream_tokens: list[np.ndarray] = []
for token_idx in range(seq):
seed_vec = orig_np[0, token_idx, :].reshape(1, -1) # (1, dim)
dream_out = stanno.dream(
num_steps=1,
input_seed=seed_vec,
noise_sigma=noise_sigma,
blind_inputs=False,
rng=rng,
)
dream_tokens.append(dream_out[0]) # (dim,)
dream_cond = np.stack(dream_tokens, axis=0)[np.newaxis] # (1, seq, dim)
blended = (1.0 - blend_strength) * orig_np + blend_strength * dream_cond
blended_t = torch.from_numpy(blended).to(
device=cond_tensor.device, dtype=cond_tensor.dtype
)
result.append((blended_t, copy.deepcopy(cond_meta)))
return io.NodeOutput(result, conditioning)
# ─── Node 5: Dynamic LoRA (weight-space patching) ────────────────────────────
class STANNODynamicLoRA(io.ComfyNode):
"""
Inject STANNO dream output as LoRA-equivalent weight patches into a MODEL.
STANNO generates `lora_rank` dream vectors. These are stacked into A (up)
and B (down) projection matrices and applied to the SD 1.5 cross-attention
layers via ComfyUI's native add_patches() mechanism.
Requirements:
- STANNO layers[0] == layers[-1] == 768 (SD 1.5 cross-attention dim)
- Recommended: train STANNO on CLIP embeddings of your target style first
Parameter guide:
lora_rank 1–2 β†’ subtle and stable; 4–8 β†’ stronger but may cause drift
alpha 0.3–0.5 is a good starting point for SD 1.5
noise_sigma 0.0 β†’ deterministic style from STANNO weights
0.1–0.2 β†’ creative variations per run
"""
# Cross-attention projections in SD 1.5 UNet (12 representative layers;
# add more from the full model key list for stronger effect).
_ATTN_KEYS = [
"diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight",
"diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight",
"diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight",
"diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight",
"diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight",
"diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight",
"diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight",
"diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight",
"diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight",
"diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight",
"diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight",
]
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNODynamicLoRA",
display_name="STANNO Dynamic LoRA",
category="STANNO",
inputs=[
io.Model.Input("model"),
io.Custom.Input("STANNO", "stanno"),
io.Float.Input(
"alpha",
default=0.5,
min=0.0,
max=2.0,
step=0.05,
display_mode=io.NumberDisplay.slider,
tooltip="LoRA scaling factor. Start at 0.3–0.5 for SD 1.5.",
),
io.Int.Input(
"lora_rank",
default=2,
min=1,
max=16,
step=1,
display_mode=io.NumberDisplay.number,
tooltip="Rank of the injected AΓ—B matrices. Lower = more stable.",
),
io.Float.Input(
"noise_sigma",
default=0.10,
min=0.0,
max=1.0,
step=0.01,
display_mode=io.NumberDisplay.slider,
),
io.Int.Input(
"seed",
default=0,
min=0,
max=2 ** 31,
display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Model.Output(),
io.String.Output("patch_info"),
],
)
@classmethod
def execute(cls, model, stanno, alpha, lora_rank, noise_sigma, seed) -> io.NodeOutput:
rng = np.random.default_rng(seed)
dim = stanno.config.layers[0]
rank = min(lora_rank, dim)
# Generate `rank` dream vectors β€” each is a (dim,) style direction
basis: list[np.ndarray] = []
for _ in range(rank):
seed_vec = rng.normal(0.0, 0.1, (1, dim)).astype(np.float32)
dream_out = stanno.dream(
num_steps=1,
input_seed=seed_vec,
noise_sigma=noise_sigma,
blind_inputs=False,
rng=rng,
)
basis.append(dream_out[0]) # (dim,)
# A: (rank, dim) B: (dim, rank)
A = np.stack(basis, axis=0)
norms = np.linalg.norm(A, axis=1, keepdims=True).clip(min=1e-8)
A_norm = (A / norms).astype(np.float32)
B = A_norm.T.astype(np.float32)
A_t = torch.from_numpy(A_norm)
B_t = torch.from_numpy(B)
# ComfyUI LoRA patch format: {key: ("lora", (down, up))}
patches = {key: ("lora", (B_t, A_t)) for key in cls._ATTN_KEYS}
patched_model = model.clone()
patched_model.add_patches(patches, alpha)
info = (
f"Patched {len(patches)} attention layers | "
f"rank={rank} | alpha={alpha:.2f} | noise={noise_sigma:.3f} | seed={seed}"
)
print(f"[STANNO DynamicLoRA] {info}")
return io.NodeOutput(patched_model, info)
# ─── Node 6: Composite Style Checker ─────────────────────────────────────────
class STANNOCompositeCheck(io.ComfyNode):
"""
Score a batch of images against two STANNOs and route by the closer match.
In a composite / inpainting workflow different image zones should each match
a particular trained style. This node splits a batch into two sub-batches
based on which STANNO has lower reconstruction error, and reports the margin
so you can identify ambiguous images.
Typical use: connect two STANNOs trained on 'background style' and
'foreground style'; route each generated image to the right inpaint layer.
"""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNOCompositeCheck",
display_name="STANNO Composite Style Checker",
category="STANNO",
inputs=[
io.Image.Input("images"),
io.Custom.Input("STANNO", "stanno_a"),
io.Custom.Input("STANNO", "stanno_b"),
io.String.Input("label_a", default="Style A", multiline=False),
io.String.Input("label_b", default="Style B", multiline=False),
],
outputs=[
io.Image.Output(), # images closest to Style A
io.Image.Output(), # images closest to Style B
io.String.Output("report_json"),
],
)
@classmethod
def execute(cls, images, stanno_a, stanno_b, label_a, label_b) -> io.NodeOutput:
dim_a = stanno_a.config.layers[0]
dim_b = stanno_b.config.layers[0]
xa = _flatten_images(images, dim_a).astype(np.float32) * 2.0 - 1.0
xb = _flatten_images(images, dim_b).astype(np.float32) * 2.0 - 1.0
scores_a = np.mean((stanno_a.predict(xa) - xa) ** 2, axis=1)
scores_b = np.mean((stanno_b.predict(xb) - xb) ** 2, axis=1)
idx_a = [i for i in range(len(scores_a)) if scores_a[i] <= scores_b[i]]
idx_b = [i for i in range(len(scores_a)) if scores_a[i] > scores_b[i]]
imgs_a = (
images[torch.tensor(idx_a, device=images.device)]
if idx_a else images[:1]
)
imgs_b = (
images[torch.tensor(idx_b, device=images.device)]
if idx_b else images[:1]
)
report = [
{
"index": i,
label_a: round(float(scores_a[i]), 5),
label_b: round(float(scores_b[i]), 5),
"winner": label_a if scores_a[i] <= scores_b[i] else label_b,
"margin": round(abs(float(scores_a[i]) - float(scores_b[i])), 5),
}
for i in range(len(scores_a))
]
return io.NodeOutput(imgs_a, imgs_b, json.dumps(report, indent=2))
# ─── Node 7: DSANNO Scan ──────────────────────────────────────────────────────
class STANNOScan(io.ComfyNode):
"""
DSANNO β€” Data Scanning Artificial Neural Network Object.
Scans a batch of images and finds the ones that best match what the STANNO
has learned, implementing the patent's DSANNO concept: "scan large regions
of the data space looking for patterns that match the learned representation."
Two modes
─────────
auto_calibrate = ON (recommended)
The threshold is computed automatically from this very batch at the
given percentile. E.g. percentile=20 keeps the best-matching 20 %.
auto_calibrate = OFF
Use the manually supplied ``threshold`` value directly.
Outputs
───────
top_k_images β€” the k images with lowest reconstruction error
matched_images β€” all images below the threshold
scores_json β€” per-image scores and match flags
threshold β€” the threshold that was applied (useful for display/routing)
"""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNOScan",
display_name="STANNO Scan (DSANNO)",
category="STANNO",
inputs=[
io.Image.Input("images"),
io.Custom.Input("STANNO", "stanno"),
io.Int.Input(
"top_k",
default=4,
min=1,
max=64,
step=1,
display_mode=io.NumberDisplay.number,
tooltip="Return this many best-matching images regardless of threshold.",
),
io.Combo.Input(
"auto_calibrate",
options=["on", "off"],
tooltip=(
"on: compute threshold automatically from this batch at the "
"given percentile.\n"
"off: use the manual threshold value."
),
),
io.Float.Input(
"percentile",
default=30.0,
min=1.0,
max=99.0,
step=1.0,
display_mode=io.NumberDisplay.slider,
tooltip=(
"Used when auto_calibrate=on. "
"30 = keep the best-matching 30 % of the batch."
),
),
io.Float.Input(
"threshold",
default=0.10,
min=0.0,
max=5.0,
step=0.005,
display_mode=io.NumberDisplay.number,
tooltip="Manual threshold (used only when auto_calibrate=off).",
),
],
outputs=[
io.Image.Output(), # top_k_images
io.Image.Output(), # matched_images
io.String.Output("scores_json"),
io.Float.Output("threshold"),
],
)
@classmethod
def execute(cls, images, stanno, top_k, auto_calibrate, percentile, threshold) -> io.NodeOutput:
from stanno.integration.dsanno import DSANNO
input_dim = stanno.config.layers[0]
x = _flatten_images(images, input_dim).astype(np.float32) * 2.0 - 1.0
scanner = DSANNO(stanno, mode="reconstruction")
result = scanner.scan(x)
# Determine threshold
if auto_calibrate == "on":
used_threshold = float(np.percentile(result.scores, percentile))
else:
used_threshold = float(threshold)
result.set_threshold(used_threshold)
# top_k images
k = min(int(top_k), len(images))
top_indices, top_scores, _ = scanner.top_k(x, k=k)
top_images = images[torch.tensor(top_indices.tolist(), device=images.device)]
# matched images (below threshold)
matched_idx = result.matched_indices().tolist()
matched_images = (
images[torch.tensor(matched_idx, device=images.device)]
if matched_idx else images[:1]
)
scores_data = [
{
"index": int(i),
"mse": round(float(result.scores[i]), 5),
"matched": bool(result.matched_mask[i]),
"rank": int(np.where(np.argsort(result.scores) == i)[0][0]) + 1,
}
for i in range(len(result.scores))
]
print(
f"[STANNO Scan] threshold={used_threshold:.4f} | "
f"matched={len(matched_idx)}/{len(images)} | top_k={k}"
)
return io.NodeOutput(
top_images,
matched_images,
json.dumps(scores_data, indent=2),
used_threshold,
)
# ─── Node 8: Cascade Load / Create ───────────────────────────────────────────
class STANNOCascadeLoad(io.ComfyNode):
"""
Load or create a CascadeSTANNO β€” a chain of STANNO stages.
Implements the patent's "cascading networks to form system models":
the output of stage k feeds the input of stage k+1, and each stage
can be independently frozen or adapted.
Typical uses
────────────
Encoder + Decoder autoencoder:
stages_json = [{"layers": [3072, 512]}, {"layers": [512, 3072]}]
Progressive compression pipeline:
stages_json = [{"layers":[768,256]}, {"layers":[256,64]}, {"layers":[64,256]}, {"layers":[256,768]}]
Frozen pre-processor + adaptive head:
stages_json = [{"layers":[768,256]}, {"layers":[256,10]}]
frozen_json = [true, false]
stages_json format
──────────────────
JSON array of objects. Keys:
"layers" required β€” e.g. [768, 256, 768]
"trainer_type" optional β€” "fixed"|"local_rule"|"evolutionary" (default "fixed")
"learning_rate" optional β€” per-stage lr (default: uses the top-level lr)
"""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNOCascadeLoad",
display_name="STANNO Cascade Loader",
category="STANNO",
inputs=[
io.String.Input(
"model_path",
default="cascade_model.pkl",
multiline=False,
tooltip="Path to a saved CascadeSTANNO .pkl, or a new filename to create.",
),
io.String.Input(
"stages_json",
default='[{"layers": [3072, 512]}, {"layers": [512, 3072]}]',
multiline=True,
tooltip=(
"JSON array of stage configs. Each stage needs at minimum "
'{"layers": [in, ..., out]}. Used only when creating a new cascade.'
),
),
io.String.Input(
"frozen_json",
default="[]",
multiline=False,
tooltip=(
"JSON bool array of frozen flags per stage. "
"[] = all trainable. Example: [true, false] = freeze stage 0."
),
),
io.Combo.Input(
"trainer_type",
options=["fixed", "local_rule", "evolutionary"],
tooltip="Default trainer type applied to stages that don't override it.",
),
io.Float.Input(
"learning_rate",
default=0.01,
min=1e-5,
max=1.0,
step=0.001,
display_mode=io.NumberDisplay.number,
tooltip="Default learning rate applied to stages that don't override it.",
),
],
outputs=[
io.Custom.Output("CASCADE"),
io.String.Output("info"),
],
)
@classmethod
def execute(
cls, model_path, stages_json, frozen_json, trainer_type, learning_rate
) -> io.NodeOutput:
import os
from stanno.config.schema import STANNOConfig
from stanno.core.stanno import STANNO
from stanno.integration.cascade import CascadeSTANNO
if os.path.isfile(model_path):
cascade = CascadeSTANNO.load(model_path)
info = (
f"Loaded: {model_path} | "
f"{len(cascade.stages)} stages | "
f"frozen={cascade.frozen}"
)
else:
stage_defs = json.loads(stages_json)
frozen = json.loads(frozen_json) if frozen_json.strip() else []
if not frozen:
frozen = [False] * len(stage_defs)
stages = []
for sd in stage_defs:
lr = sd.get("learning_rate", learning_rate)
tt = sd.get("trainer_type", trainer_type)
scfg = STANNOConfig(
layers=sd["layers"],
trainer_type=tt,
learning_rate=lr,
)
stages.append(STANNO(scfg))
cascade = CascadeSTANNO(stages, frozen=frozen)
topology = " β†’ ".join(
"Γ—".join(str(d) for d in s.config.layers) for s in stages
)
info = f"Created CascadeSTANNO: {topology} | frozen={frozen}"
print(f"[STANNO Cascade Loader] {info}")
return io.NodeOutput(cascade, info)
# ─── Node 9: Cascade Train on Images ─────────────────────────────────────────
class STANNOCascadeTrainImages(io.ComfyNode):
"""
Train a CascadeSTANNO end-to-end on a batch of images.
Implements the patent's "self-training within cascaded systems":
gradient flows from the final stage back through every non-frozen stage
via the cascade mechanism in FixedTrainerNet.
Autoencoder use-case (most common)
───────────────────────────────────
Set up a CascadeSTANNO with an encoder stage and a decoder stage. This
node trains the whole chain with input == target so the bottleneck is
forced to compress image content.
Partial training (frozen stages)
─────────────────────────────────
Freeze the encoder in STANNOCascadeLoad, then connect here. Only the
unfrozen decoder receives weight updates β€” useful for domain adaptation.
"""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="STANNOCascadeTrainImages",
display_name="STANNO Cascade Train from Images",
category="STANNO",
inputs=[
io.Image.Input("images"),
io.Custom.Input("CASCADE", "cascade"),
io.Int.Input(
"epochs",
default=100,
min=1,
max=5000,
step=10,
display_mode=io.NumberDisplay.number,
),
io.Int.Input(
"batch_size",
default=16,
min=1,
max=256,
step=8,
display_mode=io.NumberDisplay.number,
),
io.Int.Input(
"patience",
default=30,
min=0,
max=500,
step=5,
display_mode=io.NumberDisplay.number,
tooltip="Early stopping patience in epochs. 0 = disabled.",
),
io.String.Input(
"save_path",
default="",
multiline=False,
tooltip="Optional path to save the trained cascade as .pkl.",
),
],
outputs=[
io.Custom.Output("CASCADE"),
io.String.Output("training_log"),
],
)
@classmethod
def execute(cls, images, cascade, epochs, batch_size, patience, save_path) -> io.NodeOutput:
import copy
cascade_copy = copy.deepcopy(cascade)
# Use the first stage's input_dim for flattening
input_dim = cascade_copy.stages[0].config.layers[0]
output_dim = cascade_copy.stages[-1].config.layers[-1]
x = _flatten_images(images, input_dim).astype(np.float32) * 2.0 - 1.0
# Autoencoder: target is the image itself (needs matching output dim)
if output_dim == input_dim:
y = x
else:
# If dims differ, pad/trim y to match output dim
if x.shape[1] >= output_dim:
y = x[:, :output_dim]
else:
pad = np.zeros((x.shape[0], output_dim - x.shape[1]), dtype=np.float32)
y = np.hstack([x, pad])
log_lines: list[str] = []
report_every = max(1, epochs // 5)
def log_cb(epoch: int, loss: float) -> None:
if (epoch + 1) % report_every == 0:
line = f"epoch {epoch + 1:5d} loss={loss:.5f}"
log_lines.append(line)
print(f"[STANNO Cascade Train] {line}")
cascade_copy.fit(
x, y,
epochs=epochs,
batch_size=batch_size,
patience=patience,
log_every=0, # use callback instead
callback=log_cb,
)
save = save_path.strip()
if save:
os.makedirs(os.path.dirname(os.path.abspath(save)), exist_ok=True)
cascade_copy.save(save)
log_lines.append(f"Saved β†’ {save}")
print(f"[STANNO Cascade Train] Saved β†’ {save}")
return io.NodeOutput(cascade_copy, "\n".join(log_lines))
# ─── Extension registration ───────────────────────────────────────────────────
class STANNOExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
STANNOLoad,
STANNOTrainImages,
STANNOScoreImages,
STANNODreamCond,
STANNODynamicLoRA,
STANNOCompositeCheck,
STANNOScan,
STANNOCascadeLoad,
STANNOCascadeTrainImages,
]
async def comfy_entrypoint() -> STANNOExtension:
return STANNOExtension()