alexchilton Claude commited on
Commit
5f971cd
Β·
1 Parent(s): 69d5608

fix: Fix class parser to find section headers not table entries

Browse files

Fixed class loading bug that was only finding 5/12 classes.

Root Cause:
- _split_class_blocks() searched for class names anywhere in text
- Found them first in the comparison table (lines 26-54)
- This caused it to capture wrong content or too-short content

The Fix:
- Changed regex from '\b{CLASS}\b' to '^{CLASS}$' with re.MULTILINE
- Now finds class names only when they're alone on a line (section headers)
- Properly captures detailed class sections starting around line 57+

Results:
- Before: 5/12 classes found (41.7%)
- After: 11/12 classes found (91.7%)
- Missing: Monk (not in extracted_classes.txt file)

Classes now loading:
βœ… Barbarian, Bard, Cleric, Druid, Fighter, Paladin, Ranger, Rogue,
Sorcerer, Warlock, Wizard
❌ Monk (missing from source file)

Also added test_entity_retrieval.py - comprehensive test script that
validates all entities (spells, monsters, classes, races) can be found
by name in the RAG system.

πŸ€– Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. initialize_rag.py +16 -11
  2. test_entity_retrieval.py +375 -0
initialize_rag.py CHANGED
@@ -241,24 +241,29 @@ def load_classes(db_manager: ChromaDBManager, clear: bool = False):
241
 
242
 
243
  def _split_class_blocks(content: str) -> Dict[str, str]:
244
- """Split content by class names."""
245
  class_blocks = {}
246
 
247
  for i, class_name in enumerate(settings.DND_CLASSES):
248
- # Find this class
249
- pattern = rf'\b{class_name.upper()}\b'
250
- matches = list(re.finditer(pattern, content, re.IGNORECASE))
 
251
 
252
  if matches:
253
  start = matches[0].start()
254
- # Find end (next class or end of text)
255
  end = len(content)
256
- for next_class in settings.DND_CLASSES[i+1:]:
257
- next_pattern = rf'\b{next_class.upper()}\b'
258
- next_matches = re.search(next_pattern, content[start+100:], re.IGNORECASE)
259
- if next_matches:
260
- end = start + 100 + next_matches.start()
261
- break
 
 
 
 
262
 
263
  class_content = content[start:end].strip()
264
  if len(class_content) > 500: # Substantial content
 
241
 
242
 
243
  def _split_class_blocks(content: str) -> Dict[str, str]:
244
+ """Split content by class names at start of line (section headers)."""
245
  class_blocks = {}
246
 
247
  for i, class_name in enumerate(settings.DND_CLASSES):
248
+ # FIXED: Look for class name at the beginning of a line (^)
249
+ # This finds the detailed section header, not mentions in the table
250
+ pattern = rf'^{class_name.upper()}$'
251
+ matches = list(re.finditer(pattern, content, re.MULTILINE))
252
 
253
  if matches:
254
  start = matches[0].start()
255
+ # Find end (next class section or end of text)
256
  end = len(content)
257
+
258
+ # Look for ANY other class name on its own line after this one
259
+ for next_class in settings.DND_CLASSES:
260
+ if next_class == class_name:
261
+ continue
262
+ next_pattern = rf'^{next_class.upper()}$'
263
+ next_match = re.search(next_pattern, content[start+10:], re.MULTILINE)
264
+ if next_match:
265
+ candidate_end = start + 10 + next_match.start()
266
+ end = min(end, candidate_end)
267
 
268
  class_content = content[start:end].strip()
269
  if len(class_content) > 500: # Substantial content
test_entity_retrieval.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test Entity Retrieval - Validates all entities can be found by name
4
+
5
+ This test validates that every entity (spell, monster, class, race) can be
6
+ retrieved from the RAG system:
7
+
8
+ 1. Spells - from spells.txt + all_spells.txt
9
+ 2. Monsters - from extracted_monsters.txt
10
+ 3. Classes - from extracted_classes.txt
11
+ 4. Races - from Player's Handbook PDF
12
+
13
+ For each entity, verifies:
14
+ - Entity can be found in the RAG system
15
+ - Entity name is returned as the top result
16
+ - Metadata contains the correct entity name
17
+
18
+ This ensures name weighting and parsing work correctly across all collections.
19
+ """
20
+
21
+ import sys
22
+ import re
23
+ from pathlib import Path
24
+ from typing import List, Dict, Set
25
+
26
+ # Add project to path
27
+ project_root = Path(__file__).parent
28
+ sys.path.insert(0, str(project_root))
29
+
30
+ from dnd_rag_system.core.chroma_manager import ChromaDBManager
31
+ from dnd_rag_system.parsers.spell_parser import SpellParser
32
+ from dnd_rag_system.config import settings
33
+
34
+
35
+ class EntityRetrievalTest:
36
+ """Test that all entities can be retrieved by name."""
37
+
38
+ def __init__(self):
39
+ self.db = ChromaDBManager()
40
+ self.results_by_collection = {}
41
+
42
+ def run_all_tests(self):
43
+ """Run retrieval tests for all collections."""
44
+ print("="*70)
45
+ print("πŸ§ͺ COMPREHENSIVE ENTITY RETRIEVAL VALIDATION")
46
+ print("="*70)
47
+ print()
48
+ print("Testing all 4 collections: Spells, Monsters, Classes, Races")
49
+ print("Validates that every entity can be found by name.\n")
50
+
51
+ # Test each collection
52
+ self.test_spells()
53
+ self.test_monsters()
54
+ self.test_classes()
55
+ self.test_races()
56
+
57
+ # Print overall summary
58
+ self._print_overall_summary()
59
+
60
+ def test_spells(self):
61
+ """Test spell retrieval."""
62
+ print("="*70)
63
+ print("πŸ“š TESTING SPELLS")
64
+ print("="*70)
65
+ print()
66
+
67
+ # Parse spells from both sources
68
+ parser = SpellParser()
69
+ parsed_spells = parser.parse()
70
+
71
+ # Get unique spell names
72
+ spell_names = set()
73
+ for parsed_spell in parsed_spells:
74
+ spell_name = parsed_spell.metadata.get('name', '')
75
+ if spell_name:
76
+ spell_names.add(spell_name)
77
+
78
+ spell_names = sorted(spell_names)
79
+ print(f"Found {len(spell_names)} unique spells to test")
80
+ print(f"(from spells.txt + all_spells.txt)\n")
81
+
82
+ # Test each spell
83
+ results = self._test_collection(
84
+ spell_names,
85
+ settings.COLLECTION_NAMES['spells'],
86
+ 'Spell'
87
+ )
88
+
89
+ self.results_by_collection['spells'] = results
90
+ self._print_collection_results(results, 'SPELLS')
91
+
92
+ def test_monsters(self):
93
+ """Test monster retrieval."""
94
+ print("\n" + "="*70)
95
+ print("πŸ‘Ή TESTING MONSTERS")
96
+ print("="*70)
97
+ print()
98
+
99
+ # Parse monsters from file
100
+ monster_names = self._extract_monster_names()
101
+ print(f"Found {len(monster_names)} monsters to test")
102
+ print(f"(from extracted_monsters.txt)\n")
103
+
104
+ # Test each monster
105
+ results = self._test_collection(
106
+ monster_names,
107
+ settings.COLLECTION_NAMES['monsters'],
108
+ 'Monster'
109
+ )
110
+
111
+ self.results_by_collection['monsters'] = results
112
+ self._print_collection_results(results, 'MONSTERS')
113
+
114
+ def test_classes(self):
115
+ """Test class retrieval."""
116
+ print("\n" + "="*70)
117
+ print("βš”οΈ TESTING CLASSES")
118
+ print("="*70)
119
+ print()
120
+
121
+ # Parse classes from file
122
+ class_names = self._extract_class_names()
123
+ print(f"Found {len(class_names)} classes to test")
124
+ print(f"(from extracted_classes.txt)\n")
125
+
126
+ # Test each class
127
+ results = self._test_collection(
128
+ class_names,
129
+ settings.COLLECTION_NAMES['classes'],
130
+ 'Class'
131
+ )
132
+
133
+ self.results_by_collection['classes'] = results
134
+ self._print_collection_results(results, 'CLASSES')
135
+
136
+ def test_races(self):
137
+ """Test race retrieval."""
138
+ print("\n" + "="*70)
139
+ print("🧝 TESTING RACES")
140
+ print("="*70)
141
+ print()
142
+
143
+ # Use standard D&D races
144
+ race_names = [
145
+ 'Dragonborn', 'Dwarf', 'Elf', 'Gnome',
146
+ 'Half-Elf', 'Halfling', 'Half-Orc', 'Human', 'Tiefling'
147
+ ]
148
+ print(f"Found {len(race_names)} races to test")
149
+ print(f"(from Player's Handbook PDF)\n")
150
+
151
+ # Test each race
152
+ results = self._test_collection(
153
+ race_names,
154
+ settings.COLLECTION_NAMES['races'],
155
+ 'Race'
156
+ )
157
+
158
+ self.results_by_collection['races'] = results
159
+ self._print_collection_results(results, 'RACES')
160
+
161
+ def _test_collection(self, entity_names: List[str], collection_name: str, entity_type: str) -> Dict:
162
+ """
163
+ Test retrieval for a collection of entities.
164
+
165
+ Args:
166
+ entity_names: List of entity names to test
167
+ collection_name: ChromaDB collection name
168
+ entity_type: Type of entity (for logging)
169
+
170
+ Returns:
171
+ Dictionary with test results
172
+ """
173
+ passed = []
174
+ warnings = []
175
+ failed = []
176
+
177
+ for i, entity_name in enumerate(entity_names, 1):
178
+ # Progress indicator
179
+ if i % 20 == 0:
180
+ print(f" Progress: {i}/{len(entity_names)} {entity_type.lower()}s tested...")
181
+
182
+ try:
183
+ # Search for the entity
184
+ results = self.db.search(collection_name, entity_name, n_results=3)
185
+
186
+ # Check if we got results
187
+ if not results or not results['documents'] or len(results['documents'][0]) == 0:
188
+ failed.append({
189
+ 'name': entity_name,
190
+ 'reason': 'No results returned',
191
+ 'top_result': None
192
+ })
193
+ continue
194
+
195
+ # Get top result
196
+ top_metadata = results['metadatas'][0][0]
197
+ top_name = top_metadata.get('name', 'Unknown').upper()
198
+ search_name = entity_name.upper()
199
+ distance = results['distances'][0][0]
200
+
201
+ # Check if top result matches
202
+ if top_name == search_name:
203
+ passed.append({
204
+ 'name': entity_name,
205
+ 'distance': distance
206
+ })
207
+ else:
208
+ # Check if correct entity is in top 3
209
+ found_in_top_3 = False
210
+ for j in range(min(3, len(results['metadatas'][0]))):
211
+ result_name = results['metadatas'][0][j].get('name', '').upper()
212
+ if result_name == search_name:
213
+ found_in_top_3 = True
214
+ warnings.append({
215
+ 'name': entity_name,
216
+ 'reason': f'Found at position {j+1}, not #1',
217
+ 'top_result': top_metadata.get('name', 'Unknown'),
218
+ 'distance': distance
219
+ })
220
+ break
221
+
222
+ if not found_in_top_3:
223
+ failed.append({
224
+ 'name': entity_name,
225
+ 'reason': 'Not in top 3 results',
226
+ 'top_result': top_metadata.get('name', 'Unknown'),
227
+ 'distance': distance
228
+ })
229
+
230
+ except Exception as e:
231
+ failed.append({
232
+ 'name': entity_name,
233
+ 'reason': f'Error: {str(e)}',
234
+ 'top_result': None
235
+ })
236
+
237
+ return {
238
+ 'total': len(entity_names),
239
+ 'passed': passed,
240
+ 'warnings': warnings,
241
+ 'failed': failed
242
+ }
243
+
244
+ def _extract_monster_names(self) -> List[str]:
245
+ """Extract monster names from extracted_monsters.txt."""
246
+ monster_file = Path(settings.EXTRACTED_MONSTERS_TXT)
247
+ if not monster_file.exists():
248
+ print(f"Warning: {monster_file} not found")
249
+ return []
250
+
251
+ with open(monster_file, 'r', encoding='utf-8') as f:
252
+ text = f.read()
253
+
254
+ # Split by double newlines to get monster blocks
255
+ blocks = text.split('\n\n')
256
+ monster_names = []
257
+
258
+ for block in blocks:
259
+ lines = block.strip().split('\n')
260
+ if lines:
261
+ # First line is typically the monster name
262
+ name = lines[0].strip()
263
+ # Filter out empty lines and non-monster entries
264
+ if name and not name.startswith('#') and len(name) > 1:
265
+ monster_names.append(name)
266
+
267
+ return sorted(set(monster_names))
268
+
269
+ def _extract_class_names(self) -> List[str]:
270
+ """Extract class names from extracted_classes.txt."""
271
+ class_file = Path(settings.EXTRACTED_CLASSES_TXT)
272
+ if not class_file.exists():
273
+ print(f"Warning: {class_file} not found")
274
+ return []
275
+
276
+ with open(class_file, 'r', encoding='utf-8') as f:
277
+ text = f.read()
278
+
279
+ # Use the standard D&D class list
280
+ standard_classes = settings.DND_CLASSES
281
+ found_classes = []
282
+
283
+ # Check which classes are present in the file
284
+ for class_name in standard_classes:
285
+ if class_name in text or class_name.upper() in text:
286
+ found_classes.append(class_name)
287
+
288
+ return sorted(found_classes)
289
+
290
+ def _print_collection_results(self, results: Dict, collection_name: str):
291
+ """Print results for a single collection."""
292
+ total = results['total']
293
+ passed = len(results['passed'])
294
+ warnings = len(results['warnings'])
295
+ failed = len(results['failed'])
296
+ pass_rate = (passed / total * 100) if total > 0 else 0
297
+
298
+ print()
299
+ print(f"πŸ“Š {collection_name} Results:")
300
+ print(f" Total: {total}")
301
+ print(f" βœ… Passed: {passed} ({pass_rate:.1f}%)")
302
+ print(f" ⚠️ Warnings: {warnings}")
303
+ print(f" ❌ Failed: {failed}")
304
+
305
+ # Show first few failures
306
+ if results['failed']:
307
+ print(f"\n Failed entities (showing first 5):")
308
+ for fail in results['failed'][:5]:
309
+ print(f" ❌ {fail['name']}: {fail['reason']}")
310
+ if fail['top_result']:
311
+ print(f" (Top result: {fail['top_result']})")
312
+
313
+ def _print_overall_summary(self):
314
+ """Print overall summary across all collections."""
315
+ print("\n" + "="*70)
316
+ print("πŸ“Š OVERALL SUMMARY")
317
+ print("="*70)
318
+ print()
319
+
320
+ total_entities = 0
321
+ total_passed = 0
322
+ total_warnings = 0
323
+ total_failed = 0
324
+
325
+ for collection, results in self.results_by_collection.items():
326
+ total_entities += results['total']
327
+ total_passed += len(results['passed'])
328
+ total_warnings += len(results['warnings'])
329
+ total_failed += len(results['failed'])
330
+
331
+ overall_pass_rate = (total_passed / total_entities * 100) if total_entities > 0 else 0
332
+
333
+ print(f"Total Entities Tested: {total_entities}")
334
+ print(f"βœ… Passed: {total_passed} ({overall_pass_rate:.1f}%)")
335
+ print(f"⚠️ Warnings: {total_warnings}")
336
+ print(f"❌ Failed: {total_failed}")
337
+ print()
338
+
339
+ # Breakdown by collection
340
+ print("Breakdown by Collection:")
341
+ for collection, results in self.results_by_collection.items():
342
+ passed = len(results['passed'])
343
+ total = results['total']
344
+ rate = (passed / total * 100) if total > 0 else 0
345
+ print(f" {collection.capitalize()}: {passed}/{total} ({rate:.1f}%)")
346
+
347
+ print()
348
+ print("="*70)
349
+
350
+ if total_failed == 0 and total_warnings == 0:
351
+ print("πŸŽ‰ PERFECT! All entities retrieved correctly!")
352
+ elif total_failed == 0:
353
+ print("βœ… GOOD! All entities found (some not ranking #1)")
354
+ else:
355
+ print("⚠️ ISSUES FOUND - Some entities missing or incorrect")
356
+ print(f" Run: python initialize_rag.py --clear")
357
+
358
+ print("="*70)
359
+
360
+
361
+ def main():
362
+ """Run the entity retrieval test."""
363
+ try:
364
+ test = EntityRetrievalTest()
365
+ test.run_all_tests()
366
+ except KeyboardInterrupt:
367
+ print("\n\n⚠️ Test interrupted by user")
368
+ except Exception as e:
369
+ print(f"\n\n❌ Fatal error: {e}")
370
+ import traceback
371
+ traceback.print_exc()
372
+
373
+
374
+ if __name__ == '__main__':
375
+ main()