#!/usr/bin/env python3 """ generate_test_data.py - Generate test data for unified_KB Generates realistic test queries based on real user patterns: - Role-based: "Who is the ruler of Dubai?" - General info: "Tell me about G42" - Sensitive: "MbZ war crimes" - Direct name: "MbZ", "محمد بن زايد" - Org→Person: "Who is the CEO of ADNOC?" Usage: python generate_test_data.py python generate_test_data.py --output test_data.json """ import json import random import argparse from pathlib import Path from collections import defaultdict from typing import List, Dict, Any, Optional # ============================================================================= # CONFIGURATION # ============================================================================= # Sampling config per category SAMPLE_CONFIG = { 1: {"count": 15, "name": "State Basics"}, 2: {"count": 15, "name": "Constitutional Framework"}, 3: {"count": 25, "name": "Current Leadership"}, 4: {"count": 25, "name": "Royal Families"}, 5: {"count": 20, "name": "Foreign Policy"}, 6: {"count": 15, "name": "Controversial Issues"}, # All 7: {"count": 60, "name": "Key Entities"}, 8: {"count": 25, "name": "Social-Cultural"}, } # Query templates ROLE_TEMPLATES = [ "Who is the {role} of {context}?", "Who leads {context}?", "Who is the {role}?", "Tell me about the {role} of {context}", ] GENERAL_TEMPLATES_PERSON = [ "Who is {name}?", "Tell me about {name}", "What is {name}'s role?", "What does {name} do?", ] GENERAL_TEMPLATES_ORG = [ "What is {name}?", "Tell me about {name}", "What does {name} do?", "Explain {name}", ] GENERAL_TEMPLATES_LOCATION = [ "Where is {name}?", "Tell me about {name}", "What is {name}?", "What is {name} known for?", ] GENERAL_TEMPLATES_CONCEPT = [ "What is {name}?", "Explain {name}", "Tell me about {name}", "How does {name} work?", ] SENSITIVE_TEMPLATES = [ "{name} {trigger}", "What about {name} and {trigger}?", "Is {name} involved in {trigger}?", "{trigger} {name}", ] ORG_PERSON_TEMPLATES = [ "Who is the CEO of {org}?", "Who leads {org}?", "Who runs {org}?", "Who is the chairman of {org}?", "Who is the head of {org}?", ] # ============================================================================= # DATA LOADING # ============================================================================= def load_unified_kb(data_dir: Path) -> Dict[str, Any]: """Load unified_KB data""" entities_path = data_dir / "entities.json" sensitive_path = data_dir / "sensitive_topics.json" with open(entities_path, 'r', encoding='utf-8') as f: entities = json.load(f) with open(sensitive_path, 'r', encoding='utf-8') as f: sensitive_topics = json.load(f) return { "entities": entities, "sensitive_topics": sensitive_topics, } # ============================================================================= # ENTITY SAMPLING # ============================================================================= def sample_entities(entities: List[Dict], sensitive_topics: List[Dict]) -> List[Dict]: """Sample entities with stratification by category""" # Get sensitive entity IDs (mandatory inclusion) sensitive_ids = set(s.get('source_entity_id') for s in sensitive_topics) # Group entities by category by_category = defaultdict(list) for e in entities: cat = e.get('category', 0) by_category[cat].append(e) sampled = [] sampled_ids = set() # First, add all sensitive entities for e in entities: if e['id'] in sensitive_ids: sampled.append(e) sampled_ids.add(e['id']) print(f" Added {len(sampled)} sensitive entities (mandatory)") # Then sample from each category for cat_id, config in SAMPLE_CONFIG.items(): cat_entities = by_category.get(cat_id, []) target = config['count'] # Filter out already sampled available = [e for e in cat_entities if e['id'] not in sampled_ids] # Sample sample_count = min(target, len(available)) if sample_count > 0: cat_sample = random.sample(available, sample_count) for e in cat_sample: sampled.append(e) sampled_ids.add(e['id']) print(f" Category {cat_id} ({config['name']}): sampled {sample_count}/{len(cat_entities)}") return sampled # ============================================================================= # QUERY GENERATION # ============================================================================= def get_entity_name(entity: Dict) -> str: """Get canonical English name""" return entity.get('canonical_name', {}).get('en', entity.get('id', '')) def get_arabic_name(entity: Dict) -> Optional[str]: """Get Arabic name if available""" return entity.get('canonical_name', {}).get('ar', '') def get_role(entity: Dict) -> Optional[str]: """Extract role/position from entity""" subcategory = entity.get('subcategory', '') if subcategory and subcategory != 'N/A': return subcategory # Try to get from facts facts = entity.get('facts', {}) must_answer = facts.get('must_answer', []) for fact in must_answer: if fact.get('fact_type') == 'position': return fact.get('fact', '').split(' of ')[0] if ' of ' in fact.get('fact', '') else None return None def get_context(entity: Dict) -> Optional[str]: """Get context (organization, emirate, etc.)""" metadata = entity.get('metadata', {}) # Try emirate emirate = metadata.get('emirate') if emirate: return emirate # Try organization from raw_content raw = entity.get('raw_content', {}) org = raw.get('primary_organization') if org: return org # Try from sources sources = entity.get('data_sources', []) if 'dhow_KB' in sources: # dhow entities often have organization info pass return None def get_aliases(entity: Dict) -> List[str]: """Get all aliases for entity""" aliases = [] # Canonical names en_name = get_entity_name(entity) ar_name = get_arabic_name(entity) if en_name: aliases.append(en_name) if ar_name: aliases.append(ar_name) # Aliases from entity for alias in entity.get('aliases', []): if isinstance(alias, dict): name = alias.get('name', '') else: name = alias if name and name not in aliases: aliases.append(name) return aliases def generate_role_queries(entity: Dict) -> List[Dict]: """Generate role-based queries: 'Who is the ruler of Dubai?'""" queries = [] name = get_entity_name(entity) role = get_role(entity) context = get_context(entity) if not name: return queries # Only for person entities with roles entity_type = entity.get('entity_type', '') if entity_type != 'person': return queries if role and context: template = random.choice(ROLE_TEMPLATES[:2]) # Templates with context query = template.format(role=role, context=context) queries.append({ "query": query, "expected": name, "category": entity.get('category', 0), "query_type": "role_based" }) elif role: template = ROLE_TEMPLATES[2] # "Who is the {role}?" query = template.format(role=role) queries.append({ "query": query, "expected": name, "category": entity.get('category', 0), "query_type": "role_based" }) return queries def generate_general_queries(entity: Dict) -> List[Dict]: """Generate general info queries: 'Tell me about G42'""" queries = [] name = get_entity_name(entity) if not name: return queries entity_type = entity.get('entity_type', 'concept') # Select appropriate templates if entity_type == 'person': templates = GENERAL_TEMPLATES_PERSON elif entity_type == 'organization': templates = GENERAL_TEMPLATES_ORG elif entity_type == 'location': templates = GENERAL_TEMPLATES_LOCATION else: templates = GENERAL_TEMPLATES_CONCEPT # Generate 1-2 queries num_queries = random.randint(1, 2) selected_templates = random.sample(templates, min(num_queries, len(templates))) for template in selected_templates: query = template.format(name=name) queries.append({ "query": query, "expected": name, "category": entity.get('category', 0), "query_type": "general_info" }) return queries def generate_sensitive_queries(entity: Dict, sensitive_topics: List[Dict]) -> List[Dict]: """Generate sensitive queries using trigger patterns""" queries = [] name = get_entity_name(entity) entity_id = entity.get('id', '') # Find sensitive topics for this entity entity_topics = [s for s in sensitive_topics if s.get('source_entity_id') == entity_id] for topic in entity_topics: triggers = topic.get('trigger_patterns', []) for trigger in triggers[:2]: # Max 2 triggers per topic template = random.choice(SENSITIVE_TEMPLATES) query = template.format(name=name, trigger=trigger) queries.append({ "query": query, "expected": name, "category": entity.get('category', 0), "query_type": "sensitive", "trigger": trigger }) return queries def generate_direct_name_queries(entity: Dict) -> List[Dict]: """Generate direct name/alias queries""" queries = [] name = get_entity_name(entity) aliases = get_aliases(entity) # Select a subset of aliases selected = [] # Always include short forms and abbreviations for alias in aliases: if isinstance(alias, str): # Short names (abbreviations like MbZ, ADIA) if len(alias) <= 6 and alias.isupper(): selected.append(alias) # Arabic names elif any(ord(c) > 1536 and ord(c) < 1791 for c in alias): selected.append(alias) # Add a few more random aliases remaining = [a for a in aliases if a not in selected and a != name] if remaining: selected.extend(random.sample(remaining, min(2, len(remaining)))) for alias in selected[:4]: # Max 4 per entity queries.append({ "query": alias, "expected": name, "category": entity.get('category', 0), "query_type": "direct_name" }) return queries def generate_org_person_queries(entities: List[Dict]) -> List[Dict]: """Generate org→person queries: 'Who is the CEO of G42?'""" queries = [] # Find organizations orgs = [e for e in entities if e.get('entity_type') == 'organization'] for org in orgs[:30]: # Limit to 30 orgs org_name = get_entity_name(org) if not org_name: continue template = random.choice(ORG_PERSON_TEMPLATES) query = template.format(org=org_name) # The expected answer is the org itself (since we're testing entity retrieval) queries.append({ "query": query, "expected": org_name, "category": org.get('category', 0), "query_type": "org_person" }) return queries # ============================================================================= # MAIN GENERATION # ============================================================================= def generate_test_data(data_dir: Path) -> List[Dict]: """Generate complete test dataset""" print("=" * 60) print("Generating Test Data for unified_KB") print("=" * 60) # Load data print("\n[1/4] Loading unified_KB...") kb_data = load_unified_kb(data_dir) entities = kb_data['entities'] sensitive_topics = kb_data['sensitive_topics'] print(f" Loaded {len(entities)} entities, {len(sensitive_topics)} sensitive topics") # Sample entities print("\n[2/4] Sampling entities...") sampled = sample_entities(entities, sensitive_topics) print(f" Sampled {len(sampled)} entities total") # Generate queries print("\n[3/4] Generating queries...") all_queries = [] # Role-based queries (40%) role_queries = [] for entity in sampled: role_queries.extend(generate_role_queries(entity)) print(f" Role-based queries: {len(role_queries)}") all_queries.extend(role_queries) # General info queries (25%) general_queries = [] for entity in sampled: general_queries.extend(generate_general_queries(entity)) print(f" General info queries: {len(general_queries)}") all_queries.extend(general_queries) # Sensitive queries (15%) sensitive_queries = [] for entity in sampled: sensitive_queries.extend(generate_sensitive_queries(entity, sensitive_topics)) print(f" Sensitive queries: {len(sensitive_queries)}") all_queries.extend(sensitive_queries) # Direct name queries (10%) direct_queries = [] for entity in sampled: direct_queries.extend(generate_direct_name_queries(entity)) print(f" Direct name queries: {len(direct_queries)}") all_queries.extend(direct_queries) # Org→Person queries (10%) org_person_queries = generate_org_person_queries(sampled) print(f" Org→Person queries: {len(org_person_queries)}") all_queries.extend(org_person_queries) # Shuffle random.shuffle(all_queries) # Stats print("\n[4/4] Summary...") print(f" Total queries: {len(all_queries)}") # By type type_counts = defaultdict(int) for q in all_queries: type_counts[q.get('query_type', 'unknown')] += 1 print("\n By query type:") for qtype, count in sorted(type_counts.items(), key=lambda x: -x[1]): pct = count / len(all_queries) * 100 print(f" {qtype:<20} {count:>5} ({pct:.1f}%)") # By category cat_counts = defaultdict(int) for q in all_queries: cat_counts[q.get('category', 0)] += 1 print("\n By category:") for cat, count in sorted(cat_counts.items()): print(f" Category {cat}: {count}") # Unique entities unique_entities = set(q['expected'] for q in all_queries) print(f"\n Unique entities covered: {len(unique_entities)}") return all_queries def main(): parser = argparse.ArgumentParser(description="Generate test data for unified_KB") parser.add_argument( "--data-dir", type=str, default="../uae_knowledge_build/data/unified_KB", help="Path to unified_KB directory" ) parser.add_argument( "--output", type=str, default="test_data.json", help="Output file path" ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducibility" ) args = parser.parse_args() # Set seed random.seed(args.seed) # Resolve paths script_dir = Path(__file__).parent data_dir = script_dir / args.data_dir output_path = script_dir / args.output # Generate test_data = generate_test_data(data_dir) # Save with open(output_path, 'w', encoding='utf-8') as f: json.dump(test_data, f, ensure_ascii=False, indent=2) print(f"\n Saved to: {output_path}") print("=" * 60) if __name__ == "__main__": main()