baqu2213 commited on
Commit
807d53a
·
verified ·
1 Parent(s): b998b38

Upload 2 files

Browse files
Files changed (2) hide show
  1. data/partition_loader.py +1 -1
  2. data/tag_store.py +50 -85
data/partition_loader.py CHANGED
@@ -119,7 +119,7 @@ class SinglePartitionStore:
119
  if not HAS_NUMPY or id_to_tag is None:
120
  return Counter()
121
 
122
- if event_indices is None or len(event_indices) == 0:
123
  # Total tag counts
124
  return Counter({
125
  id_to_tag[tag_id]: len(events)
 
119
  if not HAS_NUMPY or id_to_tag is None:
120
  return Counter()
121
 
122
+ if event_indices is None:
123
  # Total tag counts
124
  return Counter({
125
  id_to_tag[tag_id]: len(events)
data/tag_store.py CHANGED
@@ -35,12 +35,13 @@ class QuickSearchResult:
35
  self.tags = []
36
 
37
 
 
38
  class TagStore:
39
  """
40
  High-level interface for Quick Search functionality.
41
 
42
  Provides:
43
- - Partition loading and management
44
  - Tag filtering by rating and person category
45
  - Include/Exclude tag filtering
46
  - Random event sampling for prompt generation
@@ -66,38 +67,38 @@ class TagStore:
66
  def load_partition(self, rating: str, person: str) -> bool:
67
  """
68
  Load partition for given rating and person category.
69
-
70
  Args:
71
  rating: 'g', 's', 'q', or 'e'
72
  person: Person category like '1girl_solo'
73
-
74
  Returns:
75
  True if loaded successfully
76
  """
77
  partition_name = self._manager.get_partition_filename(rating, person)
 
 
 
 
78
 
79
- if partition_name == self._current_partition_name and self._current_partition is not None:
80
- return True # Already loaded
81
-
82
- # Unload previous partition to save memory
83
- if self._current_partition_name:
84
- self._manager.unload_partition(self._current_partition_name)
85
-
86
  partition = self._manager.load_partition(partition_name)
87
- if partition is None:
88
- self._current_partition = None
89
- self._current_partition_name = ""
90
- return False
91
-
92
- self._current_partition = partition
93
- self._current_partition_name = partition_name
94
- return True
 
 
95
 
96
  def get_event_count(self) -> int:
97
  """Get number of events in current partition"""
98
- if self._current_partition is None:
99
- return 0
100
- return self._current_partition.num_events
101
 
102
  def get_filtered_event_count(
103
  self,
@@ -105,20 +106,19 @@ class TagStore:
105
  exclude_tags: Optional[List[str]] = None
106
  ) -> int:
107
  """Get number of events matching filter criteria"""
108
- if self._current_partition is None:
109
  return 0
110
 
111
  metadata = self._manager.get_metadata()
112
  if metadata is None:
113
  return 0
114
 
115
- filtered = self._current_partition.filter_events(
116
  required_tags=include_tags,
117
  excluded_tags=exclude_tags,
118
  tag_to_id=metadata.tag_to_id
119
  )
120
-
121
- return len(filtered)
122
 
123
  def get_top_tags(
124
  self,
@@ -129,17 +129,14 @@ class TagStore:
129
  ) -> List[TagInfo]:
130
  """
131
  Get most frequent tags in current partition with filters applied.
132
-
133
- Returns tags sorted by frequency (descending).
134
  """
135
- if self._current_partition is None:
136
  return []
137
 
138
  metadata = self._manager.get_metadata()
139
  if metadata is None:
140
  return []
141
 
142
- # Filter events
143
  filtered_indices = self._current_partition.filter_events(
144
  required_tags=include_tags,
145
  excluded_tags=exclude_tags,
@@ -149,23 +146,25 @@ class TagStore:
149
  if len(filtered_indices) == 0:
150
  return []
151
 
152
- # Count tags in filtered events
153
  tag_counts = self._current_partition.get_tag_counts(
154
  event_indices=filtered_indices,
155
  id_to_tag=metadata.id_to_tag
156
  )
157
 
158
- # Exclude already included/excluded tags from results
159
  excluded_set = set(include_tags or []) | set(exclude_tags or [])
160
-
161
- # Sort by count and return with pagination
162
  sorted_tags = sorted(
163
  [(tag, count) for tag, count in tag_counts.items() if tag not in excluded_set],
164
  key=lambda x: x[1],
165
  reverse=True
166
- )[offset:offset + limit]
167
 
168
- return [TagInfo(tag=tag, count=count) for tag, count in sorted_tags]
 
 
 
169
 
170
  def get_total_tag_count(
171
  self,
@@ -173,14 +172,13 @@ class TagStore:
173
  exclude_tags: Optional[List[str]] = None
174
  ) -> int:
175
  """Get total number of unique tags matching filter criteria"""
176
- if self._current_partition is None:
177
  return 0
178
 
179
  metadata = self._manager.get_metadata()
180
  if metadata is None:
181
  return 0
182
 
183
- # Filter events
184
  filtered_indices = self._current_partition.filter_events(
185
  required_tags=include_tags,
186
  excluded_tags=exclude_tags,
@@ -190,15 +188,13 @@ class TagStore:
190
  if len(filtered_indices) == 0:
191
  return 0
192
 
193
- # Count tags in filtered events
194
  tag_counts = self._current_partition.get_tag_counts(
195
  event_indices=filtered_indices,
196
  id_to_tag=metadata.id_to_tag
197
  )
198
-
199
- # Exclude already included/excluded tags
200
  excluded_set = set(include_tags or []) | set(exclude_tags or [])
201
- return len([tag for tag in tag_counts.keys() if tag not in excluded_set])
202
 
203
  def generate_random_prompt(
204
  self,
@@ -208,67 +204,43 @@ class TagStore:
208
  exclude_tags: Optional[List[str]] = None
209
  ) -> QuickSearchResult:
210
  """
211
- Generate a random prompt from the partition.
212
-
213
- Args:
214
- rating: Rating code ('g', 's', 'q', 'e')
215
- person: Person category
216
- include_tags: Tags that must be present
217
- exclude_tags: Tags that must not be present
218
-
219
- Returns:
220
- QuickSearchResult with generated prompt
221
  """
222
- # Load partition
223
  if not self.load_partition(rating, person):
224
- return QuickSearchResult(
225
- success=False,
226
- error_message=f"Failed to load partition: {rating}_{person}"
227
- )
228
 
229
  metadata = self._manager.get_metadata()
230
  if metadata is None:
231
- return QuickSearchResult(
232
- success=False,
233
- error_message="Metadata not available"
234
- )
235
-
236
- # Get auto-tags for person category
237
- auto_tags = PERSON_AUTO_TAGS.get(person, [])
238
-
239
- # Combine include tags with auto tags
240
- all_include = list(set((include_tags or []) + auto_tags))
241
 
242
  # Filter events
243
  filtered_indices = self._current_partition.filter_events(
244
- required_tags=all_include if all_include else None,
245
  excluded_tags=exclude_tags,
246
  tag_to_id=metadata.tag_to_id
247
  )
248
 
249
  if len(filtered_indices) == 0:
250
  return QuickSearchResult(
251
- success=False,
252
  error_message="No events match the filter criteria",
253
  event_count=0
254
  )
255
 
256
- # Select random event
257
- random_idx = random.choice(filtered_indices)
258
-
259
- # Get tags for this event
260
  event_tags = self._current_partition.get_event_tags(
261
- event_idx=int(random_idx),
262
  id_to_tag=metadata.id_to_tag
263
  )
264
 
265
  if not event_tags:
266
- return QuickSearchResult(
267
- success=False,
268
- error_message="Failed to get tags for selected event"
269
- )
270
 
271
- # Convert to list and create prompt
272
  tags_list = sorted(list(event_tags))
273
  prompt = ", ".join(tags_list)
274
 
@@ -286,13 +258,6 @@ class TagStore:
286
  ) -> List[TagInfo]:
287
  """
288
  Search for tags matching query string.
289
-
290
- Args:
291
- query: Search query (partial match)
292
- limit: Maximum results
293
-
294
- Returns:
295
- List of matching tags with frequencies
296
  """
297
  metadata = self._manager.get_metadata()
298
  if metadata is None or not query:
 
35
  self.tags = []
36
 
37
 
38
+
39
  class TagStore:
40
  """
41
  High-level interface for Quick Search functionality.
42
 
43
  Provides:
44
+ - Partition loading and management (supports multiple active partitions)
45
  - Tag filtering by rating and person category
46
  - Include/Exclude tag filtering
47
  - Random event sampling for prompt generation
 
67
  def load_partition(self, rating: str, person: str) -> bool:
68
  """
69
  Load partition for given rating and person category.
70
+
71
  Args:
72
  rating: 'g', 's', 'q', or 'e'
73
  person: Person category like '1girl_solo'
74
+
75
  Returns:
76
  True if loaded successfully
77
  """
78
  partition_name = self._manager.get_partition_filename(rating, person)
79
+
80
+ # If already loaded, do nothing
81
+ if self._current_partition_name == partition_name and self._current_partition is not None:
82
+ return True
83
 
84
+ # Load new partition
 
 
 
 
 
 
85
  partition = self._manager.load_partition(partition_name)
86
+ if partition:
87
+ # Unload previous if different
88
+ if self._current_partition_name and self._current_partition_name != partition_name:
89
+ self._manager.unload_partition(self._current_partition_name)
90
+
91
+ self._current_partition = partition
92
+ self._current_partition_name = partition_name
93
+ return True
94
+
95
+ return False
96
 
97
  def get_event_count(self) -> int:
98
  """Get number of events in current partition"""
99
+ if self._current_partition:
100
+ return self._current_partition.num_events
101
+ return 0
102
 
103
  def get_filtered_event_count(
104
  self,
 
106
  exclude_tags: Optional[List[str]] = None
107
  ) -> int:
108
  """Get number of events matching filter criteria"""
109
+ if not self._current_partition:
110
  return 0
111
 
112
  metadata = self._manager.get_metadata()
113
  if metadata is None:
114
  return 0
115
 
116
+ filtered_indices = self._current_partition.filter_events(
117
  required_tags=include_tags,
118
  excluded_tags=exclude_tags,
119
  tag_to_id=metadata.tag_to_id
120
  )
121
+ return len(filtered_indices)
 
122
 
123
  def get_top_tags(
124
  self,
 
129
  ) -> List[TagInfo]:
130
  """
131
  Get most frequent tags in current partition with filters applied.
 
 
132
  """
133
+ if not self._current_partition:
134
  return []
135
 
136
  metadata = self._manager.get_metadata()
137
  if metadata is None:
138
  return []
139
 
 
140
  filtered_indices = self._current_partition.filter_events(
141
  required_tags=include_tags,
142
  excluded_tags=exclude_tags,
 
146
  if len(filtered_indices) == 0:
147
  return []
148
 
149
+ # Get counts for these events
150
  tag_counts = self._current_partition.get_tag_counts(
151
  event_indices=filtered_indices,
152
  id_to_tag=metadata.id_to_tag
153
  )
154
 
155
+ # Remove excluded tags and sort
156
  excluded_set = set(include_tags or []) | set(exclude_tags or [])
157
+
 
158
  sorted_tags = sorted(
159
  [(tag, count) for tag, count in tag_counts.items() if tag not in excluded_set],
160
  key=lambda x: x[1],
161
  reverse=True
162
+ )
163
 
164
+ # Pagination
165
+ paged_tags = sorted_tags[offset:offset + limit]
166
+
167
+ return [TagInfo(tag=tag, count=count) for tag, count in paged_tags]
168
 
169
  def get_total_tag_count(
170
  self,
 
172
  exclude_tags: Optional[List[str]] = None
173
  ) -> int:
174
  """Get total number of unique tags matching filter criteria"""
175
+ if not self._current_partition:
176
  return 0
177
 
178
  metadata = self._manager.get_metadata()
179
  if metadata is None:
180
  return 0
181
 
 
182
  filtered_indices = self._current_partition.filter_events(
183
  required_tags=include_tags,
184
  excluded_tags=exclude_tags,
 
188
  if len(filtered_indices) == 0:
189
  return 0
190
 
 
191
  tag_counts = self._current_partition.get_tag_counts(
192
  event_indices=filtered_indices,
193
  id_to_tag=metadata.id_to_tag
194
  )
195
+
 
196
  excluded_set = set(include_tags or []) | set(exclude_tags or [])
197
+ return len([t for t in tag_counts if t not in excluded_set])
198
 
199
  def generate_random_prompt(
200
  self,
 
204
  exclude_tags: Optional[List[str]] = None
205
  ) -> QuickSearchResult:
206
  """
207
+ Generate a random prompt from the dataset based on criteria.
 
 
 
 
 
 
 
 
 
208
  """
209
+ # Ensure correct partition is loaded
210
  if not self.load_partition(rating, person):
211
+ return QuickSearchResult(success=False, error_message=f"Failed to load partition for {rating}/{person}")
 
 
 
212
 
213
  metadata = self._manager.get_metadata()
214
  if metadata is None:
215
+ return QuickSearchResult(success=False, error_message="Metadata not available")
 
 
 
 
 
 
 
 
 
216
 
217
  # Filter events
218
  filtered_indices = self._current_partition.filter_events(
219
+ required_tags=include_tags,
220
  excluded_tags=exclude_tags,
221
  tag_to_id=metadata.tag_to_id
222
  )
223
 
224
  if len(filtered_indices) == 0:
225
  return QuickSearchResult(
226
+ success=False,
227
  error_message="No events match the filter criteria",
228
  event_count=0
229
  )
230
 
231
+ # Pick random event
232
+ choice_idx = random.choice(filtered_indices)
233
+
234
+ # Get tags for event
235
  event_tags = self._current_partition.get_event_tags(
236
+ event_idx=int(choice_idx),
237
  id_to_tag=metadata.id_to_tag
238
  )
239
 
240
  if not event_tags:
241
+ return QuickSearchResult(success=False, error_message="Failed to retrieve tags for selected event")
 
 
 
242
 
243
+ # Format prompt
244
  tags_list = sorted(list(event_tags))
245
  prompt = ", ".join(tags_list)
246
 
 
258
  ) -> List[TagInfo]:
259
  """
260
  Search for tags matching query string.
 
 
 
 
 
 
 
261
  """
262
  metadata = self._manager.get_metadata()
263
  if metadata is None or not query: