NAIA / core /autocomplete_service.py
baqu2213's picture
Upload 5 files
cad34e4 verified
"""
NAIA-WEB Autocomplete Service
Tag autocomplete functionality for prompt input fields
Reference: NAIA2.0/core/autocomplete_manager.py, NAIA2.0/core/tag_data_manager.py
"""
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import time
@dataclass
class TagResult:
"""Single tag search result"""
tag: str
count: int
category: str = "general" # general, artist, character
class AutocompleteService:
"""
Autocomplete service for tag suggestions.
Provides fast tag search with:
- Prefix matching (highest priority)
- Contains matching
- Category-aware search (general, artist, character)
- Frequency-based sorting
Usage:
service = AutocompleteService()
results = service.search("blue") # Returns list of TagResult
"""
_instance: Optional['AutocompleteService'] = None
def __new__(cls):
"""Singleton pattern for shared data across requests"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._generals: Dict[str, int] = {}
self._artists: Dict[str, int] = {}
self._characters: Dict[str, int] = {}
self._combined: Dict[str, Tuple[int, str]] = {} # tag -> (count, category)
self._load_data()
self._initialized = True
def _load_data(self):
"""Load tag data from source files"""
print("AutocompleteService: Loading tag data...")
start_time = time.time()
# Load generals (general tags)
try:
from data.autocomplete.result_dupl import generals
self._generals = dict(generals)
print(f" - Loaded {len(self._generals):,} general tags")
except ImportError as e:
print(f" - Failed to load generals: {e}")
self._generals = {}
# Load artists
try:
from data.autocomplete.artist_dictionary import artist_dict
# artist_dict has artist names as keys with counts
self._artists = {}
for key, value in artist_dict.items():
if isinstance(value, int):
self._artists[key] = value
elif isinstance(value, (list, tuple)) and len(value) > 0:
# Some entries might be [count, ...] format
self._artists[key] = value[0] if isinstance(value[0], int) else 0
print(f" - Loaded {len(self._artists)} artists")
except ImportError as e:
print(f" - Failed to load artists: {e}")
self._artists = {}
# Load characters
try:
from data.autocomplete.danbooru_character import character_dict_count
self._characters = dict(character_dict_count)
print(f" - Loaded {len(self._characters):,} characters")
except ImportError as e:
print(f" - Failed to load characters: {e}")
self._characters = {}
# Build combined index
self._build_combined_index()
elapsed = time.time() - start_time
print(f"AutocompleteService: Loaded {len(self._combined):,} total tags in {elapsed:.2f}s")
def _build_combined_index(self):
"""Build combined index with category information"""
self._combined = {}
# Add generals
for tag, count in self._generals.items():
self._combined[tag] = (count, "general")
# Add artists (may override generals with same name)
for tag, count in self._artists.items():
self._combined[tag] = (count, "artist")
# Add characters
for tag, count in self._characters.items():
if tag not in self._combined or count > self._combined[tag][0]:
self._combined[tag] = (count, "character")
def search(
self,
query: str,
limit: int = 20,
category: Optional[str] = None
) -> List[TagResult]:
"""
Search for tags matching query.
Args:
query: Search query (minimum 1 character)
limit: Maximum results to return
category: Optional filter by category ('general', 'artist', 'character')
Returns:
List of TagResult sorted by relevance and frequency
"""
if not query or len(query) < 1:
return []
query_lower = query.lower().strip()
# Select data source based on category
if category == "artist":
source = {k: (v, "artist") for k, v in self._artists.items()}
elif category == "character":
source = {k: (v, "character") for k, v in self._characters.items()}
elif category == "general":
source = {k: (v, "general") for k, v in self._generals.items()}
else:
source = self._combined
# Separate matches by type
exact_matches = []
prefix_matches = []
contains_matches = []
for tag, (count, cat) in source.items():
tag_lower = tag.lower()
if tag_lower == query_lower:
exact_matches.append(TagResult(tag=tag, count=count, category=cat))
elif tag_lower.startswith(query_lower):
prefix_matches.append(TagResult(tag=tag, count=count, category=cat))
elif query_lower in tag_lower:
contains_matches.append(TagResult(tag=tag, count=count, category=cat))
# Sort each group by count (descending)
exact_matches.sort(key=lambda x: x.count, reverse=True)
prefix_matches.sort(key=lambda x: x.count, reverse=True)
contains_matches.sort(key=lambda x: x.count, reverse=True)
# Combine: exact > prefix > contains
results = exact_matches + prefix_matches + contains_matches
return results[:limit]
def search_artists(self, query: str, limit: int = 20) -> List[TagResult]:
"""Search only artists"""
return self.search(query, limit=limit, category="artist")
def search_characters(self, query: str, limit: int = 20) -> List[TagResult]:
"""Search only characters"""
return self.search(query, limit=limit, category="character")
def search_generals(self, query: str, limit: int = 20) -> List[TagResult]:
"""Search only general tags"""
return self.search(query, limit=limit, category="general")
def get_popular_tags(self, limit: int = 100, category: Optional[str] = None) -> List[TagResult]:
"""Get most popular tags"""
if category == "artist":
source = [(k, v, "artist") for k, v in self._artists.items()]
elif category == "character":
source = [(k, v, "character") for k, v in self._characters.items()]
elif category == "general":
source = [(k, v, "general") for k, v in self._generals.items()]
else:
source = [(k, v, c) for k, (v, c) in self._combined.items()]
# Sort by count
source.sort(key=lambda x: x[1], reverse=True)
return [TagResult(tag=t, count=c, category=cat) for t, c, cat in source[:limit]]
def get_stats(self) -> Dict[str, int]:
"""Get statistics about loaded data"""
return {
"generals": len(self._generals),
"artists": len(self._artists),
"characters": len(self._characters),
"total": len(self._combined)
}
# Convenience function for simple usage
def search_tags(query: str, limit: int = 20) -> List[Dict]:
"""
Simple function to search tags.
Returns list of dicts: [{"tag": str, "count": int, "category": str}, ...]
"""
service = AutocompleteService()
results = service.search(query, limit=limit)
return [{"tag": r.tag, "count": r.count, "category": r.category} for r in results]
def get_autocomplete_service() -> AutocompleteService:
"""Get the singleton AutocompleteService instance"""
return AutocompleteService()
# =============================================================================
# Gradio API Functions
# =============================================================================
def gradio_search_tags(query: str, limit: int = 20) -> List[List]:
"""
Search tags for Gradio Dataframe component.
Args:
query: Search query
limit: Maximum results
Returns:
List of [tag, count, category] for Dataframe display
"""
if not query or len(query.strip()) < 1:
return []
service = get_autocomplete_service()
results = service.search(query.strip(), limit=limit)
return [[r.tag, r.count, r.category] for r in results]
def gradio_search_tags_json(query: str, limit: int = 20) -> List[Dict]:
"""
Search tags and return as JSON-serializable list.
For JavaScript consumption via Gradio's js parameter.
Returns:
[{"tag": str, "count": int, "category": str}, ...]
"""
if not query or len(query.strip()) < 1:
return []
service = get_autocomplete_service()
results = service.search(query.strip(), limit=limit)
return [
{"tag": r.tag, "count": r.count, "category": r.category}
for r in results
]
def gradio_get_completion(current_text: str, cursor_position: int, limit: int = 10) -> List[Dict]:
"""
Get autocomplete suggestions based on current cursor position.
Extracts the current token (word being typed) and returns suggestions.
Args:
current_text: Full text content
cursor_position: Cursor position in text
limit: Maximum suggestions
Returns:
List of suggestions with metadata
"""
if not current_text:
return []
# Extract current token (word at cursor)
# Tokens are separated by commas
text_before_cursor = current_text[:cursor_position]
# Find the start of current token (after last comma)
last_comma = text_before_cursor.rfind(',')
token_start = last_comma + 1 if last_comma >= 0 else 0
# Extract and clean the current token
current_token = text_before_cursor[token_start:].strip()
if len(current_token) < 1:
return []
# Search for matches
service = get_autocomplete_service()
results = service.search(current_token, limit=limit)
return [
{
"tag": r.tag,
"count": r.count,
"category": r.category,
"token_start": token_start,
"token_end": cursor_position
}
for r in results
]
def preload_autocomplete_data() -> Dict:
"""
Preload autocomplete data and return statistics.
Call this at app startup to warm up the cache.
Returns:
Statistics about loaded data
"""
service = get_autocomplete_service()
stats = service.get_stats()
return {
"status": "loaded",
"stats": stats
}