NAIA / data /partition_loader.py
baqu2213's picture
Upload 2 files
807d53a verified
"""
NAIA-WEB Partition Loader
TGP file loading for Quick Search functionality
Reference: NAIA2.0/ui/remote/quick_search_tab.py (145-255)
"""
import struct
import lzma
import pickle
from pathlib import Path
from typing import Dict, List, Set, Optional
from collections import Counter
try:
import numpy as np
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
np = None
# Data directory
DATA_DIR = Path(__file__).parent / "quick_search"
class SinglePartitionStore:
"""
Single partition storage (inverted index based) - Quick Search lightweight version
Reference: NAIA2.0/ui/remote/quick_search_tab.py SinglePartitionStore class
"""
MAGIC = b'TGP1'
VERSION = 1
def __init__(self):
self.num_events: int = 0
self._event_tag_indices = None
self._event_tag_indptr = None
self._event_counts = None
self._tag_to_events: Dict[int, object] = {}
self._loaded: bool = False
@classmethod
def load(cls, input_path: str) -> 'SinglePartitionStore':
"""Load partition file"""
if not HAS_NUMPY:
raise RuntimeError("NumPy is required")
store = cls()
with open(input_path, 'rb') as f:
magic = f.read(4)
if magic != cls.MAGIC:
raise ValueError(f"Invalid format: {magic}")
_ = struct.unpack('<H', f.read(2))[0] # version
compressed_len = struct.unpack('<I', f.read(4))[0]
compressed = f.read(compressed_len)
serialized = lzma.decompress(compressed)
data = pickle.loads(serialized)
store.num_events = data['num_events']
store._event_tag_indices = np.frombuffer(data['event_tag_indices'], dtype=np.uint16).copy()
store._event_tag_indptr = np.frombuffer(data['event_tag_indptr'], dtype=np.int32).copy()
store._event_counts = np.frombuffer(data['event_counts'], dtype=np.int32).copy()
store._tag_to_events = {
int(k): np.frombuffer(v, dtype=np.int32).copy()
for k, v in data['tag_to_events'].items()
}
store._loaded = True
return store
def filter_events(
self,
required_tags: Optional[List[str]] = None,
excluded_tags: Optional[List[str]] = None,
tag_to_id: Optional[Dict[str, int]] = None
):
"""Return event indices matching conditions"""
if not self._loaded or not HAS_NUMPY:
return np.array([], dtype=np.int32) if HAS_NUMPY else []
# Start with all events
candidates = set(range(self.num_events))
# Required tags
if required_tags and tag_to_id:
for tag in required_tags:
if tag in tag_to_id:
tag_id = tag_to_id[tag]
if tag_id in self._tag_to_events:
candidates &= set(self._tag_to_events[tag_id])
else:
return np.array([], dtype=np.int32)
else:
return np.array([], dtype=np.int32)
# Excluded tags
if excluded_tags and tag_to_id:
for tag in excluded_tags:
if tag in tag_to_id:
tag_id = tag_to_id[tag]
if tag_id in self._tag_to_events:
candidates -= set(self._tag_to_events[tag_id])
return np.array(sorted(candidates), dtype=np.int32)
def get_tag_counts(
self,
event_indices=None,
id_to_tag: Optional[Dict[int, str]] = None
) -> Counter:
"""Count events per tag"""
if not HAS_NUMPY or id_to_tag is None:
return Counter()
if event_indices is None:
# Total tag counts
return Counter({
id_to_tag[tag_id]: len(events)
for tag_id, events in self._tag_to_events.items()
if tag_id in id_to_tag
})
event_set = set(event_indices)
return Counter({
id_to_tag[tag_id]: len(set(events) & event_set)
for tag_id, events in self._tag_to_events.items()
if tag_id in id_to_tag
})
def get_event_tags(
self,
event_idx: int,
id_to_tag: Optional[Dict[int, str]] = None
) -> Set[str]:
"""Return tags for an event"""
if not self._loaded or id_to_tag is None:
return set()
if event_idx < 0 or event_idx >= self.num_events:
return set()
start = self._event_tag_indptr[event_idx]
end = self._event_tag_indptr[event_idx + 1]
tag_ids = self._event_tag_indices[start:end]
return {id_to_tag[int(tid)] for tid in tag_ids if int(tid) in id_to_tag}
class PartitionMetadata:
"""
Metadata for partition files
Reference: NAIA2.0/ui/remote/quick_search_tab.py metadata loading
"""
MAGIC = b'TGPS'
def __init__(self):
self.tag_to_id: Dict[str, int] = {}
self.id_to_tag: Dict[int, str] = {}
self.tag_freq: Dict[str, int] = {}
self.partitions: Dict[str, Dict] = {}
self._loaded: bool = False
@classmethod
def load(cls, input_path: str) -> 'PartitionMetadata':
"""Load metadata file"""
meta = cls()
with open(input_path, 'rb') as f:
magic = f.read(4)
if magic != cls.MAGIC:
raise ValueError(f"Invalid metadata format: {magic}")
_ = struct.unpack('<H', f.read(2))[0] # version
compressed_len = struct.unpack('<I', f.read(4))[0]
compressed = f.read(compressed_len)
serialized = lzma.decompress(compressed)
data = pickle.loads(serialized)
meta.tag_to_id = data.get('tag_to_id', {})
meta.id_to_tag = data.get('id_to_tag', {})
meta.tag_freq = data.get('tag_freq', {})
meta.partitions = data.get('partitions', {})
meta._loaded = True
return meta
@property
def is_loaded(self) -> bool:
return self._loaded
def get_partition_names(self) -> List[str]:
"""Return list of available partition names"""
return list(self.partitions.keys())
class PartitionManager:
"""
Manages partition loading and caching
"""
def __init__(self):
self._metadata: Optional[PartitionMetadata] = None
self._loaded_partitions: Dict[str, SinglePartitionStore] = {}
self._data_dir = DATA_DIR
def is_data_available(self) -> bool:
"""Check if partition data files are available"""
metadata_path = self._data_dir / "metadata.tgpm"
return metadata_path.exists()
def load_metadata(self) -> Optional[PartitionMetadata]:
"""Load partition metadata"""
if self._metadata is not None:
return self._metadata
metadata_path = self._data_dir / "metadata.tgpm"
if not metadata_path.exists():
return None
try:
self._metadata = PartitionMetadata.load(str(metadata_path))
return self._metadata
except Exception as e:
print(f"Error loading metadata: {e}")
return None
def get_metadata(self) -> Optional[PartitionMetadata]:
"""Get loaded metadata (load if needed)"""
if self._metadata is None:
self.load_metadata()
return self._metadata
def load_partition(self, partition_name: str) -> Optional[SinglePartitionStore]:
"""Load a specific partition"""
if partition_name in self._loaded_partitions:
return self._loaded_partitions[partition_name]
partition_path = self._data_dir / f"{partition_name}.tgp"
if not partition_path.exists():
print(f"Partition file not found: {partition_path}")
return None
try:
store = SinglePartitionStore.load(str(partition_path))
self._loaded_partitions[partition_name] = store
return store
except Exception as e:
print(f"Error loading partition {partition_name}: {e}")
return None
def unload_partition(self, partition_name: str):
"""Unload a partition to free memory"""
if partition_name in self._loaded_partitions:
del self._loaded_partitions[partition_name]
def unload_all(self):
"""Unload all partitions"""
self._loaded_partitions.clear()
def get_partition_filename(self, rating: str, person: str) -> str:
"""
Get partition filename from rating and person category
Args:
rating: 'g', 's', 'q', or 'e'
person: person category like '1girl_solo'
Returns:
Partition name like 'g_1girl_solo'
"""
return f"{rating}_{person}"