|
|
|
|
|
""" |
|
|
generate_test_data.py - Generate test data for unified_KB |
|
|
|
|
|
Generates realistic test queries based on real user patterns: |
|
|
- Role-based: "Who is the ruler of Dubai?" |
|
|
- General info: "Tell me about G42" |
|
|
- Sensitive: "MbZ war crimes" |
|
|
- Direct name: "MbZ", "محمد بن زايد" |
|
|
- Org→Person: "Who is the CEO of ADNOC?" |
|
|
|
|
|
Usage: |
|
|
python generate_test_data.py |
|
|
python generate_test_data.py --output test_data.json |
|
|
""" |
|
|
|
|
|
import json |
|
|
import random |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from collections import defaultdict |
|
|
from typing import List, Dict, Any, Optional |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SAMPLE_CONFIG = { |
|
|
1: {"count": 15, "name": "State Basics"}, |
|
|
2: {"count": 15, "name": "Constitutional Framework"}, |
|
|
3: {"count": 25, "name": "Current Leadership"}, |
|
|
4: {"count": 25, "name": "Royal Families"}, |
|
|
5: {"count": 20, "name": "Foreign Policy"}, |
|
|
6: {"count": 15, "name": "Controversial Issues"}, |
|
|
7: {"count": 60, "name": "Key Entities"}, |
|
|
8: {"count": 25, "name": "Social-Cultural"}, |
|
|
} |
|
|
|
|
|
|
|
|
ROLE_TEMPLATES = [ |
|
|
"Who is the {role} of {context}?", |
|
|
"Who leads {context}?", |
|
|
"Who is the {role}?", |
|
|
"Tell me about the {role} of {context}", |
|
|
] |
|
|
|
|
|
GENERAL_TEMPLATES_PERSON = [ |
|
|
"Who is {name}?", |
|
|
"Tell me about {name}", |
|
|
"What is {name}'s role?", |
|
|
"What does {name} do?", |
|
|
] |
|
|
|
|
|
GENERAL_TEMPLATES_ORG = [ |
|
|
"What is {name}?", |
|
|
"Tell me about {name}", |
|
|
"What does {name} do?", |
|
|
"Explain {name}", |
|
|
] |
|
|
|
|
|
GENERAL_TEMPLATES_LOCATION = [ |
|
|
"Where is {name}?", |
|
|
"Tell me about {name}", |
|
|
"What is {name}?", |
|
|
"What is {name} known for?", |
|
|
] |
|
|
|
|
|
GENERAL_TEMPLATES_CONCEPT = [ |
|
|
"What is {name}?", |
|
|
"Explain {name}", |
|
|
"Tell me about {name}", |
|
|
"How does {name} work?", |
|
|
] |
|
|
|
|
|
SENSITIVE_TEMPLATES = [ |
|
|
"{name} {trigger}", |
|
|
"What about {name} and {trigger}?", |
|
|
"Is {name} involved in {trigger}?", |
|
|
"{trigger} {name}", |
|
|
] |
|
|
|
|
|
ORG_PERSON_TEMPLATES = [ |
|
|
"Who is the CEO of {org}?", |
|
|
"Who leads {org}?", |
|
|
"Who runs {org}?", |
|
|
"Who is the chairman of {org}?", |
|
|
"Who is the head of {org}?", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_unified_kb(data_dir: Path) -> Dict[str, Any]: |
|
|
"""Load unified_KB data""" |
|
|
entities_path = data_dir / "entities.json" |
|
|
sensitive_path = data_dir / "sensitive_topics.json" |
|
|
|
|
|
with open(entities_path, 'r', encoding='utf-8') as f: |
|
|
entities = json.load(f) |
|
|
|
|
|
with open(sensitive_path, 'r', encoding='utf-8') as f: |
|
|
sensitive_topics = json.load(f) |
|
|
|
|
|
return { |
|
|
"entities": entities, |
|
|
"sensitive_topics": sensitive_topics, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_entities(entities: List[Dict], sensitive_topics: List[Dict]) -> List[Dict]: |
|
|
"""Sample entities with stratification by category""" |
|
|
|
|
|
|
|
|
sensitive_ids = set(s.get('source_entity_id') for s in sensitive_topics) |
|
|
|
|
|
|
|
|
by_category = defaultdict(list) |
|
|
for e in entities: |
|
|
cat = e.get('category', 0) |
|
|
by_category[cat].append(e) |
|
|
|
|
|
sampled = [] |
|
|
sampled_ids = set() |
|
|
|
|
|
|
|
|
for e in entities: |
|
|
if e['id'] in sensitive_ids: |
|
|
sampled.append(e) |
|
|
sampled_ids.add(e['id']) |
|
|
|
|
|
print(f" Added {len(sampled)} sensitive entities (mandatory)") |
|
|
|
|
|
|
|
|
for cat_id, config in SAMPLE_CONFIG.items(): |
|
|
cat_entities = by_category.get(cat_id, []) |
|
|
target = config['count'] |
|
|
|
|
|
|
|
|
available = [e for e in cat_entities if e['id'] not in sampled_ids] |
|
|
|
|
|
|
|
|
sample_count = min(target, len(available)) |
|
|
if sample_count > 0: |
|
|
cat_sample = random.sample(available, sample_count) |
|
|
for e in cat_sample: |
|
|
sampled.append(e) |
|
|
sampled_ids.add(e['id']) |
|
|
|
|
|
print(f" Category {cat_id} ({config['name']}): sampled {sample_count}/{len(cat_entities)}") |
|
|
|
|
|
return sampled |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_entity_name(entity: Dict) -> str: |
|
|
"""Get canonical English name""" |
|
|
return entity.get('canonical_name', {}).get('en', entity.get('id', '')) |
|
|
|
|
|
|
|
|
def get_arabic_name(entity: Dict) -> Optional[str]: |
|
|
"""Get Arabic name if available""" |
|
|
return entity.get('canonical_name', {}).get('ar', '') |
|
|
|
|
|
|
|
|
def get_role(entity: Dict) -> Optional[str]: |
|
|
"""Extract role/position from entity""" |
|
|
subcategory = entity.get('subcategory', '') |
|
|
if subcategory and subcategory != 'N/A': |
|
|
return subcategory |
|
|
|
|
|
|
|
|
facts = entity.get('facts', {}) |
|
|
must_answer = facts.get('must_answer', []) |
|
|
for fact in must_answer: |
|
|
if fact.get('fact_type') == 'position': |
|
|
return fact.get('fact', '').split(' of ')[0] if ' of ' in fact.get('fact', '') else None |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def get_context(entity: Dict) -> Optional[str]: |
|
|
"""Get context (organization, emirate, etc.)""" |
|
|
metadata = entity.get('metadata', {}) |
|
|
|
|
|
|
|
|
emirate = metadata.get('emirate') |
|
|
if emirate: |
|
|
return emirate |
|
|
|
|
|
|
|
|
raw = entity.get('raw_content', {}) |
|
|
org = raw.get('primary_organization') |
|
|
if org: |
|
|
return org |
|
|
|
|
|
|
|
|
sources = entity.get('data_sources', []) |
|
|
if 'dhow_KB' in sources: |
|
|
|
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def get_aliases(entity: Dict) -> List[str]: |
|
|
"""Get all aliases for entity""" |
|
|
aliases = [] |
|
|
|
|
|
|
|
|
en_name = get_entity_name(entity) |
|
|
ar_name = get_arabic_name(entity) |
|
|
|
|
|
if en_name: |
|
|
aliases.append(en_name) |
|
|
if ar_name: |
|
|
aliases.append(ar_name) |
|
|
|
|
|
|
|
|
for alias in entity.get('aliases', []): |
|
|
if isinstance(alias, dict): |
|
|
name = alias.get('name', '') |
|
|
else: |
|
|
name = alias |
|
|
if name and name not in aliases: |
|
|
aliases.append(name) |
|
|
|
|
|
return aliases |
|
|
|
|
|
|
|
|
def generate_role_queries(entity: Dict) -> List[Dict]: |
|
|
"""Generate role-based queries: 'Who is the ruler of Dubai?'""" |
|
|
queries = [] |
|
|
|
|
|
name = get_entity_name(entity) |
|
|
role = get_role(entity) |
|
|
context = get_context(entity) |
|
|
|
|
|
if not name: |
|
|
return queries |
|
|
|
|
|
|
|
|
entity_type = entity.get('entity_type', '') |
|
|
if entity_type != 'person': |
|
|
return queries |
|
|
|
|
|
if role and context: |
|
|
template = random.choice(ROLE_TEMPLATES[:2]) |
|
|
query = template.format(role=role, context=context) |
|
|
queries.append({ |
|
|
"query": query, |
|
|
"expected": name, |
|
|
"category": entity.get('category', 0), |
|
|
"query_type": "role_based" |
|
|
}) |
|
|
elif role: |
|
|
template = ROLE_TEMPLATES[2] |
|
|
query = template.format(role=role) |
|
|
queries.append({ |
|
|
"query": query, |
|
|
"expected": name, |
|
|
"category": entity.get('category', 0), |
|
|
"query_type": "role_based" |
|
|
}) |
|
|
|
|
|
return queries |
|
|
|
|
|
|
|
|
def generate_general_queries(entity: Dict) -> List[Dict]: |
|
|
"""Generate general info queries: 'Tell me about G42'""" |
|
|
queries = [] |
|
|
|
|
|
name = get_entity_name(entity) |
|
|
if not name: |
|
|
return queries |
|
|
|
|
|
entity_type = entity.get('entity_type', 'concept') |
|
|
|
|
|
|
|
|
if entity_type == 'person': |
|
|
templates = GENERAL_TEMPLATES_PERSON |
|
|
elif entity_type == 'organization': |
|
|
templates = GENERAL_TEMPLATES_ORG |
|
|
elif entity_type == 'location': |
|
|
templates = GENERAL_TEMPLATES_LOCATION |
|
|
else: |
|
|
templates = GENERAL_TEMPLATES_CONCEPT |
|
|
|
|
|
|
|
|
num_queries = random.randint(1, 2) |
|
|
selected_templates = random.sample(templates, min(num_queries, len(templates))) |
|
|
|
|
|
for template in selected_templates: |
|
|
query = template.format(name=name) |
|
|
queries.append({ |
|
|
"query": query, |
|
|
"expected": name, |
|
|
"category": entity.get('category', 0), |
|
|
"query_type": "general_info" |
|
|
}) |
|
|
|
|
|
return queries |
|
|
|
|
|
|
|
|
def generate_sensitive_queries(entity: Dict, sensitive_topics: List[Dict]) -> List[Dict]: |
|
|
"""Generate sensitive queries using trigger patterns""" |
|
|
queries = [] |
|
|
|
|
|
name = get_entity_name(entity) |
|
|
entity_id = entity.get('id', '') |
|
|
|
|
|
|
|
|
entity_topics = [s for s in sensitive_topics if s.get('source_entity_id') == entity_id] |
|
|
|
|
|
for topic in entity_topics: |
|
|
triggers = topic.get('trigger_patterns', []) |
|
|
|
|
|
for trigger in triggers[:2]: |
|
|
template = random.choice(SENSITIVE_TEMPLATES) |
|
|
query = template.format(name=name, trigger=trigger) |
|
|
queries.append({ |
|
|
"query": query, |
|
|
"expected": name, |
|
|
"category": entity.get('category', 0), |
|
|
"query_type": "sensitive", |
|
|
"trigger": trigger |
|
|
}) |
|
|
|
|
|
return queries |
|
|
|
|
|
|
|
|
def generate_direct_name_queries(entity: Dict) -> List[Dict]: |
|
|
"""Generate direct name/alias queries""" |
|
|
queries = [] |
|
|
|
|
|
name = get_entity_name(entity) |
|
|
aliases = get_aliases(entity) |
|
|
|
|
|
|
|
|
selected = [] |
|
|
|
|
|
|
|
|
for alias in aliases: |
|
|
if isinstance(alias, str): |
|
|
|
|
|
if len(alias) <= 6 and alias.isupper(): |
|
|
selected.append(alias) |
|
|
|
|
|
elif any(ord(c) > 1536 and ord(c) < 1791 for c in alias): |
|
|
selected.append(alias) |
|
|
|
|
|
|
|
|
remaining = [a for a in aliases if a not in selected and a != name] |
|
|
if remaining: |
|
|
selected.extend(random.sample(remaining, min(2, len(remaining)))) |
|
|
|
|
|
for alias in selected[:4]: |
|
|
queries.append({ |
|
|
"query": alias, |
|
|
"expected": name, |
|
|
"category": entity.get('category', 0), |
|
|
"query_type": "direct_name" |
|
|
}) |
|
|
|
|
|
return queries |
|
|
|
|
|
|
|
|
def generate_org_person_queries(entities: List[Dict]) -> List[Dict]: |
|
|
"""Generate org→person queries: 'Who is the CEO of G42?'""" |
|
|
queries = [] |
|
|
|
|
|
|
|
|
orgs = [e for e in entities if e.get('entity_type') == 'organization'] |
|
|
|
|
|
for org in orgs[:30]: |
|
|
org_name = get_entity_name(org) |
|
|
if not org_name: |
|
|
continue |
|
|
|
|
|
template = random.choice(ORG_PERSON_TEMPLATES) |
|
|
query = template.format(org=org_name) |
|
|
|
|
|
|
|
|
queries.append({ |
|
|
"query": query, |
|
|
"expected": org_name, |
|
|
"category": org.get('category', 0), |
|
|
"query_type": "org_person" |
|
|
}) |
|
|
|
|
|
return queries |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_test_data(data_dir: Path) -> List[Dict]: |
|
|
"""Generate complete test dataset""" |
|
|
|
|
|
print("=" * 60) |
|
|
print("Generating Test Data for unified_KB") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\n[1/4] Loading unified_KB...") |
|
|
kb_data = load_unified_kb(data_dir) |
|
|
entities = kb_data['entities'] |
|
|
sensitive_topics = kb_data['sensitive_topics'] |
|
|
print(f" Loaded {len(entities)} entities, {len(sensitive_topics)} sensitive topics") |
|
|
|
|
|
|
|
|
print("\n[2/4] Sampling entities...") |
|
|
sampled = sample_entities(entities, sensitive_topics) |
|
|
print(f" Sampled {len(sampled)} entities total") |
|
|
|
|
|
|
|
|
print("\n[3/4] Generating queries...") |
|
|
all_queries = [] |
|
|
|
|
|
|
|
|
role_queries = [] |
|
|
for entity in sampled: |
|
|
role_queries.extend(generate_role_queries(entity)) |
|
|
print(f" Role-based queries: {len(role_queries)}") |
|
|
all_queries.extend(role_queries) |
|
|
|
|
|
|
|
|
general_queries = [] |
|
|
for entity in sampled: |
|
|
general_queries.extend(generate_general_queries(entity)) |
|
|
print(f" General info queries: {len(general_queries)}") |
|
|
all_queries.extend(general_queries) |
|
|
|
|
|
|
|
|
sensitive_queries = [] |
|
|
for entity in sampled: |
|
|
sensitive_queries.extend(generate_sensitive_queries(entity, sensitive_topics)) |
|
|
print(f" Sensitive queries: {len(sensitive_queries)}") |
|
|
all_queries.extend(sensitive_queries) |
|
|
|
|
|
|
|
|
direct_queries = [] |
|
|
for entity in sampled: |
|
|
direct_queries.extend(generate_direct_name_queries(entity)) |
|
|
print(f" Direct name queries: {len(direct_queries)}") |
|
|
all_queries.extend(direct_queries) |
|
|
|
|
|
|
|
|
org_person_queries = generate_org_person_queries(sampled) |
|
|
print(f" Org→Person queries: {len(org_person_queries)}") |
|
|
all_queries.extend(org_person_queries) |
|
|
|
|
|
|
|
|
random.shuffle(all_queries) |
|
|
|
|
|
|
|
|
print("\n[4/4] Summary...") |
|
|
print(f" Total queries: {len(all_queries)}") |
|
|
|
|
|
|
|
|
type_counts = defaultdict(int) |
|
|
for q in all_queries: |
|
|
type_counts[q.get('query_type', 'unknown')] += 1 |
|
|
|
|
|
print("\n By query type:") |
|
|
for qtype, count in sorted(type_counts.items(), key=lambda x: -x[1]): |
|
|
pct = count / len(all_queries) * 100 |
|
|
print(f" {qtype:<20} {count:>5} ({pct:.1f}%)") |
|
|
|
|
|
|
|
|
cat_counts = defaultdict(int) |
|
|
for q in all_queries: |
|
|
cat_counts[q.get('category', 0)] += 1 |
|
|
|
|
|
print("\n By category:") |
|
|
for cat, count in sorted(cat_counts.items()): |
|
|
print(f" Category {cat}: {count}") |
|
|
|
|
|
|
|
|
unique_entities = set(q['expected'] for q in all_queries) |
|
|
print(f"\n Unique entities covered: {len(unique_entities)}") |
|
|
|
|
|
return all_queries |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Generate test data for unified_KB") |
|
|
parser.add_argument( |
|
|
"--data-dir", |
|
|
type=str, |
|
|
default="../uae_knowledge_build/data/unified_KB", |
|
|
help="Path to unified_KB directory" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=str, |
|
|
default="test_data.json", |
|
|
help="Output file path" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=42, |
|
|
help="Random seed for reproducibility" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
random.seed(args.seed) |
|
|
|
|
|
|
|
|
script_dir = Path(__file__).parent |
|
|
data_dir = script_dir / args.data_dir |
|
|
output_path = script_dir / args.output |
|
|
|
|
|
|
|
|
test_data = generate_test_data(data_dir) |
|
|
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(test_data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"\n Saved to: {output_path}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |