Spaces:
Running
Running
| """ | |
| NAIA-WEB Prompt Processor | |
| Pipeline-based prompt processing with hooks | |
| Reference: NAIA2.0/core/prompt_processor.py, NAIA2.0/modules/prompt_engineering_module.py | |
| """ | |
| import re | |
| from dataclasses import dataclass, field | |
| from typing import List, Set, Tuple | |
| from utils.constants import QUALITY_TAGS_POSITIVE, QUALITY_TAGS_NEGATIVE | |
| class PromptContext: | |
| """ | |
| Context passed through the prompt processing pipeline. | |
| Carries all prompt-related data and settings through each stage. | |
| """ | |
| positive_prompt: str | |
| negative_prompt: str | |
| # Processing flags | |
| use_quality_tags: bool = True | |
| # Pre/Post prompt additions | |
| pre_prompt: str = "" | |
| post_prompt: str = "" | |
| # Auto hide tags (tags to remove) - supports patterns | |
| auto_hide_tags: Set[str] = field(default_factory=set) | |
| # Removed tags tracking | |
| removed_tags: List[str] = field(default_factory=list) | |
| # Processing log for debugging | |
| processing_log: List[str] = field(default_factory=list) | |
| class PromptProcessor: | |
| """ | |
| Pipeline-based prompt processor. | |
| Processing order: | |
| 1. Add pre-prompt | |
| 2. Main prompt | |
| 3. Add post-prompt | |
| 4. Inject quality tags (if enabled) | |
| 5. Remove auto-hide tags | |
| 6. Clean up formatting | |
| """ | |
| def process(self, context: PromptContext) -> PromptContext: | |
| """ | |
| Run the full processing pipeline on a prompt context. | |
| Args: | |
| context: Initial prompt context | |
| Returns: | |
| Processed prompt context | |
| """ | |
| # Step 1: Build positive prompt with pre/post | |
| context = self._build_positive_prompt(context) | |
| # Step 2: Inject quality tags | |
| if context.use_quality_tags: | |
| context = self._inject_quality_tags(context) | |
| # Step 3: Remove auto-hide tags | |
| if context.auto_hide_tags: | |
| context = self._remove_auto_hide_tags(context) | |
| # Step 4: Clean up formatting | |
| context = self._cleanup_prompt(context) | |
| return context | |
| # Person tag sets for reordering (from NAIA2.0) | |
| PERSON_TAGS = { | |
| "boys": {"1boy", "2boys", "3boys", "4boys", "5boys", "6+boys"}, | |
| "girls": {"1girl", "2girls", "3girls", "4girls", "5girls", "6+girls"}, | |
| "others": {"1other", "2others", "3others", "4others", "5others", "6+others"} | |
| } | |
| ALL_PERSON_TAGS = PERSON_TAGS["boys"] | PERSON_TAGS["girls"] | PERSON_TAGS["others"] | |
| def _build_positive_prompt(self, context: PromptContext) -> PromptContext: | |
| """ | |
| Combine pre-prompt, main prompt, and post-prompt. | |
| Person tags (1girl, 2boys, etc.) are extracted from main prompt | |
| and moved to the front in order: boys -> girls -> others. | |
| Final order: [person tags], [pre-prompt], [main prompt], [post-prompt] | |
| """ | |
| # Parse main prompt into tags | |
| main_tags = [t.strip() for t in context.positive_prompt.split(',') if t.strip()] | |
| # Extract person tags from main prompt | |
| person_tags_found = [] | |
| other_main_tags = [] | |
| for tag in main_tags: | |
| if tag.lower() in {pt.lower() for pt in self.ALL_PERSON_TAGS}: | |
| person_tags_found.append(tag) | |
| else: | |
| other_main_tags.append(tag) | |
| # Sort person tags: boys -> girls -> others | |
| sorted_person_tags = sorted( | |
| person_tags_found, | |
| key=lambda tag: ( | |
| 0 if tag.lower() in {pt.lower() for pt in self.PERSON_TAGS["boys"]} else | |
| 1 if tag.lower() in {pt.lower() for pt in self.PERSON_TAGS["girls"]} else 2 | |
| ) | |
| ) | |
| if sorted_person_tags: | |
| context.processing_log.append(f"Person tags moved to front: {', '.join(sorted_person_tags)}") | |
| # Build final prompt: [person tags], [pre-prompt], [main prompt], [post-prompt] | |
| parts = [] | |
| # 1. Person tags (extracted from main prompt) | |
| if sorted_person_tags: | |
| parts.append(", ".join(sorted_person_tags)) | |
| # 2. Pre-prompt | |
| if context.pre_prompt.strip(): | |
| parts.append(context.pre_prompt.strip()) | |
| context.processing_log.append("Added pre-prompt") | |
| # 3. Main prompt (without person tags) | |
| if other_main_tags: | |
| parts.append(", ".join(other_main_tags)) | |
| # 4. Post-prompt | |
| if context.post_prompt.strip(): | |
| parts.append(context.post_prompt.strip()) | |
| context.processing_log.append("Added post-prompt") | |
| context.positive_prompt = ", ".join(parts) | |
| return context | |
| def _inject_quality_tags(self, context: PromptContext) -> PromptContext: | |
| """ | |
| Inject quality tags if enabled. | |
| Positive quality tags are only appended to the END of the prompt | |
| if the user's post_prompt does NOT contain "quality". | |
| This allows users to customize quality tags via post_prompt. | |
| Negative quality tags are appended only if not already present. | |
| """ | |
| # Check if post_prompt contains "quality" (case-insensitive) | |
| has_quality_in_post = "quality" in context.post_prompt.lower() | |
| # Append positive quality tags only if post_prompt doesn't have "quality" | |
| if not has_quality_in_post: | |
| if context.positive_prompt: | |
| context.positive_prompt = f"{context.positive_prompt}, {QUALITY_TAGS_POSITIVE}" | |
| else: | |
| context.positive_prompt = QUALITY_TAGS_POSITIVE | |
| context.processing_log.append("Appended positive quality tags (post_prompt has no 'quality')") | |
| else: | |
| context.processing_log.append("Skipped positive quality tags (post_prompt has 'quality')") | |
| # Append quality tags to negative prompt (only if not already present) | |
| # Check for signature pattern "lowres, {bad}" to detect existing quality tags | |
| negative_lower = context.negative_prompt.lower() if context.negative_prompt else "" | |
| has_quality_tags = "lowres, {bad}" in negative_lower | |
| if has_quality_tags: | |
| context.processing_log.append("Skipped negative quality tags (already present)") | |
| elif context.negative_prompt: | |
| context.negative_prompt = f"{context.negative_prompt}, {QUALITY_TAGS_NEGATIVE}" | |
| context.processing_log.append("Injected negative quality tags") | |
| else: | |
| context.negative_prompt = QUALITY_TAGS_NEGATIVE | |
| context.processing_log.append("Injected negative quality tags") | |
| return context | |
| def _remove_auto_hide_tags(self, context: PromptContext) -> PromptContext: | |
| """ | |
| Remove auto-hide tags from the prompt with pattern support. | |
| Pattern syntax (from NAIA2.0): | |
| - `tag`: Exact match removal | |
| - `_pattern_`: Remove tags containing 'pattern' (e.g., _hair_ → blonde hair) | |
| - `_pattern`: Remove tags ending with 'pattern' | |
| - `pattern_`: Remove tags starting with 'pattern' | |
| - `~keyword`: Protect keyword from removal | |
| """ | |
| if not context.auto_hide_tags: | |
| return context | |
| # Parse tags from positive prompt | |
| tags = [t.strip() for t in context.positive_prompt.split(',') if t.strip()] | |
| # Separate protected keywords (starting with ~) and patterns | |
| protected_keywords = set() | |
| auto_hide_patterns = [] | |
| for item in context.auto_hide_tags: | |
| item = item.strip() | |
| if not item: | |
| continue | |
| if item.startswith('~'): | |
| # Protected keyword | |
| protected_keywords.add(item[1:].strip().lower()) | |
| else: | |
| auto_hide_patterns.append(item) | |
| # Build removal list | |
| to_remove = set() | |
| for pattern in auto_hide_patterns: | |
| pattern_lower = pattern.lower() | |
| # Pattern matching logic from NAIA2.0 | |
| if pattern.startswith('__') and pattern.endswith('__') and len(pattern) > 4: | |
| # __pattern__: contains match (double underscore) | |
| # Remove all underscores for search | |
| search_term = pattern[2:-2].replace('_', '') | |
| for tag in tags: | |
| if search_term.lower() in tag.lower().replace(' ', ''): | |
| to_remove.add(tag) | |
| elif pattern.startswith('_') and pattern.endswith('_') and len(pattern) > 2: | |
| # _pattern_: contains match (single underscore, space-based) | |
| search_term = pattern[1:-1].replace('_', ' ') | |
| for tag in tags: | |
| if search_term.lower() in tag.lower(): | |
| to_remove.add(tag) | |
| elif pattern.startswith('_') and not pattern.endswith('_'): | |
| # _pattern: ends with match | |
| search_term = pattern[1:].replace('_', ' ') | |
| for tag in tags: | |
| if tag.lower().endswith(search_term.lower()): | |
| to_remove.add(tag) | |
| elif pattern.endswith('_') and not pattern.startswith('_'): | |
| # pattern_: starts with match | |
| search_term = pattern[:-1].replace('_', ' ') | |
| for tag in tags: | |
| if tag.lower().startswith(search_term.lower()): | |
| to_remove.add(tag) | |
| else: | |
| # Exact match | |
| for tag in tags: | |
| if tag.lower() == pattern_lower: | |
| to_remove.add(tag) | |
| # Remove protected keywords from removal list | |
| if protected_keywords: | |
| protected_to_keep = set() | |
| for tag in to_remove: | |
| tag_lower = tag.lower() | |
| for protected in protected_keywords: | |
| if protected in tag_lower or tag_lower == protected: | |
| protected_to_keep.add(tag) | |
| break | |
| to_remove -= protected_to_keep | |
| if protected_to_keep: | |
| context.processing_log.append(f"Protected tags: {', '.join(protected_to_keep)}") | |
| # Apply removal | |
| filtered = [t for t in tags if t not in to_remove] | |
| context.removed_tags = list(to_remove) | |
| context.positive_prompt = ", ".join(filtered) | |
| if to_remove: | |
| context.processing_log.append(f"Auto-hide removed {len(to_remove)} tags: {', '.join(sorted(to_remove))}") | |
| else: | |
| context.processing_log.append("Auto-hide: no tags matched") | |
| return context | |
| def _cleanup_prompt(self, context: PromptContext) -> PromptContext: | |
| """Clean up prompt formatting""" | |
| # Process positive prompt | |
| context.positive_prompt = self._clean_text(context.positive_prompt) | |
| # Process negative prompt | |
| context.negative_prompt = self._clean_text(context.negative_prompt) | |
| context.processing_log.append("Cleaned up formatting") | |
| return context | |
| def _clean_text(self, text: str) -> str: | |
| """Clean a single text string""" | |
| if not text: | |
| return "" | |
| # Remove extra whitespace | |
| text = ' '.join(text.split()) | |
| # Remove duplicate commas | |
| text = re.sub(r',\s*,+', ',', text) | |
| # Remove spaces around commas | |
| text = re.sub(r'\s*,\s*', ', ', text) | |
| # Strip leading/trailing commas and whitespace | |
| text = text.strip(' ,') | |
| return text | |
| def parse_tags_from_text(text: str) -> List[str]: | |
| """ | |
| Parse comma-separated tags from text. | |
| Args: | |
| text: Comma-separated tag string | |
| Returns: | |
| List of individual tags (stripped) | |
| """ | |
| if not text: | |
| return [] | |
| return [t.strip() for t in text.split(',') if t.strip()] | |