NAIA / core /prompt_processor.py
baqu2213's picture
Upload 3 files
244c0fa verified
"""
NAIA-WEB Prompt Processor
Pipeline-based prompt processing with hooks
Reference: NAIA2.0/core/prompt_processor.py, NAIA2.0/modules/prompt_engineering_module.py
"""
import re
from dataclasses import dataclass, field
from typing import List, Set, Tuple
from utils.constants import QUALITY_TAGS_POSITIVE, QUALITY_TAGS_NEGATIVE
@dataclass
class PromptContext:
"""
Context passed through the prompt processing pipeline.
Carries all prompt-related data and settings through each stage.
"""
positive_prompt: str
negative_prompt: str
# Processing flags
use_quality_tags: bool = True
# Pre/Post prompt additions
pre_prompt: str = ""
post_prompt: str = ""
# Auto hide tags (tags to remove) - supports patterns
auto_hide_tags: Set[str] = field(default_factory=set)
# Removed tags tracking
removed_tags: List[str] = field(default_factory=list)
# Processing log for debugging
processing_log: List[str] = field(default_factory=list)
class PromptProcessor:
"""
Pipeline-based prompt processor.
Processing order:
1. Add pre-prompt
2. Main prompt
3. Add post-prompt
4. Inject quality tags (if enabled)
5. Remove auto-hide tags
6. Clean up formatting
"""
def process(self, context: PromptContext) -> PromptContext:
"""
Run the full processing pipeline on a prompt context.
Args:
context: Initial prompt context
Returns:
Processed prompt context
"""
# Step 1: Build positive prompt with pre/post
context = self._build_positive_prompt(context)
# Step 2: Inject quality tags
if context.use_quality_tags:
context = self._inject_quality_tags(context)
# Step 3: Remove auto-hide tags
if context.auto_hide_tags:
context = self._remove_auto_hide_tags(context)
# Step 4: Clean up formatting
context = self._cleanup_prompt(context)
return context
# Person tag sets for reordering (from NAIA2.0)
PERSON_TAGS = {
"boys": {"1boy", "2boys", "3boys", "4boys", "5boys", "6+boys"},
"girls": {"1girl", "2girls", "3girls", "4girls", "5girls", "6+girls"},
"others": {"1other", "2others", "3others", "4others", "5others", "6+others"}
}
ALL_PERSON_TAGS = PERSON_TAGS["boys"] | PERSON_TAGS["girls"] | PERSON_TAGS["others"]
def _build_positive_prompt(self, context: PromptContext) -> PromptContext:
"""
Combine pre-prompt, main prompt, and post-prompt.
Person tags (1girl, 2boys, etc.) are extracted from main prompt
and moved to the front in order: boys -> girls -> others.
Final order: [person tags], [pre-prompt], [main prompt], [post-prompt]
"""
# Parse main prompt into tags
main_tags = [t.strip() for t in context.positive_prompt.split(',') if t.strip()]
# Extract person tags from main prompt
person_tags_found = []
other_main_tags = []
for tag in main_tags:
if tag.lower() in {pt.lower() for pt in self.ALL_PERSON_TAGS}:
person_tags_found.append(tag)
else:
other_main_tags.append(tag)
# Sort person tags: boys -> girls -> others
sorted_person_tags = sorted(
person_tags_found,
key=lambda tag: (
0 if tag.lower() in {pt.lower() for pt in self.PERSON_TAGS["boys"]} else
1 if tag.lower() in {pt.lower() for pt in self.PERSON_TAGS["girls"]} else 2
)
)
if sorted_person_tags:
context.processing_log.append(f"Person tags moved to front: {', '.join(sorted_person_tags)}")
# Build final prompt: [person tags], [pre-prompt], [main prompt], [post-prompt]
parts = []
# 1. Person tags (extracted from main prompt)
if sorted_person_tags:
parts.append(", ".join(sorted_person_tags))
# 2. Pre-prompt
if context.pre_prompt.strip():
parts.append(context.pre_prompt.strip())
context.processing_log.append("Added pre-prompt")
# 3. Main prompt (without person tags)
if other_main_tags:
parts.append(", ".join(other_main_tags))
# 4. Post-prompt
if context.post_prompt.strip():
parts.append(context.post_prompt.strip())
context.processing_log.append("Added post-prompt")
context.positive_prompt = ", ".join(parts)
return context
def _inject_quality_tags(self, context: PromptContext) -> PromptContext:
"""
Inject quality tags if enabled.
Positive quality tags are only appended to the END of the prompt
if the user's post_prompt does NOT contain "quality".
This allows users to customize quality tags via post_prompt.
Negative quality tags are appended only if not already present.
"""
# Check if post_prompt contains "quality" (case-insensitive)
has_quality_in_post = "quality" in context.post_prompt.lower()
# Append positive quality tags only if post_prompt doesn't have "quality"
if not has_quality_in_post:
if context.positive_prompt:
context.positive_prompt = f"{context.positive_prompt}, {QUALITY_TAGS_POSITIVE}"
else:
context.positive_prompt = QUALITY_TAGS_POSITIVE
context.processing_log.append("Appended positive quality tags (post_prompt has no 'quality')")
else:
context.processing_log.append("Skipped positive quality tags (post_prompt has 'quality')")
# Append quality tags to negative prompt (only if not already present)
# Check for signature pattern "lowres, {bad}" to detect existing quality tags
negative_lower = context.negative_prompt.lower() if context.negative_prompt else ""
has_quality_tags = "lowres, {bad}" in negative_lower
if has_quality_tags:
context.processing_log.append("Skipped negative quality tags (already present)")
elif context.negative_prompt:
context.negative_prompt = f"{context.negative_prompt}, {QUALITY_TAGS_NEGATIVE}"
context.processing_log.append("Injected negative quality tags")
else:
context.negative_prompt = QUALITY_TAGS_NEGATIVE
context.processing_log.append("Injected negative quality tags")
return context
def _remove_auto_hide_tags(self, context: PromptContext) -> PromptContext:
"""
Remove auto-hide tags from the prompt with pattern support.
Pattern syntax (from NAIA2.0):
- `tag`: Exact match removal
- `_pattern_`: Remove tags containing 'pattern' (e.g., _hair_ → blonde hair)
- `_pattern`: Remove tags ending with 'pattern'
- `pattern_`: Remove tags starting with 'pattern'
- `~keyword`: Protect keyword from removal
"""
if not context.auto_hide_tags:
return context
# Parse tags from positive prompt
tags = [t.strip() for t in context.positive_prompt.split(',') if t.strip()]
# Separate protected keywords (starting with ~) and patterns
protected_keywords = set()
auto_hide_patterns = []
for item in context.auto_hide_tags:
item = item.strip()
if not item:
continue
if item.startswith('~'):
# Protected keyword
protected_keywords.add(item[1:].strip().lower())
else:
auto_hide_patterns.append(item)
# Build removal list
to_remove = set()
for pattern in auto_hide_patterns:
pattern_lower = pattern.lower()
# Pattern matching logic from NAIA2.0
if pattern.startswith('__') and pattern.endswith('__') and len(pattern) > 4:
# __pattern__: contains match (double underscore)
# Remove all underscores for search
search_term = pattern[2:-2].replace('_', '')
for tag in tags:
if search_term.lower() in tag.lower().replace(' ', ''):
to_remove.add(tag)
elif pattern.startswith('_') and pattern.endswith('_') and len(pattern) > 2:
# _pattern_: contains match (single underscore, space-based)
search_term = pattern[1:-1].replace('_', ' ')
for tag in tags:
if search_term.lower() in tag.lower():
to_remove.add(tag)
elif pattern.startswith('_') and not pattern.endswith('_'):
# _pattern: ends with match
search_term = pattern[1:].replace('_', ' ')
for tag in tags:
if tag.lower().endswith(search_term.lower()):
to_remove.add(tag)
elif pattern.endswith('_') and not pattern.startswith('_'):
# pattern_: starts with match
search_term = pattern[:-1].replace('_', ' ')
for tag in tags:
if tag.lower().startswith(search_term.lower()):
to_remove.add(tag)
else:
# Exact match
for tag in tags:
if tag.lower() == pattern_lower:
to_remove.add(tag)
# Remove protected keywords from removal list
if protected_keywords:
protected_to_keep = set()
for tag in to_remove:
tag_lower = tag.lower()
for protected in protected_keywords:
if protected in tag_lower or tag_lower == protected:
protected_to_keep.add(tag)
break
to_remove -= protected_to_keep
if protected_to_keep:
context.processing_log.append(f"Protected tags: {', '.join(protected_to_keep)}")
# Apply removal
filtered = [t for t in tags if t not in to_remove]
context.removed_tags = list(to_remove)
context.positive_prompt = ", ".join(filtered)
if to_remove:
context.processing_log.append(f"Auto-hide removed {len(to_remove)} tags: {', '.join(sorted(to_remove))}")
else:
context.processing_log.append("Auto-hide: no tags matched")
return context
def _cleanup_prompt(self, context: PromptContext) -> PromptContext:
"""Clean up prompt formatting"""
# Process positive prompt
context.positive_prompt = self._clean_text(context.positive_prompt)
# Process negative prompt
context.negative_prompt = self._clean_text(context.negative_prompt)
context.processing_log.append("Cleaned up formatting")
return context
def _clean_text(self, text: str) -> str:
"""Clean a single text string"""
if not text:
return ""
# Remove extra whitespace
text = ' '.join(text.split())
# Remove duplicate commas
text = re.sub(r',\s*,+', ',', text)
# Remove spaces around commas
text = re.sub(r'\s*,\s*', ', ', text)
# Strip leading/trailing commas and whitespace
text = text.strip(' ,')
return text
def parse_tags_from_text(text: str) -> List[str]:
"""
Parse comma-separated tags from text.
Args:
text: Comma-separated tag string
Returns:
List of individual tags (stripped)
"""
if not text:
return []
return [t.strip() for t in text.split(',') if t.strip()]