File size: 12,026 Bytes
cad34e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244c0fa
cad34e4
 
 
 
 
 
 
 
 
 
 
 
 
 
244c0fa
 
 
 
 
 
 
 
cad34e4
244c0fa
cad34e4
 
244c0fa
cad34e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""

NAIA-WEB Prompt Processor

Pipeline-based prompt processing with hooks



Reference: NAIA2.0/core/prompt_processor.py, NAIA2.0/modules/prompt_engineering_module.py

"""

import re
from dataclasses import dataclass, field
from typing import List, Set, Tuple

from utils.constants import QUALITY_TAGS_POSITIVE, QUALITY_TAGS_NEGATIVE


@dataclass
class PromptContext:
    """

    Context passed through the prompt processing pipeline.



    Carries all prompt-related data and settings through each stage.

    """
    positive_prompt: str
    negative_prompt: str

    # Processing flags
    use_quality_tags: bool = True

    # Pre/Post prompt additions
    pre_prompt: str = ""
    post_prompt: str = ""

    # Auto hide tags (tags to remove) - supports patterns
    auto_hide_tags: Set[str] = field(default_factory=set)

    # Removed tags tracking
    removed_tags: List[str] = field(default_factory=list)

    # Processing log for debugging
    processing_log: List[str] = field(default_factory=list)


class PromptProcessor:
    """

    Pipeline-based prompt processor.



    Processing order:

    1. Add pre-prompt

    2. Main prompt

    3. Add post-prompt

    4. Inject quality tags (if enabled)

    5. Remove auto-hide tags

    6. Clean up formatting

    """

    def process(self, context: PromptContext) -> PromptContext:
        """

        Run the full processing pipeline on a prompt context.



        Args:

            context: Initial prompt context



        Returns:

            Processed prompt context

        """
        # Step 1: Build positive prompt with pre/post
        context = self._build_positive_prompt(context)

        # Step 2: Inject quality tags
        if context.use_quality_tags:
            context = self._inject_quality_tags(context)

        # Step 3: Remove auto-hide tags
        if context.auto_hide_tags:
            context = self._remove_auto_hide_tags(context)

        # Step 4: Clean up formatting
        context = self._cleanup_prompt(context)

        return context

    # Person tag sets for reordering (from NAIA2.0)
    PERSON_TAGS = {
        "boys": {"1boy", "2boys", "3boys", "4boys", "5boys", "6+boys"},
        "girls": {"1girl", "2girls", "3girls", "4girls", "5girls", "6+girls"},
        "others": {"1other", "2others", "3others", "4others", "5others", "6+others"}
    }
    ALL_PERSON_TAGS = PERSON_TAGS["boys"] | PERSON_TAGS["girls"] | PERSON_TAGS["others"]

    def _build_positive_prompt(self, context: PromptContext) -> PromptContext:
        """

        Combine pre-prompt, main prompt, and post-prompt.



        Person tags (1girl, 2boys, etc.) are extracted from main prompt

        and moved to the front in order: boys -> girls -> others.



        Final order: [person tags], [pre-prompt], [main prompt], [post-prompt]

        """
        # Parse main prompt into tags
        main_tags = [t.strip() for t in context.positive_prompt.split(',') if t.strip()]

        # Extract person tags from main prompt
        person_tags_found = []
        other_main_tags = []

        for tag in main_tags:
            if tag.lower() in {pt.lower() for pt in self.ALL_PERSON_TAGS}:
                person_tags_found.append(tag)
            else:
                other_main_tags.append(tag)

        # Sort person tags: boys -> girls -> others
        sorted_person_tags = sorted(
            person_tags_found,
            key=lambda tag: (
                0 if tag.lower() in {pt.lower() for pt in self.PERSON_TAGS["boys"]} else
                1 if tag.lower() in {pt.lower() for pt in self.PERSON_TAGS["girls"]} else 2
            )
        )

        if sorted_person_tags:
            context.processing_log.append(f"Person tags moved to front: {', '.join(sorted_person_tags)}")

        # Build final prompt: [person tags], [pre-prompt], [main prompt], [post-prompt]
        parts = []

        # 1. Person tags (extracted from main prompt)
        if sorted_person_tags:
            parts.append(", ".join(sorted_person_tags))

        # 2. Pre-prompt
        if context.pre_prompt.strip():
            parts.append(context.pre_prompt.strip())
            context.processing_log.append("Added pre-prompt")

        # 3. Main prompt (without person tags)
        if other_main_tags:
            parts.append(", ".join(other_main_tags))

        # 4. Post-prompt
        if context.post_prompt.strip():
            parts.append(context.post_prompt.strip())
            context.processing_log.append("Added post-prompt")

        context.positive_prompt = ", ".join(parts)
        return context

    def _inject_quality_tags(self, context: PromptContext) -> PromptContext:
        """

        Inject quality tags if enabled.



        Positive quality tags are only appended to the END of the prompt

        if the user's post_prompt does NOT contain "quality".

        This allows users to customize quality tags via post_prompt.



        Negative quality tags are appended only if not already present.

        """
        # Check if post_prompt contains "quality" (case-insensitive)
        has_quality_in_post = "quality" in context.post_prompt.lower()

        # Append positive quality tags only if post_prompt doesn't have "quality"
        if not has_quality_in_post:
            if context.positive_prompt:
                context.positive_prompt = f"{context.positive_prompt}, {QUALITY_TAGS_POSITIVE}"
            else:
                context.positive_prompt = QUALITY_TAGS_POSITIVE
            context.processing_log.append("Appended positive quality tags (post_prompt has no 'quality')")
        else:
            context.processing_log.append("Skipped positive quality tags (post_prompt has 'quality')")

        # Append quality tags to negative prompt (only if not already present)
        # Check for signature pattern "lowres, {bad}" to detect existing quality tags
        negative_lower = context.negative_prompt.lower() if context.negative_prompt else ""
        has_quality_tags = "lowres, {bad}" in negative_lower

        if has_quality_tags:
            context.processing_log.append("Skipped negative quality tags (already present)")
        elif context.negative_prompt:
            context.negative_prompt = f"{context.negative_prompt}, {QUALITY_TAGS_NEGATIVE}"
            context.processing_log.append("Injected negative quality tags")
        else:
            context.negative_prompt = QUALITY_TAGS_NEGATIVE
            context.processing_log.append("Injected negative quality tags")

        return context

    def _remove_auto_hide_tags(self, context: PromptContext) -> PromptContext:
        """

        Remove auto-hide tags from the prompt with pattern support.



        Pattern syntax (from NAIA2.0):

        - `tag`: Exact match removal

        - `_pattern_`: Remove tags containing 'pattern' (e.g., _hair_ → blonde hair)

        - `_pattern`: Remove tags ending with 'pattern'

        - `pattern_`: Remove tags starting with 'pattern'

        - `~keyword`: Protect keyword from removal

        """
        if not context.auto_hide_tags:
            return context

        # Parse tags from positive prompt
        tags = [t.strip() for t in context.positive_prompt.split(',') if t.strip()]

        # Separate protected keywords (starting with ~) and patterns
        protected_keywords = set()
        auto_hide_patterns = []

        for item in context.auto_hide_tags:
            item = item.strip()
            if not item:
                continue
            if item.startswith('~'):
                # Protected keyword
                protected_keywords.add(item[1:].strip().lower())
            else:
                auto_hide_patterns.append(item)

        # Build removal list
        to_remove = set()

        for pattern in auto_hide_patterns:
            pattern_lower = pattern.lower()

            # Pattern matching logic from NAIA2.0
            if pattern.startswith('__') and pattern.endswith('__') and len(pattern) > 4:
                # __pattern__: contains match (double underscore)
                # Remove all underscores for search
                search_term = pattern[2:-2].replace('_', '')
                for tag in tags:
                    if search_term.lower() in tag.lower().replace(' ', ''):
                        to_remove.add(tag)

            elif pattern.startswith('_') and pattern.endswith('_') and len(pattern) > 2:
                # _pattern_: contains match (single underscore, space-based)
                search_term = pattern[1:-1].replace('_', ' ')
                for tag in tags:
                    if search_term.lower() in tag.lower():
                        to_remove.add(tag)

            elif pattern.startswith('_') and not pattern.endswith('_'):
                # _pattern: ends with match
                search_term = pattern[1:].replace('_', ' ')
                for tag in tags:
                    if tag.lower().endswith(search_term.lower()):
                        to_remove.add(tag)

            elif pattern.endswith('_') and not pattern.startswith('_'):
                # pattern_: starts with match
                search_term = pattern[:-1].replace('_', ' ')
                for tag in tags:
                    if tag.lower().startswith(search_term.lower()):
                        to_remove.add(tag)

            else:
                # Exact match
                for tag in tags:
                    if tag.lower() == pattern_lower:
                        to_remove.add(tag)

        # Remove protected keywords from removal list
        if protected_keywords:
            protected_to_keep = set()
            for tag in to_remove:
                tag_lower = tag.lower()
                for protected in protected_keywords:
                    if protected in tag_lower or tag_lower == protected:
                        protected_to_keep.add(tag)
                        break
            to_remove -= protected_to_keep

            if protected_to_keep:
                context.processing_log.append(f"Protected tags: {', '.join(protected_to_keep)}")

        # Apply removal
        filtered = [t for t in tags if t not in to_remove]
        context.removed_tags = list(to_remove)

        context.positive_prompt = ", ".join(filtered)

        if to_remove:
            context.processing_log.append(f"Auto-hide removed {len(to_remove)} tags: {', '.join(sorted(to_remove))}")
        else:
            context.processing_log.append("Auto-hide: no tags matched")

        return context

    def _cleanup_prompt(self, context: PromptContext) -> PromptContext:
        """Clean up prompt formatting"""
        # Process positive prompt
        context.positive_prompt = self._clean_text(context.positive_prompt)

        # Process negative prompt
        context.negative_prompt = self._clean_text(context.negative_prompt)

        context.processing_log.append("Cleaned up formatting")
        return context

    def _clean_text(self, text: str) -> str:
        """Clean a single text string"""
        if not text:
            return ""

        # Remove extra whitespace
        text = ' '.join(text.split())

        # Remove duplicate commas
        text = re.sub(r',\s*,+', ',', text)

        # Remove spaces around commas
        text = re.sub(r'\s*,\s*', ', ', text)

        # Strip leading/trailing commas and whitespace
        text = text.strip(' ,')

        return text


def parse_tags_from_text(text: str) -> List[str]:
    """

    Parse comma-separated tags from text.



    Args:

        text: Comma-separated tag string



    Returns:

        List of individual tags (stripped)

    """
    if not text:
        return []

    return [t.strip() for t in text.split(',') if t.strip()]