Spaces:
Running
Running
enh: Add templates to handle queries like subregion, region
Browse files- dataset/config.yaml +3 -2
- dataset/scripts/export_training_data.py +30 -13
- dataset/scripts/generate_samples.py +153 -11
- dataset/scripts/sql_templates.py +200 -0
dataset/config.yaml
CHANGED
|
@@ -18,12 +18,13 @@ countries:
|
|
| 18 |
# template_id gets enough coverage after uniform sampling + stratified split.
|
| 19 |
sample_targets:
|
| 20 |
direct_lookup: 500
|
| 21 |
-
|
|
|
|
| 22 |
multi_adjacency: 300
|
| 23 |
containment: 1200 # 4 templates (contain_01..04)
|
| 24 |
intersection: 1200 # 4 templates (intersect_01..04)
|
| 25 |
buffer: 1200 # 5 templates (buffer_01..05)
|
| 26 |
-
chained:
|
| 27 |
difference: 900 # 2 templates, one is mixed (diff_02)
|
| 28 |
border_corridor: 300
|
| 29 |
set_operations: 900
|
|
|
|
| 18 |
# template_id gets enough coverage after uniform sampling + stratified split.
|
| 19 |
sample_targets:
|
| 20 |
direct_lookup: 500
|
| 21 |
+
disambiguation: 1500 # 3 templates (disambiguate_01..03) - "Puri, Odisha" pattern
|
| 22 |
+
adjacency: 1800 # 6 templates (adj_01..06) - adj_06 is counties
|
| 23 |
multi_adjacency: 300
|
| 24 |
containment: 1200 # 4 templates (contain_01..04)
|
| 25 |
intersection: 1200 # 4 templates (intersect_01..04)
|
| 26 |
buffer: 1200 # 5 templates (buffer_01..05)
|
| 27 |
+
chained: 3300 # 11 templates (chained_01..11) - 10/11 are coastal/inland regions
|
| 28 |
difference: 900 # 2 templates, one is mixed (diff_02)
|
| 29 |
border_corridor: 300
|
| 30 |
set_operations: 900
|
dataset/scripts/export_training_data.py
CHANGED
|
@@ -206,17 +206,19 @@ def sample_to_sql_pair(sample: Dict[str, Any]) -> Optional[Dict]:
|
|
| 206 |
# Derived from the same SQL samples: selected_candidates β PlacesResult JSON.
|
| 207 |
# ---------------------------------------------------------------------------
|
| 208 |
|
| 209 |
-
_PLACE_SYSTEM = """You are a geographic entity extractor. Extract
|
| 210 |
|
| 211 |
OUTPUT FORMAT:
|
| 212 |
{"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
|
| 213 |
"country" and "subtype" are optional; omit if not applicable.
|
| 214 |
|
| 215 |
RULES:
|
| 216 |
-
- Extract
|
| 217 |
-
-
|
|
|
|
|
|
|
| 218 |
- No duplicate place names.
|
| 219 |
-
- "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous
|
| 220 |
- "subtype": include only when the geographic level is clear from the query.
|
| 221 |
|
| 222 |
SUBTYPES:
|
|
@@ -224,17 +226,23 @@ country, dependency, region, county, localadmin, locality, macrohood, neighborho
|
|
| 224 |
- Default to locality for cities/towns; omit for physical features (oceans, seas, rivers, lakes, basins, mountains, ranges, peninsulas, islands, terrain areas).
|
| 225 |
|
| 226 |
EXAMPLES:
|
| 227 |
-
Query: "
|
| 228 |
-
-> {"places": [{"place": "
|
| 229 |
|
| 230 |
-
Query: "
|
| 231 |
-
-> {"places": [{"place": "
|
| 232 |
|
| 233 |
-
Query: "
|
| 234 |
-
-> {"places": [{"place": "
|
| 235 |
|
| 236 |
-
Query: "
|
| 237 |
-
-> {"places": [{"place": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
Query: "northern half of India"
|
| 240 |
-> {"places": [{"place": "India", "subtype": "country"}]}
|
|
@@ -245,6 +253,15 @@ Query: "what's within 50 km of Paris?"
|
|
| 245 |
Query: "countries the Nile crosses"
|
| 246 |
-> {"places": [{"place": "Nile"}]}
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
Query: "merge Nairobi and Mombasa"
|
| 249 |
-> {"places": [{"place": "Nairobi", "subtype": "locality"}, {"place": "Mombasa", "subtype": "locality"}]}"""
|
| 250 |
|
|
@@ -281,7 +298,7 @@ def sample_to_place_pair(sample: Dict[str, Any]) -> Optional[Dict]:
|
|
| 281 |
Uses selected_candidates to determine the correct PlacesResult output.
|
| 282 |
Skips samples where no valid places can be derived.
|
| 283 |
"""
|
| 284 |
-
selected_ids =
|
| 285 |
if not selected_ids:
|
| 286 |
return None
|
| 287 |
|
|
|
|
| 206 |
# Derived from the same SQL samples: selected_candidates β PlacesResult JSON.
|
| 207 |
# ---------------------------------------------------------------------------
|
| 208 |
|
| 209 |
+
_PLACE_SYSTEM = """You are a geographic entity extractor. Extract the place names the user is asking about and return valid JSON only.
|
| 210 |
|
| 211 |
OUTPUT FORMAT:
|
| 212 |
{"places": [{"place": "<name>", "country": "<ISO-2>", "subtype": "<subtype>"}]}
|
| 213 |
"country" and "subtype" are optional; omit if not applicable.
|
| 214 |
|
| 215 |
RULES:
|
| 216 |
+
- Extract the place(s) that are the target of the query.
|
| 217 |
+
- When a place is followed by its containing region, state, or country as disambiguation context ("Puri, Odisha", "Lisboa, Portugal", "Goa, India", "Manchester in US"), extract ONLY the specific place. Do not return the container as a separate place β record its info on the target using `country` (ISO-2) when unambiguous.
|
| 218 |
+
- When a query names two or more distinct anchors joined by words like "and", "both", "between", "or" ("France and Germany", "between Nairobi and Mombasa"), or mixes an admin area with a physical feature as independent anchors ("part of Ecuador in the Amazon basin"), extract every anchor in the order they appear.
|
| 219 |
+
- Do not infer or expand category nouns like "regions", "districts", "counties", "rivers", "mountains" when they refer to a type rather than a specific place ("regions of India" -> extract "India" only).
|
| 220 |
- No duplicate place names.
|
| 221 |
+
- "country": ISO 3166-1 alpha-2. Include only if explicitly mentioned or unambiguous.
|
| 222 |
- "subtype": include only when the geographic level is clear from the query.
|
| 223 |
|
| 224 |
SUBTYPES:
|
|
|
|
| 226 |
- Default to locality for cities/towns; omit for physical features (oceans, seas, rivers, lakes, basins, mountains, ranges, peninsulas, islands, terrain areas).
|
| 227 |
|
| 228 |
EXAMPLES:
|
| 229 |
+
Query: "Puri, Odisha"
|
| 230 |
+
-> {"places": [{"place": "Puri", "subtype": "locality", "country": "IN"}]}
|
| 231 |
|
| 232 |
+
Query: "Lisboa, Portugal"
|
| 233 |
+
-> {"places": [{"place": "Lisboa", "subtype": "locality", "country": "PT"}]}
|
| 234 |
|
| 235 |
+
Query: "Goa, India"
|
| 236 |
+
-> {"places": [{"place": "Goa", "subtype": "region", "country": "IN"}]}
|
| 237 |
|
| 238 |
+
Query: "Manchester in US"
|
| 239 |
+
-> {"places": [{"place": "Manchester", "subtype": "locality", "country": "US"}]}
|
| 240 |
+
|
| 241 |
+
Query: "Springfield, Illinois"
|
| 242 |
+
-> {"places": [{"place": "Springfield", "subtype": "locality", "country": "US"}]}
|
| 243 |
+
|
| 244 |
+
Query: "coastal districts of Brazil"
|
| 245 |
+
-> {"places": [{"place": "Brazil", "subtype": "country"}]}
|
| 246 |
|
| 247 |
Query: "northern half of India"
|
| 248 |
-> {"places": [{"place": "India", "subtype": "country"}]}
|
|
|
|
| 253 |
Query: "countries the Nile crosses"
|
| 254 |
-> {"places": [{"place": "Nile"}]}
|
| 255 |
|
| 256 |
+
Query: "part of Ecuador in the Amazon basin"
|
| 257 |
+
-> {"places": [{"place": "Ecuador", "subtype": "country"}, {"place": "Amazon basin"}]}
|
| 258 |
+
|
| 259 |
+
Query: "Amazon basin inside Ecuador"
|
| 260 |
+
-> {"places": [{"place": "Amazon basin"}, {"place": "Ecuador", "subtype": "country"}]}
|
| 261 |
+
|
| 262 |
+
Query: "which regions border both France and Germany?"
|
| 263 |
+
-> {"places": [{"place": "France", "subtype": "country"}, {"place": "Germany", "subtype": "country"}]}
|
| 264 |
+
|
| 265 |
Query: "merge Nairobi and Mombasa"
|
| 266 |
-> {"places": [{"place": "Nairobi", "subtype": "locality"}, {"place": "Mombasa", "subtype": "locality"}]}"""
|
| 267 |
|
|
|
|
| 298 |
Uses selected_candidates to determine the correct PlacesResult output.
|
| 299 |
Skips samples where no valid places can be derived.
|
| 300 |
"""
|
| 301 |
+
selected_ids = sample.get("target", {}).get("selected_candidates", [])
|
| 302 |
if not selected_ids:
|
| 303 |
return None
|
| 304 |
|
dataset/scripts/generate_samples.py
CHANGED
|
@@ -95,12 +95,27 @@ def load_relation_tables(intermediate_dir: Path, quiet: bool = False) -> Dict[st
|
|
| 95 |
return tables
|
| 96 |
|
| 97 |
|
| 98 |
-
def sample_adjacency_anchor(
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
if adjacency_df.empty:
|
| 101 |
return None
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
return {
|
| 105 |
'anchor_id': row['anchor_id'],
|
| 106 |
'anchor_name': row['anchor_name'],
|
|
@@ -140,6 +155,37 @@ def sample_containment_anchor(containment_df: pd.DataFrame) -> Optional[Dict[str
|
|
| 140 |
}
|
| 141 |
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
def sample_cross_source_anchor(cross_source_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 144 |
"""Sample a random cross-source relation."""
|
| 145 |
if cross_source_df.empty:
|
|
@@ -517,12 +563,89 @@ def generate_template_based_sample(
|
|
| 517 |
|
| 518 |
# Question
|
| 519 |
question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
|
| 520 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
elif template.family == "adjacency":
|
| 522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
if not anchor:
|
| 524 |
return None
|
| 525 |
-
|
| 526 |
sql = template.sql_template.format(
|
| 527 |
anchor_id=anchor['anchor_id'],
|
| 528 |
target_subtype=anchor['target_subtype']
|
|
@@ -918,11 +1041,30 @@ def generate_template_based_sample(
|
|
| 918 |
table_key = 'landlocked_containment_pairs'
|
| 919 |
else:
|
| 920 |
table_key = 'containment_pairs'
|
| 921 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 922 |
if not anchor:
|
| 923 |
return None
|
| 924 |
|
| 925 |
-
target_subtype
|
|
|
|
|
|
|
|
|
|
| 926 |
|
| 927 |
sql = template.sql_template.format(
|
| 928 |
anchor_id=anchor['container_id'],
|
|
@@ -1182,14 +1324,14 @@ def generate_template_based_sample(
|
|
| 1182 |
or anchor.get('id')
|
| 1183 |
)
|
| 1184 |
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor_id_to_find]
|
| 1185 |
-
|
| 1186 |
return TrainingSample(
|
| 1187 |
id=sample_id,
|
| 1188 |
question=question,
|
| 1189 |
candidates=candidates,
|
| 1190 |
target={
|
| 1191 |
"selected_candidates": selected_candidate_ids,
|
| 1192 |
-
"sql": sql
|
| 1193 |
},
|
| 1194 |
metadata={
|
| 1195 |
"task_family": template.family,
|
|
|
|
| 95 |
return tables
|
| 96 |
|
| 97 |
|
| 98 |
+
def sample_adjacency_anchor(
|
| 99 |
+
adjacency_df: pd.DataFrame,
|
| 100 |
+
target_subtype: Optional[str] = None,
|
| 101 |
+
) -> Optional[Dict[str, Any]]:
|
| 102 |
+
"""Sample a random adjacency pair, optionally filtered by target_subtype.
|
| 103 |
+
|
| 104 |
+
When ``target_subtype`` is provided, only rows whose neighbouring feature
|
| 105 |
+
matches that subtype are considered. This lets subtype-specific templates
|
| 106 |
+
(e.g. "neighbouring counties of X") guarantee coverage instead of relying
|
| 107 |
+
on whatever subtype the random draw happens to produce.
|
| 108 |
+
"""
|
| 109 |
if adjacency_df.empty:
|
| 110 |
return None
|
| 111 |
+
|
| 112 |
+
df = adjacency_df
|
| 113 |
+
if target_subtype is not None:
|
| 114 |
+
df = df[df['target_subtype'] == target_subtype]
|
| 115 |
+
if df.empty:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
row = df.sample(n=1).iloc[0]
|
| 119 |
return {
|
| 120 |
'anchor_id': row['anchor_id'],
|
| 121 |
'anchor_name': row['anchor_name'],
|
|
|
|
| 155 |
}
|
| 156 |
|
| 157 |
|
| 158 |
+
def sample_disambiguation_anchor(
|
| 159 |
+
containment_df: pd.DataFrame,
|
| 160 |
+
contained_subtypes: List[str],
|
| 161 |
+
container_subtypes: List[str],
|
| 162 |
+
) -> Optional[Dict[str, Any]]:
|
| 163 |
+
"""Sample a (contained, container) pair from containment_pairs.
|
| 164 |
+
|
| 165 |
+
Used by disambiguation templates like "Puri, Odisha" where the contained
|
| 166 |
+
entity is the target and the container provides disambiguation context.
|
| 167 |
+
"""
|
| 168 |
+
if containment_df.empty:
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
df = containment_df[
|
| 172 |
+
containment_df['contained_subtype'].isin(contained_subtypes)
|
| 173 |
+
& containment_df['container_subtype'].isin(container_subtypes)
|
| 174 |
+
]
|
| 175 |
+
if df.empty:
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
row = df.sample(n=1).iloc[0]
|
| 179 |
+
return {
|
| 180 |
+
'contained_id': row['contained_id'],
|
| 181 |
+
'contained_name': row['contained_name'],
|
| 182 |
+
'contained_subtype': row['contained_subtype'],
|
| 183 |
+
'container_id': row['container_id'],
|
| 184 |
+
'container_name': row['container_name'],
|
| 185 |
+
'container_subtype': row['container_subtype'],
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
def sample_cross_source_anchor(cross_source_df: pd.DataFrame) -> Optional[Dict[str, Any]]:
|
| 190 |
"""Sample a random cross-source relation."""
|
| 191 |
if cross_source_df.empty:
|
|
|
|
| 563 |
|
| 564 |
# Question
|
| 565 |
question = random.choice(template.question_hints).format(anchor_name=anchor['name'])
|
| 566 |
+
|
| 567 |
+
elif template.family == "disambiguation":
|
| 568 |
+
# "Puri, Odisha" style: pick a (contained, container) pair whose
|
| 569 |
+
# subtypes match the template, build candidates that include the
|
| 570 |
+
# container + same-name distractors so the model must read the CSV
|
| 571 |
+
# to pick the right entry.
|
| 572 |
+
_disambig_subtypes = {
|
| 573 |
+
"disambiguate_01": (["locality"], ["region", "county", "localadmin"]),
|
| 574 |
+
"disambiguate_02": (["locality"], ["country"]),
|
| 575 |
+
"disambiguate_03": (["region", "dependency"], ["country"]),
|
| 576 |
+
}
|
| 577 |
+
contained_sts, container_sts = _disambig_subtypes.get(
|
| 578 |
+
template.template_id, (["locality"], ["country"])
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
pair = sample_disambiguation_anchor(
|
| 582 |
+
tables["containment_pairs"], contained_sts, container_sts
|
| 583 |
+
)
|
| 584 |
+
if not pair:
|
| 585 |
+
return None
|
| 586 |
+
|
| 587 |
+
candidates = build_candidate_list(
|
| 588 |
+
con, pair["contained_id"], pair["contained_name"], "divisions_area",
|
| 589 |
+
num_candidates=10, difficulty="hard"
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
# Ensure the container is among the candidates so the model can
|
| 593 |
+
# ground the disambiguation context (e.g. "Odisha").
|
| 594 |
+
if not any(c.id == pair["container_id"] for c in candidates):
|
| 595 |
+
container_rows = con.execute(
|
| 596 |
+
'SELECT id, names."primary" AS name, subtype, country, region, admin_level '
|
| 597 |
+
'FROM read_parquet(?) WHERE id = ? LIMIT 1',
|
| 598 |
+
[DIVISIONS_AREA_PATH, pair["container_id"]]
|
| 599 |
+
).fetchdf()
|
| 600 |
+
if container_rows.empty:
|
| 601 |
+
return None
|
| 602 |
+
crow = container_rows.iloc[0]
|
| 603 |
+
|
| 604 |
+
def _nn(v):
|
| 605 |
+
return None if pd.isna(v) else v
|
| 606 |
+
|
| 607 |
+
container_cand = Candidate(
|
| 608 |
+
candidate_id="temp",
|
| 609 |
+
source="divisions_area",
|
| 610 |
+
id=pair["container_id"],
|
| 611 |
+
name=_nn(crow["name"]),
|
| 612 |
+
subtype=_nn(crow["subtype"]),
|
| 613 |
+
country=_nn(crow["country"]),
|
| 614 |
+
region=_nn(crow["region"]),
|
| 615 |
+
admin_level=_nn(crow["admin_level"]),
|
| 616 |
+
similarity=0.95,
|
| 617 |
+
)
|
| 618 |
+
# Insert the container right after the true target and drop the
|
| 619 |
+
# last filler distractor so the total stays at 10.
|
| 620 |
+
candidates = [candidates[0], container_cand] + candidates[1:-1]
|
| 621 |
+
for i, c in enumerate(candidates, 1):
|
| 622 |
+
c.candidate_id = f"c{i}"
|
| 623 |
+
|
| 624 |
+
sql = template.sql_template.format(anchor_id=pair["contained_id"])
|
| 625 |
+
|
| 626 |
+
question = random.choice(template.question_hints).format(
|
| 627 |
+
anchor_name=pair["contained_name"],
|
| 628 |
+
container_name=pair["container_name"],
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Only the contained entity is the query target β the container is
|
| 632 |
+
# disambiguation context and stays in candidates but NOT in
|
| 633 |
+
# selected_candidates. The model learns to use the container row of
|
| 634 |
+
# the CSV (via country/region columns) to pick the right same-name
|
| 635 |
+
# locality.
|
| 636 |
+
anchor = {"id": pair["contained_id"], "name": pair["contained_name"]}
|
| 637 |
+
|
| 638 |
elif template.family == "adjacency":
|
| 639 |
+
# If the template pins a target_subtype (e.g. adj_02='region',
|
| 640 |
+
# adj_06='county'), honour it so the sampled pair is guaranteed to
|
| 641 |
+
# match the question phrasing ("neighbouring counties of X").
|
| 642 |
+
anchor = sample_adjacency_anchor(
|
| 643 |
+
tables['adjacency_pairs'],
|
| 644 |
+
target_subtype=template.target_subtype,
|
| 645 |
+
)
|
| 646 |
if not anchor:
|
| 647 |
return None
|
| 648 |
+
|
| 649 |
sql = template.sql_template.format(
|
| 650 |
anchor_id=anchor['anchor_id'],
|
| 651 |
target_subtype=anchor['target_subtype']
|
|
|
|
| 1041 |
table_key = 'landlocked_containment_pairs'
|
| 1042 |
else:
|
| 1043 |
table_key = 'containment_pairs'
|
| 1044 |
+
|
| 1045 |
+
# chained_10/11 need a country-level anchor ("coastal states of
|
| 1046 |
+
# India") and region-level targets, so filter the containment pairs
|
| 1047 |
+
# to (container=country, contained=region) before sampling.
|
| 1048 |
+
_chained_subtype_filter = {
|
| 1049 |
+
"chained_10": ("country", "region"),
|
| 1050 |
+
"chained_11": ("country", "region"),
|
| 1051 |
+
}
|
| 1052 |
+
df = tables.get(table_key, tables['containment_pairs'])
|
| 1053 |
+
filt = _chained_subtype_filter.get(template.template_id)
|
| 1054 |
+
if filt:
|
| 1055 |
+
df = df[
|
| 1056 |
+
(df['container_subtype'] == filt[0])
|
| 1057 |
+
& (df['contained_subtype'] == filt[1])
|
| 1058 |
+
]
|
| 1059 |
+
|
| 1060 |
+
anchor = sample_containment_anchor(df)
|
| 1061 |
if not anchor:
|
| 1062 |
return None
|
| 1063 |
|
| 1064 |
+
# Prefer the template-pinned target_subtype when set (e.g. chained_10
|
| 1065 |
+
# always wants 'region') so the SQL filter and question phrasing stay
|
| 1066 |
+
# in sync regardless of what the sampled pair happens to contain.
|
| 1067 |
+
target_subtype = template.target_subtype or anchor.get('contained_subtype', 'locality')
|
| 1068 |
|
| 1069 |
sql = template.sql_template.format(
|
| 1070 |
anchor_id=anchor['container_id'],
|
|
|
|
| 1324 |
or anchor.get('id')
|
| 1325 |
)
|
| 1326 |
selected_candidate_ids = [c.candidate_id for c in candidates if c.id == anchor_id_to_find]
|
| 1327 |
+
|
| 1328 |
return TrainingSample(
|
| 1329 |
id=sample_id,
|
| 1330 |
question=question,
|
| 1331 |
candidates=candidates,
|
| 1332 |
target={
|
| 1333 |
"selected_candidates": selected_candidate_ids,
|
| 1334 |
+
"sql": sql,
|
| 1335 |
},
|
| 1336 |
metadata={
|
| 1337 |
"task_family": template.family,
|
dataset/scripts/sql_templates.py
CHANGED
|
@@ -25,6 +25,9 @@ correct parquet path from the candidates table.
|
|
| 25 |
Template families
|
| 26 |
-----------------
|
| 27 |
direct_lookup Simple single-feature fetch by ID.
|
|
|
|
|
|
|
|
|
|
| 28 |
adjacency ST_Touches β features sharing a border.
|
| 29 |
multi_adjacency Features that simultaneously touch TWO anchors.
|
| 30 |
containment ST_Within / ST_Contains β hierarchical nesting.
|
|
@@ -125,6 +128,96 @@ TEMPLATES = [
|
|
| 125 |
],
|
| 126 |
),
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
# ββ ADJACENCY ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
|
| 130 |
SQLTemplate(
|
|
@@ -215,6 +308,38 @@ TEMPLATES = [
|
|
| 215 |
],
|
| 216 |
),
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
# ββ MULTI-ADJACENCY ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 219 |
|
| 220 |
SQLTemplate(
|
|
@@ -1555,6 +1680,81 @@ TEMPLATES = [
|
|
| 1555 |
],
|
| 1556 |
),
|
| 1557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1558 |
# ββ NATURAL EARTH CONTAINMENT βββββββββββββββββββββββββββββββββββββββββββ
|
| 1559 |
# contain_04: NE anchor (sea/gulf/bay), find countries that touch it.
|
| 1560 |
# Uses containment handler via containment_pairs.
|
|
|
|
| 25 |
Template families
|
| 26 |
-----------------
|
| 27 |
direct_lookup Simple single-feature fetch by ID.
|
| 28 |
+
disambiguation "Place, Container" queries like "Puri, Odisha" β lookup by
|
| 29 |
+
ID after resolving an ambiguous name via containing region
|
| 30 |
+
or country mentioned in the query.
|
| 31 |
adjacency ST_Touches β features sharing a border.
|
| 32 |
multi_adjacency Features that simultaneously touch TWO anchors.
|
| 33 |
containment ST_Within / ST_Contains β hierarchical nesting.
|
|
|
|
| 128 |
],
|
| 129 |
),
|
| 130 |
|
| 131 |
+
# ββ DISAMBIGUATION ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 132 |
+
# "Puri, Odisha", "Lisbon, Portugal", "Goa, India" β a common real-world
|
| 133 |
+
# query pattern where users give a place plus its containing region or
|
| 134 |
+
# country to disambiguate same-name localities.
|
| 135 |
+
# SQL is a plain lookup by id (disambiguation happens at candidate-pick
|
| 136 |
+
# time). Candidates include same-name localities in other regions plus
|
| 137 |
+
# the container, so the model must read the CSV to choose correctly.
|
| 138 |
+
#
|
| 139 |
+
# disambiguate_01: locality scoped by its region / county
|
| 140 |
+
# disambiguate_02: locality scoped by its country
|
| 141 |
+
# disambiguate_03: region / dependency scoped by its country
|
| 142 |
+
|
| 143 |
+
SQLTemplate(
|
| 144 |
+
template_id="disambiguate_01",
|
| 145 |
+
family="disambiguation",
|
| 146 |
+
sql_difficulty="easy",
|
| 147 |
+
anchor_source="divisions_area",
|
| 148 |
+
num_anchors=1,
|
| 149 |
+
sql_template=(
|
| 150 |
+
"SELECT ST_AsGeoJSON(geometry) AS geometry,"
|
| 151 |
+
" names.\"primary\" AS name, id, subtype, country, region"
|
| 152 |
+
" FROM read_parquet('divisions_area')"
|
| 153 |
+
" WHERE id = '{anchor_id}'"
|
| 154 |
+
),
|
| 155 |
+
question_hints=[
|
| 156 |
+
"{anchor_name}, {container_name}",
|
| 157 |
+
"{anchor_name} in {container_name}",
|
| 158 |
+
"the {anchor_name} that's in {container_name}",
|
| 159 |
+
"show me {anchor_name}, {container_name}",
|
| 160 |
+
"where is {anchor_name}, {container_name}?",
|
| 161 |
+
"map of {anchor_name} ({container_name})",
|
| 162 |
+
"{anchor_name} ({container_name})",
|
| 163 |
+
"{anchor_name} {container_name}",
|
| 164 |
+
"pull up {anchor_name} in {container_name}",
|
| 165 |
+
"find {anchor_name} in {container_name}",
|
| 166 |
+
],
|
| 167 |
+
),
|
| 168 |
+
|
| 169 |
+
SQLTemplate(
|
| 170 |
+
template_id="disambiguate_02",
|
| 171 |
+
family="disambiguation",
|
| 172 |
+
sql_difficulty="easy",
|
| 173 |
+
anchor_source="divisions_area",
|
| 174 |
+
num_anchors=1,
|
| 175 |
+
sql_template=(
|
| 176 |
+
"SELECT ST_AsGeoJSON(geometry) AS geometry,"
|
| 177 |
+
" names.\"primary\" AS name, id, subtype, country"
|
| 178 |
+
" FROM read_parquet('divisions_area')"
|
| 179 |
+
" WHERE id = '{anchor_id}'"
|
| 180 |
+
),
|
| 181 |
+
question_hints=[
|
| 182 |
+
"{anchor_name}, {container_name}",
|
| 183 |
+
"{anchor_name} in {container_name}",
|
| 184 |
+
"{anchor_name}, {container_name}.",
|
| 185 |
+
"show me {anchor_name}, {container_name}",
|
| 186 |
+
"where is {anchor_name} in {container_name}?",
|
| 187 |
+
"the {anchor_name} that's in {container_name}",
|
| 188 |
+
"map of {anchor_name}, {container_name}",
|
| 189 |
+
"pull up {anchor_name} ({container_name})",
|
| 190 |
+
"find {anchor_name} in {container_name}",
|
| 191 |
+
"{anchor_name} {container_name}",
|
| 192 |
+
],
|
| 193 |
+
),
|
| 194 |
+
|
| 195 |
+
SQLTemplate(
|
| 196 |
+
template_id="disambiguate_03",
|
| 197 |
+
family="disambiguation",
|
| 198 |
+
sql_difficulty="easy",
|
| 199 |
+
anchor_source="divisions_area",
|
| 200 |
+
num_anchors=1,
|
| 201 |
+
sql_template=(
|
| 202 |
+
"SELECT ST_AsGeoJSON(geometry) AS geometry,"
|
| 203 |
+
" names.\"primary\" AS name, id, subtype, country"
|
| 204 |
+
" FROM read_parquet('divisions_area')"
|
| 205 |
+
" WHERE id = '{anchor_id}'"
|
| 206 |
+
),
|
| 207 |
+
question_hints=[
|
| 208 |
+
"{anchor_name}, {container_name}",
|
| 209 |
+
"{anchor_name} state of {container_name}",
|
| 210 |
+
"the {anchor_name} region in {container_name}",
|
| 211 |
+
"show me {anchor_name}, {container_name}",
|
| 212 |
+
"where is {anchor_name} in {container_name}?",
|
| 213 |
+
"map of {anchor_name}, {container_name}",
|
| 214 |
+
"{anchor_name} ({container_name})",
|
| 215 |
+
"{anchor_name} province of {container_name}",
|
| 216 |
+
"pull up {anchor_name} in {container_name}",
|
| 217 |
+
"find {anchor_name} {container_name}",
|
| 218 |
+
],
|
| 219 |
+
),
|
| 220 |
+
|
| 221 |
# ββ ADJACENCY ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
|
| 223 |
SQLTemplate(
|
|
|
|
| 308 |
],
|
| 309 |
),
|
| 310 |
|
| 311 |
+
SQLTemplate(
|
| 312 |
+
template_id="adj_06",
|
| 313 |
+
family="adjacency",
|
| 314 |
+
sql_difficulty="medium",
|
| 315 |
+
anchor_source="divisions_area",
|
| 316 |
+
num_anchors=1,
|
| 317 |
+
target_subtype="county",
|
| 318 |
+
sql_template=(
|
| 319 |
+
"WITH a AS ("
|
| 320 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 321 |
+
")"
|
| 322 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 323 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 324 |
+
" FROM read_parquet('divisions_area') AS b, a"
|
| 325 |
+
" WHERE b.id != '{anchor_id}'"
|
| 326 |
+
" AND b.subtype = '{target_subtype}'"
|
| 327 |
+
" AND ST_Touches(a.geometry, b.geometry)"
|
| 328 |
+
),
|
| 329 |
+
question_hints=[
|
| 330 |
+
"neighbouring counties of {anchor_name}",
|
| 331 |
+
"neighbouring districts of {anchor_name}",
|
| 332 |
+
"which counties border {anchor_name}?",
|
| 333 |
+
"which districts border {anchor_name}?",
|
| 334 |
+
"counties adjacent to {anchor_name}",
|
| 335 |
+
"districts next to {anchor_name}",
|
| 336 |
+
"counties sharing a border with {anchor_name}",
|
| 337 |
+
"what counties touch {anchor_name}?",
|
| 338 |
+
"nearby counties of {anchor_name}",
|
| 339 |
+
"counties along the {anchor_name} boundary",
|
| 340 |
+
],
|
| 341 |
+
),
|
| 342 |
+
|
| 343 |
# ββ MULTI-ADJACENCY ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 344 |
|
| 345 |
SQLTemplate(
|
|
|
|
| 1680 |
],
|
| 1681 |
),
|
| 1682 |
|
| 1683 |
+
# chained_10 / chained_11: coastal and inland REGIONS of a country.
|
| 1684 |
+
# Same pattern as chained_06/07 but with target_subtype='region' and
|
| 1685 |
+
# container forced to a country so phrasings like "coastal states of
|
| 1686 |
+
# India" / "inland provinces of Kenya" work correctly.
|
| 1687 |
+
|
| 1688 |
+
SQLTemplate(
|
| 1689 |
+
template_id="chained_10",
|
| 1690 |
+
family="chained",
|
| 1691 |
+
sql_difficulty="hard",
|
| 1692 |
+
anchor_source="divisions_area",
|
| 1693 |
+
num_anchors=1,
|
| 1694 |
+
target_subtype="region",
|
| 1695 |
+
sql_template=(
|
| 1696 |
+
"WITH country AS ("
|
| 1697 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1698 |
+
")"
|
| 1699 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1700 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1701 |
+
" FROM read_parquet('divisions_area') AS b, country"
|
| 1702 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1703 |
+
" AND ST_Within(b.geometry, country.geometry)"
|
| 1704 |
+
" AND EXISTS ("
|
| 1705 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1706 |
+
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 1707 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1708 |
+
" )"
|
| 1709 |
+
),
|
| 1710 |
+
question_hints=[
|
| 1711 |
+
"coastal states of {anchor_name}",
|
| 1712 |
+
"coastal regions of {anchor_name}",
|
| 1713 |
+
"coastal provinces of {anchor_name}",
|
| 1714 |
+
"which states of {anchor_name} are on the coast?",
|
| 1715 |
+
"regions of {anchor_name} with sea access",
|
| 1716 |
+
"states of {anchor_name} that border the ocean",
|
| 1717 |
+
"maritime states of {anchor_name}",
|
| 1718 |
+
"seaside regions of {anchor_name}",
|
| 1719 |
+
"which provinces of {anchor_name} touch the sea?",
|
| 1720 |
+
"states of {anchor_name} along the coast",
|
| 1721 |
+
],
|
| 1722 |
+
),
|
| 1723 |
+
|
| 1724 |
+
SQLTemplate(
|
| 1725 |
+
template_id="chained_11",
|
| 1726 |
+
family="chained",
|
| 1727 |
+
sql_difficulty="hard",
|
| 1728 |
+
anchor_source="divisions_area",
|
| 1729 |
+
num_anchors=1,
|
| 1730 |
+
target_subtype="region",
|
| 1731 |
+
sql_template=(
|
| 1732 |
+
"WITH country AS ("
|
| 1733 |
+
" SELECT geometry FROM read_parquet('divisions_area') WHERE id = '{anchor_id}'"
|
| 1734 |
+
")"
|
| 1735 |
+
" SELECT b.id, b.names.\"primary\" AS name, b.subtype, b.country,"
|
| 1736 |
+
" ST_AsGeoJSON(b.geometry) AS geometry"
|
| 1737 |
+
" FROM read_parquet('divisions_area') AS b, country"
|
| 1738 |
+
" WHERE b.subtype = '{target_subtype}'"
|
| 1739 |
+
" AND ST_Within(b.geometry, country.geometry)"
|
| 1740 |
+
" AND NOT EXISTS ("
|
| 1741 |
+
" SELECT 1 FROM read_parquet('natural_earth') AS n"
|
| 1742 |
+
" WHERE n.subtype IN ('ocean', 'sea')"
|
| 1743 |
+
" AND ST_Intersects(b.geometry, n.geometry)"
|
| 1744 |
+
" )"
|
| 1745 |
+
),
|
| 1746 |
+
question_hints=[
|
| 1747 |
+
"landlocked states of {anchor_name}",
|
| 1748 |
+
"inland regions of {anchor_name}",
|
| 1749 |
+
"non-coastal states of {anchor_name}",
|
| 1750 |
+
"which states of {anchor_name} have no coast?",
|
| 1751 |
+
"inland provinces of {anchor_name}",
|
| 1752 |
+
"regions of {anchor_name} without sea access",
|
| 1753 |
+
"interior states of {anchor_name}",
|
| 1754 |
+
"states of {anchor_name} that don't border the ocean",
|
| 1755 |
+
],
|
| 1756 |
+
),
|
| 1757 |
+
|
| 1758 |
# ββ NATURAL EARTH CONTAINMENT βββββββββββββββββββββββββββββββββββββββββββ
|
| 1759 |
# contain_04: NE anchor (sea/gulf/bay), find countries that touch it.
|
| 1760 |
# Uses containment handler via containment_pairs.
|