File size: 5,923 Bytes
ed37502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Builds ComfyUI API-format workflow JSON from templates and parameters.

The workflow builder loads base workflow templates (JSON files representing
ComfyUI node graphs) and injects generation-specific values: checkpoint,
LoRAs, prompts, seeds, dimensions, and output filenames.
"""

from __future__ import annotations

import copy
import json
import logging
import os
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

IS_HF_SPACES = os.environ.get("HF_SPACES") == "1" or os.environ.get("SPACE_ID") is not None
WORKFLOWS_DIR = Path("/app/config/templates/workflows") if IS_HF_SPACES else Path("D:/AI automation/content_engine/config/templates/workflows")


class WorkflowBuilder:
    """Constructs ComfyUI workflows from base templates + per-job parameters."""

    def __init__(self, workflows_dir: Path | None = None):
        self.workflows_dir = workflows_dir or WORKFLOWS_DIR
        self._cache: dict[str, dict] = {}

    def _load_template(self, name: str) -> dict:
        """Load and cache a base workflow JSON template."""
        if name not in self._cache:
            path = self.workflows_dir / f"{name}.json"
            if not path.exists():
                raise FileNotFoundError(f"Workflow template not found: {path}")
            with open(path) as f:
                self._cache[name] = json.load(f)
        return self._cache[name]

    def build(
        self,
        template_name: str = "sd15_base_sfw",
        *,
        checkpoint: str = "realisticVisionV51_v51VAE.safetensors",
        positive_prompt: str = "",
        negative_prompt: str = "",
        loras: list[dict[str, Any]] | None = None,
        seed: int = -1,
        steps: int = 28,
        cfg: float = 7.0,
        sampler_name: str = "dpmpp_2m",
        scheduler: str = "karras",
        width: int = 832,
        height: int = 1216,
        batch_size: int = 1,
        filename_prefix: str = "content_engine",
        denoise: float | None = None,
        reference_image: str | None = None,
    ) -> dict:
        """Build a complete workflow dict ready for ComfyUI /prompt endpoint.

        The base template must have these node IDs (by convention):
        - "1": CheckpointLoaderSimple
        - "2": CLIPTextEncode (positive)
        - "3": CLIPTextEncode (negative)
        - "4": EmptyLatentImage (txt2img) or absent for img2img
        - "5": KSampler
        - "6": VAEDecode
        - "7": SaveImage
        - "8": LoadImage (img2img only)
        - "9": VAEEncode (img2img only)
        - "10", "11", ...: LoraLoader chain (optional, added dynamically)
        """
        base = copy.deepcopy(self._load_template(template_name))

        # Checkpoint
        if "1" in base:
            base["1"]["inputs"]["ckpt_name"] = checkpoint

        # Prompts
        if "2" in base:
            base["2"]["inputs"]["text"] = positive_prompt
        if "3" in base:
            base["3"]["inputs"]["text"] = negative_prompt

        # Latent image dimensions (txt2img only)
        if "4" in base:
            base["4"]["inputs"]["width"] = width
            base["4"]["inputs"]["height"] = height
            base["4"]["inputs"]["batch_size"] = batch_size

        # KSampler
        if "5" in base:
            base["5"]["inputs"]["seed"] = seed if seed >= 0 else _random_seed()
            base["5"]["inputs"]["steps"] = steps
            base["5"]["inputs"]["cfg"] = cfg
            base["5"]["inputs"]["sampler_name"] = sampler_name
            base["5"]["inputs"]["scheduler"] = scheduler
            if denoise is not None:
                base["5"]["inputs"]["denoise"] = denoise

        # Reference image for img2img (LoadImage node)
        if "8" in base and reference_image:
            base["8"]["inputs"]["image"] = reference_image

        # SaveImage filename prefix
        if "7" in base:
            base["7"]["inputs"]["filename_prefix"] = filename_prefix

        # Inject LoRA chain
        if loras:
            base = self._inject_loras(base, loras)

        return base

    def _inject_loras(
        self, workflow: dict, loras: list[dict[str, Any]]
    ) -> dict:
        """Dynamically insert LoraLoader nodes into the workflow graph.

        Each LoRA gets a node ID starting at "10". The chain connects:
        checkpoint -> lora_10 -> lora_11 -> ... -> KSampler/CLIP nodes.
        """
        if not loras:
            return workflow

        # Determine where model and clip currently flow from
        # By default, KSampler (node 5) takes model from checkpoint (node 1, slot 0)
        # and CLIP encoders (nodes 2,3) take clip from checkpoint (node 1, slot 1)
        prev_model_ref = ["1", 0]  # checkpoint model output
        prev_clip_ref = ["1", 1]  # checkpoint clip output

        for i, lora_spec in enumerate(loras):
            node_id = str(10 + i)
            workflow[node_id] = {
                "class_type": "LoraLoader",
                "inputs": {
                    "lora_name": lora_spec["name"],
                    "strength_model": lora_spec.get("strength_model", 0.85),
                    "strength_clip": lora_spec.get("strength_clip", 0.85),
                    "model": prev_model_ref,
                    "clip": prev_clip_ref,
                },
            }
            prev_model_ref = [node_id, 0]
            prev_clip_ref = [node_id, 1]

        # Rewire KSampler to take model from last LoRA
        if "5" in workflow:
            workflow["5"]["inputs"]["model"] = prev_model_ref

        # Rewire CLIP text encoders to take clip from last LoRA
        if "2" in workflow:
            workflow["2"]["inputs"]["clip"] = prev_clip_ref
        if "3" in workflow:
            workflow["3"]["inputs"]["clip"] = prev_clip_ref

        return workflow


def _random_seed() -> int:
    """Generate a random seed in ComfyUI's expected range."""
    import random
    return random.randint(0, 2**32 - 1)