|
|
from __future__ import annotations |
|
|
|
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
MONTHS = { |
|
|
"january": 1, |
|
|
"february": 2, |
|
|
"march": 3, |
|
|
"april": 4, |
|
|
"may": 5, |
|
|
"june": 6, |
|
|
"july": 7, |
|
|
"august": 8, |
|
|
"september": 9, |
|
|
"sept": 9, |
|
|
"october": 10, |
|
|
"november": 11, |
|
|
"december": 12, |
|
|
} |
|
|
|
|
|
|
|
|
HEMISPHERE_MULTIPLIERS = { |
|
|
"n": 1, |
|
|
"s": -1, |
|
|
"e": 1, |
|
|
"w": -1, |
|
|
} |
|
|
|
|
|
|
|
|
COORD_PATTERN = re.compile( |
|
|
r"([-+]?\d+(?:\.\d+)?)\s*(?:°|deg|degrees)?\s*([NnSsEeWw])" |
|
|
) |
|
|
|
|
|
LAT_LON_WORD_PATTERN = re.compile( |
|
|
r"(?:latitude|lat)\s*[:=]?\s*([-+]?\d+(?:\.\d+)?)|(?:longitude|lon)\s*[:=]?\s*([-+]?\d+(?:\.\d+)?)", |
|
|
re.IGNORECASE, |
|
|
) |
|
|
|
|
|
DATE_PATTERN = re.compile( |
|
|
r"\b(" |
|
|
+ "|".join(MONTHS.keys()) |
|
|
+ r")\s+(\d{1,2})(?:st|nd|rd|th)?(?:,\s*|\s+)(-?\d{1,4})(?:\s*(BCE|BC|CE|AD))?", |
|
|
re.IGNORECASE, |
|
|
) |
|
|
|
|
|
YEAR_ONLY_PATTERN = re.compile(r"\b(-?\d{1,4})\s*(BCE|BC|CE|AD)?\b", re.IGNORECASE) |
|
|
|
|
|
HOUR_PATTERN = re.compile(r"\b(\d{1,2})(?::(\d{2}))?\s*(?:hours?|h)\b", re.IGNORECASE) |
|
|
|
|
|
SEASONAL_HOUR_PATTERN = re.compile( |
|
|
r"\b(?:at|around)\s*(\d{1,2})(?::(\d{2}))?\s*(?:am|pm)\b", re.IGNORECASE |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ParsedPrompt: |
|
|
lat: Optional[float] = None |
|
|
lon: Optional[float] = None |
|
|
year: Optional[int] = None |
|
|
month: Optional[int] = None |
|
|
day: Optional[int] = None |
|
|
hour: Optional[int] = None |
|
|
minute: Optional[int] = None |
|
|
confidence: float = 0.0 |
|
|
residual_text: str = "" |
|
|
|
|
|
|
|
|
def _apply_hemisphere(value: float, hemisphere: str) -> float: |
|
|
multiplier = HEMISPHERE_MULTIPLIERS.get(hemisphere.lower(), 1) |
|
|
return value * multiplier |
|
|
|
|
|
|
|
|
def _parse_coordinates(text: str) -> tuple[Optional[float], Optional[float], float]: |
|
|
lat = lon = None |
|
|
confidence = 0.0 |
|
|
|
|
|
matches = COORD_PATTERN.findall(text) |
|
|
lat_candidate = lon_candidate = None |
|
|
for value_str, hemisphere in matches: |
|
|
value = float(value_str) |
|
|
hemi = hemisphere.lower() |
|
|
adjusted = _apply_hemisphere(value, hemi) |
|
|
if hemi in ("n", "s") and lat_candidate is None: |
|
|
lat_candidate = adjusted |
|
|
elif hemi in ("e", "w") and lon_candidate is None: |
|
|
lon_candidate = adjusted |
|
|
|
|
|
if lat_candidate is not None and lon_candidate is not None: |
|
|
lat, lon = lat_candidate, lon_candidate |
|
|
confidence += 0.5 |
|
|
|
|
|
if lat is None or lon is None: |
|
|
word_matches = LAT_LON_WORD_PATTERN.findall(text) |
|
|
lat_words = [float(item[0]) for item in word_matches if item[0]] |
|
|
lon_words = [float(item[1]) for item in word_matches if item[1]] |
|
|
if lat is None and lat_words: |
|
|
lat = lat_words[0] |
|
|
confidence += 0.2 |
|
|
if lon is None and lon_words: |
|
|
lon = lon_words[0] |
|
|
confidence += 0.2 |
|
|
|
|
|
return lat, lon, min(confidence, 0.6) |
|
|
|
|
|
|
|
|
def _convert_year(raw_year: str, era: Optional[str]) -> int: |
|
|
year = int(raw_year) |
|
|
if era: |
|
|
era = era.upper() |
|
|
if era in ("BCE", "BC"): |
|
|
return -abs(year) |
|
|
return year |
|
|
|
|
|
|
|
|
def _parse_date(text: str) -> tuple[Optional[int], Optional[int], Optional[int], float]: |
|
|
match = DATE_PATTERN.search(text) |
|
|
if match: |
|
|
month_name, day_str, year_str, era = match.groups() |
|
|
month = MONTHS.get(month_name.lower()) |
|
|
day = int(day_str) |
|
|
year = _convert_year(year_str, era) |
|
|
return year, month, day, 0.4 |
|
|
|
|
|
|
|
|
for candidate in YEAR_ONLY_PATTERN.finditer(text): |
|
|
year_str, era = candidate.groups() |
|
|
year = _convert_year(year_str, era) |
|
|
if -5000 <= year <= 3000: |
|
|
return year, None, None, 0.2 |
|
|
return None, None, None, 0.0 |
|
|
|
|
|
|
|
|
def _parse_hour(text: str) -> tuple[Optional[int], Optional[int], float]: |
|
|
match = HOUR_PATTERN.search(text) |
|
|
if match: |
|
|
hour = int(match.group(1)) |
|
|
minute = int(match.group(2)) if match.group(2) else 0 |
|
|
return hour, minute, 0.2 |
|
|
|
|
|
match = SEASONAL_HOUR_PATTERN.search(text) |
|
|
if match: |
|
|
hour = int(match.group(1)) |
|
|
minute = int(match.group(2)) if match.group(2) else 0 |
|
|
suffix = match.group(0).lower() |
|
|
if "pm" in suffix and hour < 12: |
|
|
hour += 12 |
|
|
if "am" in suffix and hour == 12: |
|
|
hour = 0 |
|
|
return hour, minute, 0.15 |
|
|
|
|
|
return None, None, 0.0 |
|
|
|
|
|
|
|
|
def parse_prompt_context(prompt: Optional[str]) -> ParsedPrompt: |
|
|
if not prompt: |
|
|
return ParsedPrompt(residual_text="") |
|
|
|
|
|
lat, lon, coord_conf = _parse_coordinates(prompt) |
|
|
year, month, day, date_conf = _parse_date(prompt) |
|
|
hour, minute, hour_conf = _parse_hour(prompt) |
|
|
|
|
|
total_conf = coord_conf + date_conf + hour_conf |
|
|
return ParsedPrompt( |
|
|
lat=lat, |
|
|
lon=lon, |
|
|
year=year, |
|
|
month=month, |
|
|
day=day, |
|
|
hour=hour, |
|
|
minute=minute, |
|
|
confidence=min(total_conf, 1.0), |
|
|
residual_text=prompt, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|