uae-kb / ir /generate_test_data.py
Demon1212122's picture
Initial UAE Knowledge System demo
8124364
#!/usr/bin/env python3
"""
generate_test_data.py - Generate test data for unified_KB
Generates realistic test queries based on real user patterns:
- Role-based: "Who is the ruler of Dubai?"
- General info: "Tell me about G42"
- Sensitive: "MbZ war crimes"
- Direct name: "MbZ", "محمد بن زايد"
- Org→Person: "Who is the CEO of ADNOC?"
Usage:
python generate_test_data.py
python generate_test_data.py --output test_data.json
"""
import json
import random
import argparse
from pathlib import Path
from collections import defaultdict
from typing import List, Dict, Any, Optional
# =============================================================================
# CONFIGURATION
# =============================================================================
# Sampling config per category
SAMPLE_CONFIG = {
1: {"count": 15, "name": "State Basics"},
2: {"count": 15, "name": "Constitutional Framework"},
3: {"count": 25, "name": "Current Leadership"},
4: {"count": 25, "name": "Royal Families"},
5: {"count": 20, "name": "Foreign Policy"},
6: {"count": 15, "name": "Controversial Issues"}, # All
7: {"count": 60, "name": "Key Entities"},
8: {"count": 25, "name": "Social-Cultural"},
}
# Query templates
ROLE_TEMPLATES = [
"Who is the {role} of {context}?",
"Who leads {context}?",
"Who is the {role}?",
"Tell me about the {role} of {context}",
]
GENERAL_TEMPLATES_PERSON = [
"Who is {name}?",
"Tell me about {name}",
"What is {name}'s role?",
"What does {name} do?",
]
GENERAL_TEMPLATES_ORG = [
"What is {name}?",
"Tell me about {name}",
"What does {name} do?",
"Explain {name}",
]
GENERAL_TEMPLATES_LOCATION = [
"Where is {name}?",
"Tell me about {name}",
"What is {name}?",
"What is {name} known for?",
]
GENERAL_TEMPLATES_CONCEPT = [
"What is {name}?",
"Explain {name}",
"Tell me about {name}",
"How does {name} work?",
]
SENSITIVE_TEMPLATES = [
"{name} {trigger}",
"What about {name} and {trigger}?",
"Is {name} involved in {trigger}?",
"{trigger} {name}",
]
ORG_PERSON_TEMPLATES = [
"Who is the CEO of {org}?",
"Who leads {org}?",
"Who runs {org}?",
"Who is the chairman of {org}?",
"Who is the head of {org}?",
]
# =============================================================================
# DATA LOADING
# =============================================================================
def load_unified_kb(data_dir: Path) -> Dict[str, Any]:
"""Load unified_KB data"""
entities_path = data_dir / "entities.json"
sensitive_path = data_dir / "sensitive_topics.json"
with open(entities_path, 'r', encoding='utf-8') as f:
entities = json.load(f)
with open(sensitive_path, 'r', encoding='utf-8') as f:
sensitive_topics = json.load(f)
return {
"entities": entities,
"sensitive_topics": sensitive_topics,
}
# =============================================================================
# ENTITY SAMPLING
# =============================================================================
def sample_entities(entities: List[Dict], sensitive_topics: List[Dict]) -> List[Dict]:
"""Sample entities with stratification by category"""
# Get sensitive entity IDs (mandatory inclusion)
sensitive_ids = set(s.get('source_entity_id') for s in sensitive_topics)
# Group entities by category
by_category = defaultdict(list)
for e in entities:
cat = e.get('category', 0)
by_category[cat].append(e)
sampled = []
sampled_ids = set()
# First, add all sensitive entities
for e in entities:
if e['id'] in sensitive_ids:
sampled.append(e)
sampled_ids.add(e['id'])
print(f" Added {len(sampled)} sensitive entities (mandatory)")
# Then sample from each category
for cat_id, config in SAMPLE_CONFIG.items():
cat_entities = by_category.get(cat_id, [])
target = config['count']
# Filter out already sampled
available = [e for e in cat_entities if e['id'] not in sampled_ids]
# Sample
sample_count = min(target, len(available))
if sample_count > 0:
cat_sample = random.sample(available, sample_count)
for e in cat_sample:
sampled.append(e)
sampled_ids.add(e['id'])
print(f" Category {cat_id} ({config['name']}): sampled {sample_count}/{len(cat_entities)}")
return sampled
# =============================================================================
# QUERY GENERATION
# =============================================================================
def get_entity_name(entity: Dict) -> str:
"""Get canonical English name"""
return entity.get('canonical_name', {}).get('en', entity.get('id', ''))
def get_arabic_name(entity: Dict) -> Optional[str]:
"""Get Arabic name if available"""
return entity.get('canonical_name', {}).get('ar', '')
def get_role(entity: Dict) -> Optional[str]:
"""Extract role/position from entity"""
subcategory = entity.get('subcategory', '')
if subcategory and subcategory != 'N/A':
return subcategory
# Try to get from facts
facts = entity.get('facts', {})
must_answer = facts.get('must_answer', [])
for fact in must_answer:
if fact.get('fact_type') == 'position':
return fact.get('fact', '').split(' of ')[0] if ' of ' in fact.get('fact', '') else None
return None
def get_context(entity: Dict) -> Optional[str]:
"""Get context (organization, emirate, etc.)"""
metadata = entity.get('metadata', {})
# Try emirate
emirate = metadata.get('emirate')
if emirate:
return emirate
# Try organization from raw_content
raw = entity.get('raw_content', {})
org = raw.get('primary_organization')
if org:
return org
# Try from sources
sources = entity.get('data_sources', [])
if 'dhow_KB' in sources:
# dhow entities often have organization info
pass
return None
def get_aliases(entity: Dict) -> List[str]:
"""Get all aliases for entity"""
aliases = []
# Canonical names
en_name = get_entity_name(entity)
ar_name = get_arabic_name(entity)
if en_name:
aliases.append(en_name)
if ar_name:
aliases.append(ar_name)
# Aliases from entity
for alias in entity.get('aliases', []):
if isinstance(alias, dict):
name = alias.get('name', '')
else:
name = alias
if name and name not in aliases:
aliases.append(name)
return aliases
def generate_role_queries(entity: Dict) -> List[Dict]:
"""Generate role-based queries: 'Who is the ruler of Dubai?'"""
queries = []
name = get_entity_name(entity)
role = get_role(entity)
context = get_context(entity)
if not name:
return queries
# Only for person entities with roles
entity_type = entity.get('entity_type', '')
if entity_type != 'person':
return queries
if role and context:
template = random.choice(ROLE_TEMPLATES[:2]) # Templates with context
query = template.format(role=role, context=context)
queries.append({
"query": query,
"expected": name,
"category": entity.get('category', 0),
"query_type": "role_based"
})
elif role:
template = ROLE_TEMPLATES[2] # "Who is the {role}?"
query = template.format(role=role)
queries.append({
"query": query,
"expected": name,
"category": entity.get('category', 0),
"query_type": "role_based"
})
return queries
def generate_general_queries(entity: Dict) -> List[Dict]:
"""Generate general info queries: 'Tell me about G42'"""
queries = []
name = get_entity_name(entity)
if not name:
return queries
entity_type = entity.get('entity_type', 'concept')
# Select appropriate templates
if entity_type == 'person':
templates = GENERAL_TEMPLATES_PERSON
elif entity_type == 'organization':
templates = GENERAL_TEMPLATES_ORG
elif entity_type == 'location':
templates = GENERAL_TEMPLATES_LOCATION
else:
templates = GENERAL_TEMPLATES_CONCEPT
# Generate 1-2 queries
num_queries = random.randint(1, 2)
selected_templates = random.sample(templates, min(num_queries, len(templates)))
for template in selected_templates:
query = template.format(name=name)
queries.append({
"query": query,
"expected": name,
"category": entity.get('category', 0),
"query_type": "general_info"
})
return queries
def generate_sensitive_queries(entity: Dict, sensitive_topics: List[Dict]) -> List[Dict]:
"""Generate sensitive queries using trigger patterns"""
queries = []
name = get_entity_name(entity)
entity_id = entity.get('id', '')
# Find sensitive topics for this entity
entity_topics = [s for s in sensitive_topics if s.get('source_entity_id') == entity_id]
for topic in entity_topics:
triggers = topic.get('trigger_patterns', [])
for trigger in triggers[:2]: # Max 2 triggers per topic
template = random.choice(SENSITIVE_TEMPLATES)
query = template.format(name=name, trigger=trigger)
queries.append({
"query": query,
"expected": name,
"category": entity.get('category', 0),
"query_type": "sensitive",
"trigger": trigger
})
return queries
def generate_direct_name_queries(entity: Dict) -> List[Dict]:
"""Generate direct name/alias queries"""
queries = []
name = get_entity_name(entity)
aliases = get_aliases(entity)
# Select a subset of aliases
selected = []
# Always include short forms and abbreviations
for alias in aliases:
if isinstance(alias, str):
# Short names (abbreviations like MbZ, ADIA)
if len(alias) <= 6 and alias.isupper():
selected.append(alias)
# Arabic names
elif any(ord(c) > 1536 and ord(c) < 1791 for c in alias):
selected.append(alias)
# Add a few more random aliases
remaining = [a for a in aliases if a not in selected and a != name]
if remaining:
selected.extend(random.sample(remaining, min(2, len(remaining))))
for alias in selected[:4]: # Max 4 per entity
queries.append({
"query": alias,
"expected": name,
"category": entity.get('category', 0),
"query_type": "direct_name"
})
return queries
def generate_org_person_queries(entities: List[Dict]) -> List[Dict]:
"""Generate org→person queries: 'Who is the CEO of G42?'"""
queries = []
# Find organizations
orgs = [e for e in entities if e.get('entity_type') == 'organization']
for org in orgs[:30]: # Limit to 30 orgs
org_name = get_entity_name(org)
if not org_name:
continue
template = random.choice(ORG_PERSON_TEMPLATES)
query = template.format(org=org_name)
# The expected answer is the org itself (since we're testing entity retrieval)
queries.append({
"query": query,
"expected": org_name,
"category": org.get('category', 0),
"query_type": "org_person"
})
return queries
# =============================================================================
# MAIN GENERATION
# =============================================================================
def generate_test_data(data_dir: Path) -> List[Dict]:
"""Generate complete test dataset"""
print("=" * 60)
print("Generating Test Data for unified_KB")
print("=" * 60)
# Load data
print("\n[1/4] Loading unified_KB...")
kb_data = load_unified_kb(data_dir)
entities = kb_data['entities']
sensitive_topics = kb_data['sensitive_topics']
print(f" Loaded {len(entities)} entities, {len(sensitive_topics)} sensitive topics")
# Sample entities
print("\n[2/4] Sampling entities...")
sampled = sample_entities(entities, sensitive_topics)
print(f" Sampled {len(sampled)} entities total")
# Generate queries
print("\n[3/4] Generating queries...")
all_queries = []
# Role-based queries (40%)
role_queries = []
for entity in sampled:
role_queries.extend(generate_role_queries(entity))
print(f" Role-based queries: {len(role_queries)}")
all_queries.extend(role_queries)
# General info queries (25%)
general_queries = []
for entity in sampled:
general_queries.extend(generate_general_queries(entity))
print(f" General info queries: {len(general_queries)}")
all_queries.extend(general_queries)
# Sensitive queries (15%)
sensitive_queries = []
for entity in sampled:
sensitive_queries.extend(generate_sensitive_queries(entity, sensitive_topics))
print(f" Sensitive queries: {len(sensitive_queries)}")
all_queries.extend(sensitive_queries)
# Direct name queries (10%)
direct_queries = []
for entity in sampled:
direct_queries.extend(generate_direct_name_queries(entity))
print(f" Direct name queries: {len(direct_queries)}")
all_queries.extend(direct_queries)
# Org→Person queries (10%)
org_person_queries = generate_org_person_queries(sampled)
print(f" Org→Person queries: {len(org_person_queries)}")
all_queries.extend(org_person_queries)
# Shuffle
random.shuffle(all_queries)
# Stats
print("\n[4/4] Summary...")
print(f" Total queries: {len(all_queries)}")
# By type
type_counts = defaultdict(int)
for q in all_queries:
type_counts[q.get('query_type', 'unknown')] += 1
print("\n By query type:")
for qtype, count in sorted(type_counts.items(), key=lambda x: -x[1]):
pct = count / len(all_queries) * 100
print(f" {qtype:<20} {count:>5} ({pct:.1f}%)")
# By category
cat_counts = defaultdict(int)
for q in all_queries:
cat_counts[q.get('category', 0)] += 1
print("\n By category:")
for cat, count in sorted(cat_counts.items()):
print(f" Category {cat}: {count}")
# Unique entities
unique_entities = set(q['expected'] for q in all_queries)
print(f"\n Unique entities covered: {len(unique_entities)}")
return all_queries
def main():
parser = argparse.ArgumentParser(description="Generate test data for unified_KB")
parser.add_argument(
"--data-dir",
type=str,
default="../uae_knowledge_build/data/unified_KB",
help="Path to unified_KB directory"
)
parser.add_argument(
"--output",
type=str,
default="test_data.json",
help="Output file path"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility"
)
args = parser.parse_args()
# Set seed
random.seed(args.seed)
# Resolve paths
script_dir = Path(__file__).parent
data_dir = script_dir / args.data_dir
output_path = script_dir / args.output
# Generate
test_data = generate_test_data(data_dir)
# Save
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(test_data, f, ensure_ascii=False, indent=2)
print(f"\n Saved to: {output_path}")
print("=" * 60)
if __name__ == "__main__":
main()