""" NAIA-WEB Prompt Processor Pipeline-based prompt processing with hooks Reference: NAIA2.0/core/prompt_processor.py, NAIA2.0/modules/prompt_engineering_module.py """ import re from dataclasses import dataclass, field from typing import List, Set, Tuple from utils.constants import QUALITY_TAGS_POSITIVE, QUALITY_TAGS_NEGATIVE @dataclass class PromptContext: """ Context passed through the prompt processing pipeline. Carries all prompt-related data and settings through each stage. """ positive_prompt: str negative_prompt: str # Processing flags use_quality_tags: bool = True # Pre/Post prompt additions pre_prompt: str = "" post_prompt: str = "" # Auto hide tags (tags to remove) - supports patterns auto_hide_tags: Set[str] = field(default_factory=set) # Removed tags tracking removed_tags: List[str] = field(default_factory=list) # Processing log for debugging processing_log: List[str] = field(default_factory=list) class PromptProcessor: """ Pipeline-based prompt processor. Processing order: 1. Add pre-prompt 2. Main prompt 3. Add post-prompt 4. Inject quality tags (if enabled) 5. Remove auto-hide tags 6. Clean up formatting """ def process(self, context: PromptContext) -> PromptContext: """ Run the full processing pipeline on a prompt context. Args: context: Initial prompt context Returns: Processed prompt context """ # Step 1: Build positive prompt with pre/post context = self._build_positive_prompt(context) # Step 2: Inject quality tags if context.use_quality_tags: context = self._inject_quality_tags(context) # Step 3: Remove auto-hide tags if context.auto_hide_tags: context = self._remove_auto_hide_tags(context) # Step 4: Clean up formatting context = self._cleanup_prompt(context) return context # Person tag sets for reordering (from NAIA2.0) PERSON_TAGS = { "boys": {"1boy", "2boys", "3boys", "4boys", "5boys", "6+boys"}, "girls": {"1girl", "2girls", "3girls", "4girls", "5girls", "6+girls"}, "others": {"1other", "2others", "3others", "4others", "5others", "6+others"} } ALL_PERSON_TAGS = PERSON_TAGS["boys"] | PERSON_TAGS["girls"] | PERSON_TAGS["others"] def _build_positive_prompt(self, context: PromptContext) -> PromptContext: """ Combine pre-prompt, main prompt, and post-prompt. Person tags (1girl, 2boys, etc.) are extracted from main prompt and moved to the front in order: boys -> girls -> others. Final order: [person tags], [pre-prompt], [main prompt], [post-prompt] """ # Parse main prompt into tags main_tags = [t.strip() for t in context.positive_prompt.split(',') if t.strip()] # Extract person tags from main prompt person_tags_found = [] other_main_tags = [] for tag in main_tags: if tag.lower() in {pt.lower() for pt in self.ALL_PERSON_TAGS}: person_tags_found.append(tag) else: other_main_tags.append(tag) # Sort person tags: boys -> girls -> others sorted_person_tags = sorted( person_tags_found, key=lambda tag: ( 0 if tag.lower() in {pt.lower() for pt in self.PERSON_TAGS["boys"]} else 1 if tag.lower() in {pt.lower() for pt in self.PERSON_TAGS["girls"]} else 2 ) ) if sorted_person_tags: context.processing_log.append(f"Person tags moved to front: {', '.join(sorted_person_tags)}") # Build final prompt: [person tags], [pre-prompt], [main prompt], [post-prompt] parts = [] # 1. Person tags (extracted from main prompt) if sorted_person_tags: parts.append(", ".join(sorted_person_tags)) # 2. Pre-prompt if context.pre_prompt.strip(): parts.append(context.pre_prompt.strip()) context.processing_log.append("Added pre-prompt") # 3. Main prompt (without person tags) if other_main_tags: parts.append(", ".join(other_main_tags)) # 4. Post-prompt if context.post_prompt.strip(): parts.append(context.post_prompt.strip()) context.processing_log.append("Added post-prompt") context.positive_prompt = ", ".join(parts) return context def _inject_quality_tags(self, context: PromptContext) -> PromptContext: """ Inject quality tags if enabled. Positive quality tags are only appended to the END of the prompt if the user's post_prompt does NOT contain "quality". This allows users to customize quality tags via post_prompt. Negative quality tags are appended only if not already present. """ # Check if post_prompt contains "quality" (case-insensitive) has_quality_in_post = "quality" in context.post_prompt.lower() # Append positive quality tags only if post_prompt doesn't have "quality" if not has_quality_in_post: if context.positive_prompt: context.positive_prompt = f"{context.positive_prompt}, {QUALITY_TAGS_POSITIVE}" else: context.positive_prompt = QUALITY_TAGS_POSITIVE context.processing_log.append("Appended positive quality tags (post_prompt has no 'quality')") else: context.processing_log.append("Skipped positive quality tags (post_prompt has 'quality')") # Append quality tags to negative prompt (only if not already present) # Check for signature pattern "lowres, {bad}" to detect existing quality tags negative_lower = context.negative_prompt.lower() if context.negative_prompt else "" has_quality_tags = "lowres, {bad}" in negative_lower if has_quality_tags: context.processing_log.append("Skipped negative quality tags (already present)") elif context.negative_prompt: context.negative_prompt = f"{context.negative_prompt}, {QUALITY_TAGS_NEGATIVE}" context.processing_log.append("Injected negative quality tags") else: context.negative_prompt = QUALITY_TAGS_NEGATIVE context.processing_log.append("Injected negative quality tags") return context def _remove_auto_hide_tags(self, context: PromptContext) -> PromptContext: """ Remove auto-hide tags from the prompt with pattern support. Pattern syntax (from NAIA2.0): - `tag`: Exact match removal - `_pattern_`: Remove tags containing 'pattern' (e.g., _hair_ → blonde hair) - `_pattern`: Remove tags ending with 'pattern' - `pattern_`: Remove tags starting with 'pattern' - `~keyword`: Protect keyword from removal """ if not context.auto_hide_tags: return context # Parse tags from positive prompt tags = [t.strip() for t in context.positive_prompt.split(',') if t.strip()] # Separate protected keywords (starting with ~) and patterns protected_keywords = set() auto_hide_patterns = [] for item in context.auto_hide_tags: item = item.strip() if not item: continue if item.startswith('~'): # Protected keyword protected_keywords.add(item[1:].strip().lower()) else: auto_hide_patterns.append(item) # Build removal list to_remove = set() for pattern in auto_hide_patterns: pattern_lower = pattern.lower() # Pattern matching logic from NAIA2.0 if pattern.startswith('__') and pattern.endswith('__') and len(pattern) > 4: # __pattern__: contains match (double underscore) # Remove all underscores for search search_term = pattern[2:-2].replace('_', '') for tag in tags: if search_term.lower() in tag.lower().replace(' ', ''): to_remove.add(tag) elif pattern.startswith('_') and pattern.endswith('_') and len(pattern) > 2: # _pattern_: contains match (single underscore, space-based) search_term = pattern[1:-1].replace('_', ' ') for tag in tags: if search_term.lower() in tag.lower(): to_remove.add(tag) elif pattern.startswith('_') and not pattern.endswith('_'): # _pattern: ends with match search_term = pattern[1:].replace('_', ' ') for tag in tags: if tag.lower().endswith(search_term.lower()): to_remove.add(tag) elif pattern.endswith('_') and not pattern.startswith('_'): # pattern_: starts with match search_term = pattern[:-1].replace('_', ' ') for tag in tags: if tag.lower().startswith(search_term.lower()): to_remove.add(tag) else: # Exact match for tag in tags: if tag.lower() == pattern_lower: to_remove.add(tag) # Remove protected keywords from removal list if protected_keywords: protected_to_keep = set() for tag in to_remove: tag_lower = tag.lower() for protected in protected_keywords: if protected in tag_lower or tag_lower == protected: protected_to_keep.add(tag) break to_remove -= protected_to_keep if protected_to_keep: context.processing_log.append(f"Protected tags: {', '.join(protected_to_keep)}") # Apply removal filtered = [t for t in tags if t not in to_remove] context.removed_tags = list(to_remove) context.positive_prompt = ", ".join(filtered) if to_remove: context.processing_log.append(f"Auto-hide removed {len(to_remove)} tags: {', '.join(sorted(to_remove))}") else: context.processing_log.append("Auto-hide: no tags matched") return context def _cleanup_prompt(self, context: PromptContext) -> PromptContext: """Clean up prompt formatting""" # Process positive prompt context.positive_prompt = self._clean_text(context.positive_prompt) # Process negative prompt context.negative_prompt = self._clean_text(context.negative_prompt) context.processing_log.append("Cleaned up formatting") return context def _clean_text(self, text: str) -> str: """Clean a single text string""" if not text: return "" # Remove extra whitespace text = ' '.join(text.split()) # Remove duplicate commas text = re.sub(r',\s*,+', ',', text) # Remove spaces around commas text = re.sub(r'\s*,\s*', ', ', text) # Strip leading/trailing commas and whitespace text = text.strip(' ,') return text def parse_tags_from_text(text: str) -> List[str]: """ Parse comma-separated tags from text. Args: text: Comma-separated tag string Returns: List of individual tags (stripped) """ if not text: return [] return [t.strip() for t in text.split(',') if t.strip()]