File size: 13,402 Bytes
23b413b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""
Edit flags parser for advanced edit controls.

This module allows the frontend to pass edit parameters via hidden flags
appended to the edit instruction, without modifying the /chat API schema.

Example:
    User input: "Make the sky dramatic --steps 30 --cfg 6.5 --denoise 0.55"

    Parsed result:
        - clean_text: "Make the sky dramatic"
        - flags: EditFlags(steps=30, cfg=6.5, denoise=0.55, ...)
"""

import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Dict, Any

# Pattern to match flags like --key value
FLAG_RE = re.compile(r"(--[a-zA-Z0-9_-]+)\s+([^\s-][^\n\r]*?)(?=\s+--|$)")

# Pattern to match URLs (http/https)
URL_RE = re.compile(r"https?://\S+", re.I)

# Pattern to match edit command keywords at the start
EDIT_CMD_RE = re.compile(r"^\s*(edit|inpaint|modify)\s+", re.I)


@dataclass
class EditFlags:
    """
    Container for parsed edit flags.

    All values have sensible defaults for quality image editing.
    """
    # Edit mode: auto, global, inpaint
    mode: str = "auto"

    # Generation parameters
    steps: int = 30
    cfg: float = 5.5
    denoise: float = 0.55
    seed: int = 0
    sampler_name: str = "euler"
    scheduler: str = "normal"

    # Model controls
    ckpt_name: Optional[str] = None
    controlnet_name: Optional[str] = None
    controlnet_strength: float = 1.0
    cn_enabled: bool = False

    # Optional mask (URL or filename; backend will preprocess)
    mask_url: Optional[str] = None

    # LoRA adapters: list of {"id": str, "weight": float}
    loras: List[Dict[str, Any]] = field(default_factory=list)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for workflow variables."""
        return {
            "mode": self.mode,
            "steps": self.steps,
            "cfg": self.cfg,
            "denoise": self.denoise,
            "seed": self.seed,
            "sampler_name": self.sampler_name,
            "scheduler": self.scheduler,
            "ckpt_name": self.ckpt_name,
            "controlnet_name": self.controlnet_name,
            "controlnet_strength": self.controlnet_strength,
            "cn_enabled": self.cn_enabled,
            "mask_url": self.mask_url,
            "loras": self.loras,
        }


def parse_edit_flags(text: str) -> Tuple[str, EditFlags]:
    """
    Extract flags from user text and return (clean_text, flags).

    Cleans the prompt by removing:
    - Edit command prefix (edit, inpaint, modify)
    - URLs (image URLs that are handled separately)
    - Flags like --steps 30, --cfg 6.5

    Args:
        text: User's edit instruction potentially containing flags and URLs

    Returns:
        Tuple of (cleaned text without flags/URLs, parsed EditFlags object)
    """
    flags = EditFlags()
    clean = text

    # First, parse all flags
    matches = list(FLAG_RE.finditer(text))

    for m in matches:
        key = (m.group(1) or "").strip().lower()
        value = (m.group(2) or "").strip()

        try:
            if key == "--mode":
                flags.mode = value.lower()
            elif key == "--steps":
                flags.steps = int(float(value))
            elif key == "--cfg":
                flags.cfg = float(value)
            elif key == "--denoise":
                flags.denoise = float(value)
            elif key == "--seed":
                flags.seed = int(float(value))
            elif key == "--sampler":
                flags.sampler_name = value
            elif key == "--scheduler":
                flags.scheduler = value
            elif key == "--ckpt":
                flags.ckpt_name = value
            elif key == "--cn":
                flags.cn_enabled = value.lower() in ("1", "true", "yes", "on")
            elif key == "--cn-strength":
                flags.controlnet_strength = float(value)
            elif key == "--controlnet":
                flags.controlnet_name = value
            elif key == "--mask":
                flags.mask_url = value
            elif key == "--lora":
                # Format: "lora_id:weight" or just "lora_id" (default weight 0.8)
                for lora_token in value.split():
                    if ":" in lora_token:
                        lora_id, lora_w = lora_token.rsplit(":", 1)
                        try:
                            w = float(lora_w)
                        except ValueError:
                            w = 0.8
                    else:
                        lora_id = lora_token
                        w = 0.8
                    lora_id = lora_id.strip()
                    if lora_id:
                        flags.loras.append({"id": lora_id, "weight": w, "enabled": True})
        except (ValueError, TypeError):
            # Skip invalid flag values
            pass

    # Remove all flags from the text
    for m in reversed(matches):
        clean = clean[:m.start()] + clean[m.end():]

    # Remove URLs from the prompt (image URL is passed separately)
    clean = URL_RE.sub("", clean)

    # Remove edit command prefix (e.g., "edit ", "inpaint ")
    clean = EDIT_CMD_RE.sub("", clean)

    # Clean up extra whitespace
    clean = " ".join(clean.split()).strip()

    return clean, flags


def infer_edit_mode(user_text: str) -> str:
    """
    Heuristic routing: inpaint for localized edits, global for style changes.

    Args:
        user_text: Cleaned user edit instruction

    Returns:
        Inferred edit mode: "inpaint" or "global"
    """
    t = user_text.lower()

    # Localized edit keywords suggest inpainting
    inpaint_keywords = (
        "remove", "erase", "delete", "replace", "change", "add", "swap",
        "object", "logo", "text", "person", "background", "hat", "shirt",
        "hair", "face", "eye", "nose", "mouth", "hand", "arm", "leg",
        "building", "car", "tree", "sky", "water", "cloud"
    )

    # Global style keywords suggest full image regeneration
    global_keywords = (
        "cinematic", "anime", "oil painting", "watercolor", "cartoon",
        "night", "sunset", "dramatic lighting", "color grade", "style",
        "filter", "tone", "mood", "atmosphere", "aesthetic", "artistic",
        "vintage", "retro", "modern", "futuristic", "cyberpunk", "noir"
    )

    # Check for localized edit keywords
    for kw in inpaint_keywords:
        if kw in t:
            return "inpaint"

    # Check for global style keywords
    for kw in global_keywords:
        if kw in t:
            return "global"

    # Default to auto (let the workflow decide)
    return "auto"


def build_edit_workflow_vars(
    image_url: str,
    prompt: str,
    flags: EditFlags,
    negative_prompt: str = "",
    img_model: str = "",
) -> Dict[str, Any]:
    """
    Build workflow variables dictionary from parsed flags.

    Args:
        image_url: URL of the source image
        prompt: Cleaned edit instruction
        flags: Parsed edit flags
        negative_prompt: Negative prompt (optional)
        img_model: User-selected image model from frontend (e.g. "dreamshaper_8.safetensors")

    Returns:
        Dictionary of variables for ComfyUI workflow
    """
    import random

    # Default checkpoints based on workflow type
    # If the user selected a specific model in the frontend, use that as
    # the global default — this ensures SD1.5 LoRAs are not silently
    # skipped when the user has a SD1.5 checkpoint selected.
    ckpt_default_global = img_model or "sd_xl_base_1.0.safetensors"
    ckpt_default_inpaint = "sd_xl_base_1.0_inpainting_0.1.safetensors"
    ckpt_default_sd15_inpaint = "sd-v1-5-inpainting.ckpt"
    cn_default = "control_v11p_sd15_inpaint.safetensors"

    # Generate random seed if not explicitly set (prevents ComfyUI caching)
    seed = flags.seed
    if seed == 0 or seed == -1:
        seed = random.randint(1, 2147483647)

    vars_dict: Dict[str, Any] = {
        "image_path": image_url,
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "steps": flags.steps,
        "cfg": flags.cfg,
        "seed": seed,
        "denoise": flags.denoise,
        "sampler_name": flags.sampler_name,
        "scheduler": flags.scheduler,
        "filename_prefix": "homepilot_edit",
    }

    # Set checkpoint and mask based on whether mask is provided
    # IMPORTANT: Only set mask_path when there's actually a mask
    if flags.mask_url:
        # Inpaint mode - mask is provided
        vars_dict["mask_path"] = flags.mask_url
        vars_dict["ckpt_name"] = flags.ckpt_name or ckpt_default_inpaint

        # Optional ControlNet (recommended with SD1.5 inpaint)
        if flags.cn_enabled:
            vars_dict["ckpt_name"] = flags.ckpt_name or ckpt_default_sd15_inpaint
            vars_dict["controlnet_name"] = flags.controlnet_name or cn_default
            vars_dict["controlnet_strength"] = flags.controlnet_strength
    else:
        # Global edit (img2img) - no mask needed
        vars_dict["ckpt_name"] = flags.ckpt_name or ckpt_default_global

    # Pass LoRA list (will be consumed by comfy.py to inject LoraLoader nodes)
    # Auto-detect: if no --lora flags given, scan installed LoRAs for trigger word matches
    # Only selects LoRAs compatible with the current checkpoint architecture.
    if not flags.loras:
        try:
            from .models.lora_loader import scan_installed_loras, get_lora_dir, is_lora_compatible
            from .models.lora_registry import get_lora_by_id
            from .model_config import get_architecture

            lora_dir = get_lora_dir()
            if lora_dir.exists():
                prompt_lower = vars_dict["prompt"].lower()
                ckpt_name = vars_dict.get("ckpt_name", "")
                ckpt_arch = get_architecture(ckpt_name) if ckpt_name else ""
                installed = scan_installed_loras()
                # Track which trigger words already matched (pick best arch variant)
                matched_triggers: set = set()
                for lora_info in installed:
                    if not lora_info.get("healthy", True):
                        continue
                    lid = lora_info.get("id", "")
                    entry = get_lora_by_id(lid)
                    if not entry or not entry.trigger_words:
                        continue
                    # Check if any trigger word appears in the prompt
                    matched = any(tw.lower() in prompt_lower for tw in entry.trigger_words if len(tw) >= 3)
                    if not matched:
                        continue
                    # Skip if incompatible with current checkpoint
                    if ckpt_arch and entry.base:
                        if not is_lora_compatible(entry.base, ckpt_arch):
                            print(f"[LORA] Skipping auto-detect '{lid}' (base={entry.base}) — incompatible with checkpoint '{ckpt_arch}'")
                            continue
                    # Avoid duplicates: if another LoRA with the same trigger words
                    # was already added, skip (e.g. undressing_sd15 vs undressing_sdxl)
                    tw_key = frozenset(tw.lower() for tw in entry.trigger_words)
                    if tw_key in matched_triggers:
                        continue
                    matched_triggers.add(tw_key)
                    flags.loras.append({"id": lid, "weight": 0.8, "enabled": True})
                    print(f"[LORA] Auto-detected '{lid}' (base={entry.base}) from trigger words in prompt")

                # Cap auto-detected LoRAs at 4
                flags.loras = flags.loras[:4]
        except (ImportError, Exception) as e:
            print(f"[LORA] Auto-detect skipped: {e}")

    if flags.loras:
        vars_dict["_loras"] = flags.loras

        # Inject trigger words into prompt (industry best practice for LoRA activation)
        try:
            from .models.lora_registry import get_lora_by_id
            trigger_parts: list[str] = []
            for lr in flags.loras:
                entry = get_lora_by_id(lr["id"])
                if entry and entry.trigger_words:
                    for tw in entry.trigger_words:
                        if tw.lower() not in vars_dict["prompt"].lower():
                            trigger_parts.append(tw)
            if trigger_parts:
                vars_dict["prompt"] = ", ".join(trigger_parts) + ", " + vars_dict["prompt"]
        except ImportError:
            pass

    return vars_dict


def determine_workflow(flags: EditFlags, prompt: str) -> str:
    """
    Determine which workflow to use based on flags and prompt.

    IMPORTANT: Inpaint workflows REQUIRE a mask. If no mask is provided,
    we MUST use the standard edit (img2img) workflow, even if keywords
    suggest inpainting. The keyword-based mode detection is informational
    only - actual workflow selection depends on mask availability.

    Args:
        flags: Parsed edit flags
        prompt: Cleaned edit instruction

    Returns:
        Workflow name to use
    """
    # Critical: Only use inpaint workflows when a mask is actually provided
    # The InpaintModelConditioning node requires a noise_mask input
    if flags.mask_url:
        if flags.cn_enabled:
            return "edit_inpaint_cn"
        return "edit_inpaint"

    # Without a mask, use standard img2img edit workflow
    # This applies regardless of detected keywords (remove, replace, etc.)
    return "edit"