Spaces:
Running
Running
| """ | |
| NAIA-WEB Autocomplete Service | |
| Tag autocomplete functionality for prompt input fields | |
| Reference: NAIA2.0/core/autocomplete_manager.py, NAIA2.0/core/tag_data_manager.py | |
| """ | |
| from typing import List, Dict, Tuple, Optional | |
| from dataclasses import dataclass | |
| import time | |
| class TagResult: | |
| """Single tag search result""" | |
| tag: str | |
| count: int | |
| category: str = "general" # general, artist, character | |
| class AutocompleteService: | |
| """ | |
| Autocomplete service for tag suggestions. | |
| Provides fast tag search with: | |
| - Prefix matching (highest priority) | |
| - Contains matching | |
| - Category-aware search (general, artist, character) | |
| - Frequency-based sorting | |
| Usage: | |
| service = AutocompleteService() | |
| results = service.search("blue") # Returns list of TagResult | |
| """ | |
| _instance: Optional['AutocompleteService'] = None | |
| def __new__(cls): | |
| """Singleton pattern for shared data across requests""" | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._generals: Dict[str, int] = {} | |
| self._artists: Dict[str, int] = {} | |
| self._characters: Dict[str, int] = {} | |
| self._combined: Dict[str, Tuple[int, str]] = {} # tag -> (count, category) | |
| self._load_data() | |
| self._initialized = True | |
| def _load_data(self): | |
| """Load tag data from source files""" | |
| print("AutocompleteService: Loading tag data...") | |
| start_time = time.time() | |
| # Load generals (general tags) | |
| try: | |
| from data.autocomplete.result_dupl import generals | |
| self._generals = dict(generals) | |
| print(f" - Loaded {len(self._generals):,} general tags") | |
| except ImportError as e: | |
| print(f" - Failed to load generals: {e}") | |
| self._generals = {} | |
| # Load artists | |
| try: | |
| from data.autocomplete.artist_dictionary import artist_dict | |
| # artist_dict has artist names as keys with counts | |
| self._artists = {} | |
| for key, value in artist_dict.items(): | |
| if isinstance(value, int): | |
| self._artists[key] = value | |
| elif isinstance(value, (list, tuple)) and len(value) > 0: | |
| # Some entries might be [count, ...] format | |
| self._artists[key] = value[0] if isinstance(value[0], int) else 0 | |
| print(f" - Loaded {len(self._artists)} artists") | |
| except ImportError as e: | |
| print(f" - Failed to load artists: {e}") | |
| self._artists = {} | |
| # Load characters | |
| try: | |
| from data.autocomplete.danbooru_character import character_dict_count | |
| self._characters = dict(character_dict_count) | |
| print(f" - Loaded {len(self._characters):,} characters") | |
| except ImportError as e: | |
| print(f" - Failed to load characters: {e}") | |
| self._characters = {} | |
| # Build combined index | |
| self._build_combined_index() | |
| elapsed = time.time() - start_time | |
| print(f"AutocompleteService: Loaded {len(self._combined):,} total tags in {elapsed:.2f}s") | |
| def _build_combined_index(self): | |
| """Build combined index with category information""" | |
| self._combined = {} | |
| # Add generals | |
| for tag, count in self._generals.items(): | |
| self._combined[tag] = (count, "general") | |
| # Add artists (may override generals with same name) | |
| for tag, count in self._artists.items(): | |
| self._combined[tag] = (count, "artist") | |
| # Add characters | |
| for tag, count in self._characters.items(): | |
| if tag not in self._combined or count > self._combined[tag][0]: | |
| self._combined[tag] = (count, "character") | |
| def search( | |
| self, | |
| query: str, | |
| limit: int = 20, | |
| category: Optional[str] = None | |
| ) -> List[TagResult]: | |
| """ | |
| Search for tags matching query. | |
| Args: | |
| query: Search query (minimum 1 character) | |
| limit: Maximum results to return | |
| category: Optional filter by category ('general', 'artist', 'character') | |
| Returns: | |
| List of TagResult sorted by relevance and frequency | |
| """ | |
| if not query or len(query) < 1: | |
| return [] | |
| query_lower = query.lower().strip() | |
| # Select data source based on category | |
| if category == "artist": | |
| source = {k: (v, "artist") for k, v in self._artists.items()} | |
| elif category == "character": | |
| source = {k: (v, "character") for k, v in self._characters.items()} | |
| elif category == "general": | |
| source = {k: (v, "general") for k, v in self._generals.items()} | |
| else: | |
| source = self._combined | |
| # Separate matches by type | |
| exact_matches = [] | |
| prefix_matches = [] | |
| contains_matches = [] | |
| for tag, (count, cat) in source.items(): | |
| tag_lower = tag.lower() | |
| if tag_lower == query_lower: | |
| exact_matches.append(TagResult(tag=tag, count=count, category=cat)) | |
| elif tag_lower.startswith(query_lower): | |
| prefix_matches.append(TagResult(tag=tag, count=count, category=cat)) | |
| elif query_lower in tag_lower: | |
| contains_matches.append(TagResult(tag=tag, count=count, category=cat)) | |
| # Sort each group by count (descending) | |
| exact_matches.sort(key=lambda x: x.count, reverse=True) | |
| prefix_matches.sort(key=lambda x: x.count, reverse=True) | |
| contains_matches.sort(key=lambda x: x.count, reverse=True) | |
| # Combine: exact > prefix > contains | |
| results = exact_matches + prefix_matches + contains_matches | |
| return results[:limit] | |
| def search_artists(self, query: str, limit: int = 20) -> List[TagResult]: | |
| """Search only artists""" | |
| return self.search(query, limit=limit, category="artist") | |
| def search_characters(self, query: str, limit: int = 20) -> List[TagResult]: | |
| """Search only characters""" | |
| return self.search(query, limit=limit, category="character") | |
| def search_generals(self, query: str, limit: int = 20) -> List[TagResult]: | |
| """Search only general tags""" | |
| return self.search(query, limit=limit, category="general") | |
| def get_popular_tags(self, limit: int = 100, category: Optional[str] = None) -> List[TagResult]: | |
| """Get most popular tags""" | |
| if category == "artist": | |
| source = [(k, v, "artist") for k, v in self._artists.items()] | |
| elif category == "character": | |
| source = [(k, v, "character") for k, v in self._characters.items()] | |
| elif category == "general": | |
| source = [(k, v, "general") for k, v in self._generals.items()] | |
| else: | |
| source = [(k, v, c) for k, (v, c) in self._combined.items()] | |
| # Sort by count | |
| source.sort(key=lambda x: x[1], reverse=True) | |
| return [TagResult(tag=t, count=c, category=cat) for t, c, cat in source[:limit]] | |
| def get_stats(self) -> Dict[str, int]: | |
| """Get statistics about loaded data""" | |
| return { | |
| "generals": len(self._generals), | |
| "artists": len(self._artists), | |
| "characters": len(self._characters), | |
| "total": len(self._combined) | |
| } | |
| # Convenience function for simple usage | |
| def search_tags(query: str, limit: int = 20) -> List[Dict]: | |
| """ | |
| Simple function to search tags. | |
| Returns list of dicts: [{"tag": str, "count": int, "category": str}, ...] | |
| """ | |
| service = AutocompleteService() | |
| results = service.search(query, limit=limit) | |
| return [{"tag": r.tag, "count": r.count, "category": r.category} for r in results] | |
| def get_autocomplete_service() -> AutocompleteService: | |
| """Get the singleton AutocompleteService instance""" | |
| return AutocompleteService() | |
| # ============================================================================= | |
| # Gradio API Functions | |
| # ============================================================================= | |
| def gradio_search_tags(query: str, limit: int = 20) -> List[List]: | |
| """ | |
| Search tags for Gradio Dataframe component. | |
| Args: | |
| query: Search query | |
| limit: Maximum results | |
| Returns: | |
| List of [tag, count, category] for Dataframe display | |
| """ | |
| if not query or len(query.strip()) < 1: | |
| return [] | |
| service = get_autocomplete_service() | |
| results = service.search(query.strip(), limit=limit) | |
| return [[r.tag, r.count, r.category] for r in results] | |
| def gradio_search_tags_json(query: str, limit: int = 20) -> List[Dict]: | |
| """ | |
| Search tags and return as JSON-serializable list. | |
| For JavaScript consumption via Gradio's js parameter. | |
| Returns: | |
| [{"tag": str, "count": int, "category": str}, ...] | |
| """ | |
| if not query or len(query.strip()) < 1: | |
| return [] | |
| service = get_autocomplete_service() | |
| results = service.search(query.strip(), limit=limit) | |
| return [ | |
| {"tag": r.tag, "count": r.count, "category": r.category} | |
| for r in results | |
| ] | |
| def gradio_get_completion(current_text: str, cursor_position: int, limit: int = 10) -> List[Dict]: | |
| """ | |
| Get autocomplete suggestions based on current cursor position. | |
| Extracts the current token (word being typed) and returns suggestions. | |
| Args: | |
| current_text: Full text content | |
| cursor_position: Cursor position in text | |
| limit: Maximum suggestions | |
| Returns: | |
| List of suggestions with metadata | |
| """ | |
| if not current_text: | |
| return [] | |
| # Extract current token (word at cursor) | |
| # Tokens are separated by commas | |
| text_before_cursor = current_text[:cursor_position] | |
| # Find the start of current token (after last comma) | |
| last_comma = text_before_cursor.rfind(',') | |
| token_start = last_comma + 1 if last_comma >= 0 else 0 | |
| # Extract and clean the current token | |
| current_token = text_before_cursor[token_start:].strip() | |
| if len(current_token) < 1: | |
| return [] | |
| # Search for matches | |
| service = get_autocomplete_service() | |
| results = service.search(current_token, limit=limit) | |
| return [ | |
| { | |
| "tag": r.tag, | |
| "count": r.count, | |
| "category": r.category, | |
| "token_start": token_start, | |
| "token_end": cursor_position | |
| } | |
| for r in results | |
| ] | |
| def preload_autocomplete_data() -> Dict: | |
| """ | |
| Preload autocomplete data and return statistics. | |
| Call this at app startup to warm up the cache. | |
| Returns: | |
| Statistics about loaded data | |
| """ | |
| service = get_autocomplete_service() | |
| stats = service.get_stats() | |
| return { | |
| "status": "loaded", | |
| "stats": stats | |
| } | |