""" STANNO custom nodes for ComfyUI. Nodes: - STANNOLoad Load or create a STANNO model - STANNOTrainImages Train STANNO as autoencoder on a batch of images - STANNOScoreImages Score and filter images by reconstruction error - STANNODreamCond Inject dream-mode creativity into CONDITIONING - STANNODynamicLoRA Patch MODEL attention weights using STANNO dream output - STANNOCompositeCheck Score images against two STANNOs and route by winner Full integration guide: /mnt/juegos/proyectos/especiales/stanno/comfyui-stanno-integration.md """ from __future__ import annotations import os import json import pickle import sys import numpy as np import torch from comfy_api.latest import ComfyExtension, io # ─── helpers ──────────────────────────────────────────────────────────────── def _flatten_images(image_tensor: torch.Tensor, target_dim: int) -> np.ndarray: """ Resize each image in a ComfyUI IMAGE batch to produce a flat vector of exactly `target_dim` floats. IMAGE format: (B, H, W, C) float32 [0, 1]. """ import torch.nn.functional as F b, h, w, c = image_tensor.shape side = max(1, int(((target_dim // c) ** 0.5))) x = image_tensor.permute(0, 3, 1, 2) # B C H W x = F.interpolate(x, size=(side, side), mode="bilinear", align_corners=False) x = x.permute(0, 2, 3, 1).reshape(b, -1) # B (side²·C) # Trim or pad to exactly target_dim if x.shape[1] > target_dim: x = x[:, :target_dim] elif x.shape[1] < target_dim: pad = torch.zeros(b, target_dim - x.shape[1], device=x.device) x = torch.cat([x, pad], dim=1) return x.detach().cpu().numpy() # ─── Node 1: Load / Create STANNO ─────────────────────────────────────────── class STANNOLoad(io.ComfyNode): """ Load a saved STANNO model from disk, or create a new untrained one. If `model_path` points to an existing .pkl file it is loaded unchanged. Otherwise a new STANNO is created with the given architecture and trainer. The returned STANNO object can be passed to any other STANNO node. """ @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNOLoad", display_name="STANNO Loader", category="STANNO", inputs=[ io.String.Input( "model_path", default="stanno_model.pkl", multiline=False, tooltip="Path to a saved STANNO .pkl file, or a new filename to create.", ), io.String.Input( "layers_json", default="[1, 32, 1]", multiline=False, tooltip=( "JSON list of layer sizes. Examples:\n" " [1, 32, 1] sin regression (poc)\n" " [768, 256, 768] CLIP-embedding autoencoder (SD 1.5)\n" " [3072, 512, 3072] 32×32 pixel autoencoder\n" " [784, 256, 128, 10] classifier" ), ), io.Combo.Input( "trainer_type", options=["fixed", "local_rule", "evolutionary"], ), io.Float.Input( "learning_rate", default=0.01, min=1e-5, max=1.0, step=0.001, display_mode=io.NumberDisplay.number, ), ], outputs=[ io.Custom.Output("STANNO"), io.String.Output("info"), ], ) @classmethod def execute(cls, model_path, layers_json, trainer_type, learning_rate) -> io.NodeOutput: from stanno.config.schema import STANNOConfig from stanno.core.stanno import STANNO if os.path.isfile(model_path): with open(model_path, "rb") as f: stanno_obj = pickle.load(f) info = f"Loaded: {model_path} | layers={stanno_obj.config.layers}" else: layers = json.loads(layers_json) config = STANNOConfig( layers=layers, trainer_type=trainer_type, learning_rate=learning_rate, ) stanno_obj = STANNO(config) info = f"Created new STANNO | layers={layers} trainer={trainer_type} lr={learning_rate}" print(f"[STANNO Loader] {info}") return io.NodeOutput(stanno_obj, info) # ─── Node 2: Train on Images ───────────────────────────────────────────────── class STANNOTrainImages(io.ComfyNode): """ Train a STANNO as an autoencoder on a batch of images. Images are resized to match the STANNO's input dimension, normalized to [-1, 1], and used as both input and target (autoencoder). After training the STANNO 'remembers' the style/distribution of those images. Tip: connect the output STANNO to STANNOScoreImages to filter later generated images against this learned distribution. """ @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNOTrainImages", display_name="STANNO Train from Images", category="STANNO", inputs=[ io.Image.Input("images"), io.Custom.Input("STANNO", "stanno"), io.Int.Input("epochs", default=100, min=1, max=10000, step=10, display_mode=io.NumberDisplay.number), io.Int.Input("batch_size", default=16, min=1, max=512, step=8, display_mode=io.NumberDisplay.number), io.String.Input( "save_path", default="", multiline=False, tooltip="Optional: absolute path to save the trained STANNO as .pkl. Leave empty to skip.", ), ], outputs=[ io.Custom.Output("STANNO"), io.String.Output("training_log"), ], ) @classmethod def execute(cls, images, stanno, epochs, batch_size, save_path) -> io.NodeOutput: import copy stanno_copy = copy.deepcopy(stanno) input_dim = stanno_copy.config.layers[0] x = _flatten_images(images, input_dim).astype(np.float32) x = x * 2.0 - 1.0 # normalize to [-1, 1] log_lines: list[str] = [] report_every = max(1, epochs // 5) def log_cb(epoch: int, loss: float) -> None: if (epoch + 1) % report_every == 0: line = f"epoch {epoch + 1:5d} loss={loss:.5f}" log_lines.append(line) print(f"[STANNO Train] {line}") stanno_copy.fit(x, x, epochs=epochs, batch_size=batch_size, callback=log_cb) save = save_path.strip() if save: os.makedirs(os.path.dirname(os.path.abspath(save)), exist_ok=True) with open(save, "wb") as f: pickle.dump(stanno_copy, f) log_lines.append(f"Saved → {save}") return io.NodeOutput(stanno_copy, "\n".join(log_lines)) # ─── Node 3: Score & Filter Images ─────────────────────────────────────────── class STANNOScoreImages(io.ComfyNode): """ Score a batch of images using a trained STANNO autoencoder. Reconstruction MSE is the anomaly score: low = in-distribution (style match), high = outlier. Outputs the full batch sorted by score plus a filtered sub-batch containing only images below the threshold. """ @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNOScoreImages", display_name="STANNO Image Scorer", category="STANNO", inputs=[ io.Image.Input("images"), io.Custom.Input("STANNO", "stanno"), io.Float.Input( "threshold", default=0.10, min=0.0, max=2.0, step=0.005, display_mode=io.NumberDisplay.slider, tooltip="MSE above this value is flagged as anomaly / style mismatch.", ), io.Combo.Input( "sort_order", options=["best_first", "worst_first", "original"], ), ], outputs=[ io.Image.Output(), # sorted batch io.Image.Output(), # filtered batch (below threshold) io.String.Output("scores_json"), ], ) @classmethod def execute(cls, images, stanno, threshold, sort_order) -> io.NodeOutput: from stanno.integration.dsanno import DSANNO input_dim = stanno.config.layers[0] x = _flatten_images(images, input_dim).astype(np.float32) * 2.0 - 1.0 scanner = DSANNO(stanno, mode="reconstruction") scores_arr, preds = scanner.score_batch(x) scores = scores_arr.tolist() max_s = max(scores) if max(scores) > 0 else 1.0 norm_scores = [s / max_s for s in scores] indices = list(range(len(scores))) if sort_order == "best_first": indices.sort(key=lambda i: scores[i]) elif sort_order == "worst_first": indices.sort(key=lambda i: -scores[i]) sorted_images = images[torch.tensor(indices, device=images.device)] filtered_idx = [i for i in indices if scores[i] < threshold] filtered_images = ( images[torch.tensor(filtered_idx, device=images.device)] if filtered_idx else images[:1] ) scores_data = [ { "index": i, "mse": round(scores[i], 5), "norm": round(norm_scores[i], 4), "pass": scores[i] < threshold, } for i in range(len(scores)) ] return io.NodeOutput(sorted_images, filtered_images, json.dumps(scores_data, indent=2)) # ─── Node 4: Dream Conditioning ────────────────────────────────────────────── class STANNODreamCond(io.ComfyNode): """ Modify a CLIP CONDITIONING tensor using STANNO dream mode. The STANNO must have been trained on CLIP embeddings (768-dim per token for SD 1.5). Each token in the conditioning is fed as an input seed to dream(), perturbed by noise, and the result is blended back with the original. noise_sigma controls the creativity spectrum: 0.00–0.02 almost identical to original prompt 0.05–0.15 subtle but noticeable style shift (recommended starting point) 0.20–0.40 creative variations, may drift from original prompt meaning 0.50+ chaotic, unpredictable (useful for pure exploration) blend_strength controls how much of the dream replaces the original: 0.0 = original conditioning unchanged 1.0 = full dream output (ignore original) 0.1–0.3 recommended for most workflows """ @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNODreamCond", display_name="STANNO Dream Conditioning", category="STANNO", inputs=[ io.Conditioning.Input("conditioning"), io.Custom.Input("STANNO", "stanno"), io.Float.Input( "noise_sigma", default=0.05, min=0.0, max=2.0, step=0.01, display_mode=io.NumberDisplay.slider, ), io.Float.Input( "blend_strength", default=0.20, min=0.0, max=1.0, step=0.01, display_mode=io.NumberDisplay.slider, ), io.Int.Input( "seed", default=42, min=0, max=2 ** 31, display_mode=io.NumberDisplay.number, ), io.Combo.Input( "feedback_projection", options=["repeat", "linear", "zeros"], ), ], outputs=[ io.Conditioning.Output(), # modified conditioning io.Conditioning.Output(), # original pass-through ], ) @classmethod def execute( cls, conditioning, stanno, noise_sigma, blend_strength, seed, feedback_projection ) -> io.NodeOutput: import copy rng = np.random.default_rng(seed) result = [] for cond_tensor, cond_meta in conditioning: # cond_tensor: (1, seq_len, embed_dim) e.g. (1, 77, 768) for SD 1.5 orig_np = cond_tensor.detach().cpu().numpy().astype(np.float32) b, seq, dim = orig_np.shape dream_tokens: list[np.ndarray] = [] for token_idx in range(seq): seed_vec = orig_np[0, token_idx, :].reshape(1, -1) # (1, dim) dream_out = stanno.dream( num_steps=1, input_seed=seed_vec, noise_sigma=noise_sigma, blind_inputs=False, rng=rng, ) dream_tokens.append(dream_out[0]) # (dim,) dream_cond = np.stack(dream_tokens, axis=0)[np.newaxis] # (1, seq, dim) blended = (1.0 - blend_strength) * orig_np + blend_strength * dream_cond blended_t = torch.from_numpy(blended).to( device=cond_tensor.device, dtype=cond_tensor.dtype ) result.append((blended_t, copy.deepcopy(cond_meta))) return io.NodeOutput(result, conditioning) # ─── Node 5: Dynamic LoRA (weight-space patching) ──────────────────────────── class STANNODynamicLoRA(io.ComfyNode): """ Inject STANNO dream output as LoRA-equivalent weight patches into a MODEL. STANNO generates `lora_rank` dream vectors. These are stacked into A (up) and B (down) projection matrices and applied to the SD 1.5 cross-attention layers via ComfyUI's native add_patches() mechanism. Requirements: - STANNO layers[0] == layers[-1] == 768 (SD 1.5 cross-attention dim) - Recommended: train STANNO on CLIP embeddings of your target style first Parameter guide: lora_rank 1–2 → subtle and stable; 4–8 → stronger but may cause drift alpha 0.3–0.5 is a good starting point for SD 1.5 noise_sigma 0.0 → deterministic style from STANNO weights 0.1–0.2 → creative variations per run """ # Cross-attention projections in SD 1.5 UNet (12 representative layers; # add more from the full model key list for stronger effect). _ATTN_KEYS = [ "diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight", "diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight", "diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight", "diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight", "diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", "diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight", "diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight", "diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight", "diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight", "diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight", "diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight", "diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight", ] @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNODynamicLoRA", display_name="STANNO Dynamic LoRA", category="STANNO", inputs=[ io.Model.Input("model"), io.Custom.Input("STANNO", "stanno"), io.Float.Input( "alpha", default=0.5, min=0.0, max=2.0, step=0.05, display_mode=io.NumberDisplay.slider, tooltip="LoRA scaling factor. Start at 0.3–0.5 for SD 1.5.", ), io.Int.Input( "lora_rank", default=2, min=1, max=16, step=1, display_mode=io.NumberDisplay.number, tooltip="Rank of the injected A×B matrices. Lower = more stable.", ), io.Float.Input( "noise_sigma", default=0.10, min=0.0, max=1.0, step=0.01, display_mode=io.NumberDisplay.slider, ), io.Int.Input( "seed", default=0, min=0, max=2 ** 31, display_mode=io.NumberDisplay.number, ), ], outputs=[ io.Model.Output(), io.String.Output("patch_info"), ], ) @classmethod def execute(cls, model, stanno, alpha, lora_rank, noise_sigma, seed) -> io.NodeOutput: rng = np.random.default_rng(seed) dim = stanno.config.layers[0] rank = min(lora_rank, dim) # Generate `rank` dream vectors — each is a (dim,) style direction basis: list[np.ndarray] = [] for _ in range(rank): seed_vec = rng.normal(0.0, 0.1, (1, dim)).astype(np.float32) dream_out = stanno.dream( num_steps=1, input_seed=seed_vec, noise_sigma=noise_sigma, blind_inputs=False, rng=rng, ) basis.append(dream_out[0]) # (dim,) # A: (rank, dim) B: (dim, rank) A = np.stack(basis, axis=0) norms = np.linalg.norm(A, axis=1, keepdims=True).clip(min=1e-8) A_norm = (A / norms).astype(np.float32) B = A_norm.T.astype(np.float32) A_t = torch.from_numpy(A_norm) B_t = torch.from_numpy(B) # ComfyUI LoRA patch format: {key: ("lora", (down, up))} patches = {key: ("lora", (B_t, A_t)) for key in cls._ATTN_KEYS} patched_model = model.clone() patched_model.add_patches(patches, alpha) info = ( f"Patched {len(patches)} attention layers | " f"rank={rank} | alpha={alpha:.2f} | noise={noise_sigma:.3f} | seed={seed}" ) print(f"[STANNO DynamicLoRA] {info}") return io.NodeOutput(patched_model, info) # ─── Node 6: Composite Style Checker ───────────────────────────────────────── class STANNOCompositeCheck(io.ComfyNode): """ Score a batch of images against two STANNOs and route by the closer match. In a composite / inpainting workflow different image zones should each match a particular trained style. This node splits a batch into two sub-batches based on which STANNO has lower reconstruction error, and reports the margin so you can identify ambiguous images. Typical use: connect two STANNOs trained on 'background style' and 'foreground style'; route each generated image to the right inpaint layer. """ @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNOCompositeCheck", display_name="STANNO Composite Style Checker", category="STANNO", inputs=[ io.Image.Input("images"), io.Custom.Input("STANNO", "stanno_a"), io.Custom.Input("STANNO", "stanno_b"), io.String.Input("label_a", default="Style A", multiline=False), io.String.Input("label_b", default="Style B", multiline=False), ], outputs=[ io.Image.Output(), # images closest to Style A io.Image.Output(), # images closest to Style B io.String.Output("report_json"), ], ) @classmethod def execute(cls, images, stanno_a, stanno_b, label_a, label_b) -> io.NodeOutput: dim_a = stanno_a.config.layers[0] dim_b = stanno_b.config.layers[0] xa = _flatten_images(images, dim_a).astype(np.float32) * 2.0 - 1.0 xb = _flatten_images(images, dim_b).astype(np.float32) * 2.0 - 1.0 scores_a = np.mean((stanno_a.predict(xa) - xa) ** 2, axis=1) scores_b = np.mean((stanno_b.predict(xb) - xb) ** 2, axis=1) idx_a = [i for i in range(len(scores_a)) if scores_a[i] <= scores_b[i]] idx_b = [i for i in range(len(scores_a)) if scores_a[i] > scores_b[i]] imgs_a = ( images[torch.tensor(idx_a, device=images.device)] if idx_a else images[:1] ) imgs_b = ( images[torch.tensor(idx_b, device=images.device)] if idx_b else images[:1] ) report = [ { "index": i, label_a: round(float(scores_a[i]), 5), label_b: round(float(scores_b[i]), 5), "winner": label_a if scores_a[i] <= scores_b[i] else label_b, "margin": round(abs(float(scores_a[i]) - float(scores_b[i])), 5), } for i in range(len(scores_a)) ] return io.NodeOutput(imgs_a, imgs_b, json.dumps(report, indent=2)) # ─── Node 7: DSANNO Scan ────────────────────────────────────────────────────── class STANNOScan(io.ComfyNode): """ DSANNO — Data Scanning Artificial Neural Network Object. Scans a batch of images and finds the ones that best match what the STANNO has learned, implementing the patent's DSANNO concept: "scan large regions of the data space looking for patterns that match the learned representation." Two modes ───────── auto_calibrate = ON (recommended) The threshold is computed automatically from this very batch at the given percentile. E.g. percentile=20 keeps the best-matching 20 %. auto_calibrate = OFF Use the manually supplied ``threshold`` value directly. Outputs ─────── top_k_images — the k images with lowest reconstruction error matched_images — all images below the threshold scores_json — per-image scores and match flags threshold — the threshold that was applied (useful for display/routing) """ @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNOScan", display_name="STANNO Scan (DSANNO)", category="STANNO", inputs=[ io.Image.Input("images"), io.Custom.Input("STANNO", "stanno"), io.Int.Input( "top_k", default=4, min=1, max=64, step=1, display_mode=io.NumberDisplay.number, tooltip="Return this many best-matching images regardless of threshold.", ), io.Combo.Input( "auto_calibrate", options=["on", "off"], tooltip=( "on: compute threshold automatically from this batch at the " "given percentile.\n" "off: use the manual threshold value." ), ), io.Float.Input( "percentile", default=30.0, min=1.0, max=99.0, step=1.0, display_mode=io.NumberDisplay.slider, tooltip=( "Used when auto_calibrate=on. " "30 = keep the best-matching 30 % of the batch." ), ), io.Float.Input( "threshold", default=0.10, min=0.0, max=5.0, step=0.005, display_mode=io.NumberDisplay.number, tooltip="Manual threshold (used only when auto_calibrate=off).", ), ], outputs=[ io.Image.Output(), # top_k_images io.Image.Output(), # matched_images io.String.Output("scores_json"), io.Float.Output("threshold"), ], ) @classmethod def execute(cls, images, stanno, top_k, auto_calibrate, percentile, threshold) -> io.NodeOutput: from stanno.integration.dsanno import DSANNO input_dim = stanno.config.layers[0] x = _flatten_images(images, input_dim).astype(np.float32) * 2.0 - 1.0 scanner = DSANNO(stanno, mode="reconstruction") result = scanner.scan(x) # Determine threshold if auto_calibrate == "on": used_threshold = float(np.percentile(result.scores, percentile)) else: used_threshold = float(threshold) result.set_threshold(used_threshold) # top_k images k = min(int(top_k), len(images)) top_indices, top_scores, _ = scanner.top_k(x, k=k) top_images = images[torch.tensor(top_indices.tolist(), device=images.device)] # matched images (below threshold) matched_idx = result.matched_indices().tolist() matched_images = ( images[torch.tensor(matched_idx, device=images.device)] if matched_idx else images[:1] ) scores_data = [ { "index": int(i), "mse": round(float(result.scores[i]), 5), "matched": bool(result.matched_mask[i]), "rank": int(np.where(np.argsort(result.scores) == i)[0][0]) + 1, } for i in range(len(result.scores)) ] print( f"[STANNO Scan] threshold={used_threshold:.4f} | " f"matched={len(matched_idx)}/{len(images)} | top_k={k}" ) return io.NodeOutput( top_images, matched_images, json.dumps(scores_data, indent=2), used_threshold, ) # ─── Node 8: Cascade Load / Create ─────────────────────────────────────────── class STANNOCascadeLoad(io.ComfyNode): """ Load or create a CascadeSTANNO — a chain of STANNO stages. Implements the patent's "cascading networks to form system models": the output of stage k feeds the input of stage k+1, and each stage can be independently frozen or adapted. Typical uses ──────────── Encoder + Decoder autoencoder: stages_json = [{"layers": [3072, 512]}, {"layers": [512, 3072]}] Progressive compression pipeline: stages_json = [{"layers":[768,256]}, {"layers":[256,64]}, {"layers":[64,256]}, {"layers":[256,768]}] Frozen pre-processor + adaptive head: stages_json = [{"layers":[768,256]}, {"layers":[256,10]}] frozen_json = [true, false] stages_json format ────────────────── JSON array of objects. Keys: "layers" required — e.g. [768, 256, 768] "trainer_type" optional — "fixed"|"local_rule"|"evolutionary" (default "fixed") "learning_rate" optional — per-stage lr (default: uses the top-level lr) """ @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNOCascadeLoad", display_name="STANNO Cascade Loader", category="STANNO", inputs=[ io.String.Input( "model_path", default="cascade_model.pkl", multiline=False, tooltip="Path to a saved CascadeSTANNO .pkl, or a new filename to create.", ), io.String.Input( "stages_json", default='[{"layers": [3072, 512]}, {"layers": [512, 3072]}]', multiline=True, tooltip=( "JSON array of stage configs. Each stage needs at minimum " '{"layers": [in, ..., out]}. Used only when creating a new cascade.' ), ), io.String.Input( "frozen_json", default="[]", multiline=False, tooltip=( "JSON bool array of frozen flags per stage. " "[] = all trainable. Example: [true, false] = freeze stage 0." ), ), io.Combo.Input( "trainer_type", options=["fixed", "local_rule", "evolutionary"], tooltip="Default trainer type applied to stages that don't override it.", ), io.Float.Input( "learning_rate", default=0.01, min=1e-5, max=1.0, step=0.001, display_mode=io.NumberDisplay.number, tooltip="Default learning rate applied to stages that don't override it.", ), ], outputs=[ io.Custom.Output("CASCADE"), io.String.Output("info"), ], ) @classmethod def execute( cls, model_path, stages_json, frozen_json, trainer_type, learning_rate ) -> io.NodeOutput: import os from stanno.config.schema import STANNOConfig from stanno.core.stanno import STANNO from stanno.integration.cascade import CascadeSTANNO if os.path.isfile(model_path): cascade = CascadeSTANNO.load(model_path) info = ( f"Loaded: {model_path} | " f"{len(cascade.stages)} stages | " f"frozen={cascade.frozen}" ) else: stage_defs = json.loads(stages_json) frozen = json.loads(frozen_json) if frozen_json.strip() else [] if not frozen: frozen = [False] * len(stage_defs) stages = [] for sd in stage_defs: lr = sd.get("learning_rate", learning_rate) tt = sd.get("trainer_type", trainer_type) scfg = STANNOConfig( layers=sd["layers"], trainer_type=tt, learning_rate=lr, ) stages.append(STANNO(scfg)) cascade = CascadeSTANNO(stages, frozen=frozen) topology = " → ".join( "×".join(str(d) for d in s.config.layers) for s in stages ) info = f"Created CascadeSTANNO: {topology} | frozen={frozen}" print(f"[STANNO Cascade Loader] {info}") return io.NodeOutput(cascade, info) # ─── Node 9: Cascade Train on Images ───────────────────────────────────────── class STANNOCascadeTrainImages(io.ComfyNode): """ Train a CascadeSTANNO end-to-end on a batch of images. Implements the patent's "self-training within cascaded systems": gradient flows from the final stage back through every non-frozen stage via the cascade mechanism in FixedTrainerNet. Autoencoder use-case (most common) ─────────────────────────────────── Set up a CascadeSTANNO with an encoder stage and a decoder stage. This node trains the whole chain with input == target so the bottleneck is forced to compress image content. Partial training (frozen stages) ───────────────────────────────── Freeze the encoder in STANNOCascadeLoad, then connect here. Only the unfrozen decoder receives weight updates — useful for domain adaptation. """ @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="STANNOCascadeTrainImages", display_name="STANNO Cascade Train from Images", category="STANNO", inputs=[ io.Image.Input("images"), io.Custom.Input("CASCADE", "cascade"), io.Int.Input( "epochs", default=100, min=1, max=5000, step=10, display_mode=io.NumberDisplay.number, ), io.Int.Input( "batch_size", default=16, min=1, max=256, step=8, display_mode=io.NumberDisplay.number, ), io.Int.Input( "patience", default=30, min=0, max=500, step=5, display_mode=io.NumberDisplay.number, tooltip="Early stopping patience in epochs. 0 = disabled.", ), io.String.Input( "save_path", default="", multiline=False, tooltip="Optional path to save the trained cascade as .pkl.", ), ], outputs=[ io.Custom.Output("CASCADE"), io.String.Output("training_log"), ], ) @classmethod def execute(cls, images, cascade, epochs, batch_size, patience, save_path) -> io.NodeOutput: import copy cascade_copy = copy.deepcopy(cascade) # Use the first stage's input_dim for flattening input_dim = cascade_copy.stages[0].config.layers[0] output_dim = cascade_copy.stages[-1].config.layers[-1] x = _flatten_images(images, input_dim).astype(np.float32) * 2.0 - 1.0 # Autoencoder: target is the image itself (needs matching output dim) if output_dim == input_dim: y = x else: # If dims differ, pad/trim y to match output dim if x.shape[1] >= output_dim: y = x[:, :output_dim] else: pad = np.zeros((x.shape[0], output_dim - x.shape[1]), dtype=np.float32) y = np.hstack([x, pad]) log_lines: list[str] = [] report_every = max(1, epochs // 5) def log_cb(epoch: int, loss: float) -> None: if (epoch + 1) % report_every == 0: line = f"epoch {epoch + 1:5d} loss={loss:.5f}" log_lines.append(line) print(f"[STANNO Cascade Train] {line}") cascade_copy.fit( x, y, epochs=epochs, batch_size=batch_size, patience=patience, log_every=0, # use callback instead callback=log_cb, ) save = save_path.strip() if save: os.makedirs(os.path.dirname(os.path.abspath(save)), exist_ok=True) cascade_copy.save(save) log_lines.append(f"Saved → {save}") print(f"[STANNO Cascade Train] Saved → {save}") return io.NodeOutput(cascade_copy, "\n".join(log_lines)) # ─── Extension registration ─────────────────────────────────────────────────── class STANNOExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ STANNOLoad, STANNOTrainImages, STANNOScoreImages, STANNODreamCond, STANNODynamicLoRA, STANNOCompositeCheck, STANNOScan, STANNOCascadeLoad, STANNOCascadeTrainImages, ] async def comfy_entrypoint() -> STANNOExtension: return STANNOExtension()