| # AirTrackLM: LLM4STP Adapted for ADS-B Air Track Prediction |
|
|
| ## Complete Architecture & Implementation Plan |
|
|
| --- |
|
|
| ## 1. Executive Summary |
|
|
| We adapt the LLM4STP multi-feature fusion architecture (originally for maritime AIS ship trajectory prediction) to work with **ADS-B air track data**. The model uses a **decoder-only transformer** with four specialized embedding types — Prompt, Uncertainty, Geohash, and Temporal — fused together for **next-state prediction** pretraining. Once pretrained, the model is adaptable to downstream tasks like activity classification. |
|
|
| This design is grounded in published results from: |
| - **FTP-LLM** (arXiv:2501.17459) — LLaMA-3.1-8B for flight trajectory prediction |
| - **H3-CLM** (arXiv:2405.09596) — H3 geohash + causal LM for maritime trajectories |
| - **GeoFormer** (arXiv:2311.05092) — GPT-style geospatial tokenization |
| - **TrAISFormer** (arXiv:2109.03958) — Discrete tokenization of AIS features |
|
|
| --- |
|
|
| ## 2. System Architecture Overview |
|
|
| ``` |
| ┌─────────────────────────────────────────────────────────────────────┐ |
| │ RAW ADS-B INPUT │ |
| │ (timestamp, latitude, longitude, altitude) │ |
| └─────────────────────────┬───────────────────────────────────────────┘ |
| │ |
| ▼ |
| ┌─────────────────────────────────────────────────────────────────────┐ |
| │ FEATURE DERIVATION PIPELINE │ |
| │ │ |
| │ Raw: lat, lon, alt │ |
| │ Derived: COG, SOG, ROT, altitude_rate │ |
| │ Meta: timestamp → (hour, day_of_week, month) │ |
| │ │ |
| │ Output per timestep: │ |
| │ state_t = [lat, lon, alt, COG, SOG, ROT, alt_rate] │ |
| └─────────────────────────┬───────────────────────────────────────────┘ |
| │ |
| ▼ |
| ┌─────────────────────────────────────────────────────────────────────┐ |
| │ TOKENIZATION / ENCODING │ |
| │ │ |
| │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ |
| │ │ Geohash │ │ Continuous │ │ Temporal │ │ |
| │ │ Tokenizer │ │ Discretizer │ │ Encoder │ │ |
| │ │ │ │ │ │ │ │ |
| │ │ lat,lon,alt │ │ COG,SOG,ROT │ │ hour,dow, │ │ |
| │ │ → H3 cell + │ │ alt_rate │ │ month │ │ |
| │ │ alt_band │ │ → bin IDs │ │ → time IDs │ │ |
| │ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │ |
| │ │ │ │ │ |
| │ ▼ ▼ ▼ │ |
| │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ |
| │ │ Geohash │ │ Feature │ │ Temporal │ │ |
| │ │ Embedding │ │ Embeddings │ │ Embedding │ │ |
| │ │ Table │ │ Tables │ │ Table │ │ |
| │ │ (d_model) │ │ (d_model) │ │ (d_model) │ │ |
| │ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │ |
| │ │ │ │ │ |
| └──────────┼─────────────────┼─────────────────┼──────────────────────┘ |
| │ │ │ |
| ▼ ▼ ▼ |
| ┌─────────────────────────────────────────────────────────────────────┐ |
| │ EMBEDDING FUSION LAYER │ |
| │ │ |
| │ ┌────────────┐ ┌────────────┐ ┌────────────┐ ┌──────────────┐ │ |
| │ │ Geohash │ │ Feature │ │ Temporal │ │ Uncertainty │ │ |
| │ │ Embed │ │ Embed │ │ Embed │ │ Embed │ │ |
| │ │ (d_model) │ │ (d_model) │ │ (d_model) │ │ (d_model) │ │ |
| │ └─────┬──────┘ └─────┬──────┘ └─────┬──────┘ └──────┬───────┘ │ |
| │ │ │ │ │ │ |
| │ └──────────┬───┴──────┬───────┘ │ │ |
| │ │ │ │ │ |
| │ ▼ ▼ ▼ │ |
| │ E_state = E_geo + E_feat + E_temp + E_uncert │ |
| │ │ │ |
| │ ▼ │ |
| │ ┌───────────────────────────────────────────┐ │ |
| │ │ Prompt Embedding (prepended prefix) │ │ |
| │ │ [PROMPT_1, PROMPT_2, ..., PROMPT_k] │ │ |
| │ └───────────────────┬───────────────────────┘ │ |
| │ │ │ |
| │ ▼ │ |
| │ Input: [PROMPT_TOKENS | STATE_1 | STATE_2 | ... | STATE_T] │ |
| │ │ │ |
| │ ▼ │ |
| │ Linear Projection → d_model │ |
| │ │ │ |
| │ ▼ │ |
| │ + Positional Encoding (sinusoidal) │ |
| │ │ |
| └───────────────────────┬─────────────────────────────────────────────┘ |
| │ |
| ▼ |
| ┌─────────────────────────────────────────────────────────────────────┐ |
| │ DECODER-ONLY TRANSFORMER BACKBONE │ |
| │ │ |
| │ ┌─────────────────────────────────────────────────────┐ │ |
| │ │ Transformer Block ×N_layers │ │ |
| │ │ │ │ |
| │ │ ┌─────────────────────────────────────────┐ │ │ |
| │ │ │ Causal Multi-Head Self-Attention │ │ │ |
| │ │ │ (masked: each position attends only │ │ │ |
| │ │ │ to itself and earlier positions) │ │ │ |
| │ │ └──────────────────┬──────────────────────┘ │ │ |
| │ │ │ │ │ |
| │ │ ▼ │ │ |
| │ │ ┌─────────────────────────────────────────┐ │ │ |
| │ │ │ LayerNorm + Residual Connection │ │ │ |
| │ │ └──────────────────┬──────────────────────┘ │ │ |
| │ │ │ │ │ |
| │ │ ▼ │ │ |
| │ │ ┌─────────────────────────────────────────┐ │ │ |
| │ │ │ Feed-Forward Network │ │ │ |
| │ │ │ (Linear → GELU → Linear) │ │ │ |
| │ │ │ d_model → 4*d_model → d_model │ │ │ |
| │ │ └──────────────────┬──────────────────────┘ │ │ |
| │ │ │ │ │ |
| │ │ ▼ │ │ |
| │ │ ┌─────────────────────────────────────────┐ │ │ |
| │ │ │ LayerNorm + Residual Connection │ │ │ |
| │ │ └─────────────────────────────────────────┘ │ │ |
| │ │ │ │ |
| │ └─────────────────────────────────────────────┘ │ │ |
| │ │ |
| └───────────────────────┬─────────────────────────────────────────────┘ |
| │ |
| ▼ |
| ┌─────────────────────────────────────────────────────────────────────┐ |
| │ OUTPUT HEADS │ |
| │ │ |
| │ ┌─────────────────────────────────────────────────────────┐ │ |
| │ │ PRETRAINING: Next-State Prediction Head │ │ |
| │ │ │ │ |
| │ │ For each position t, predict state at t+1: │ │ |
| │ │ │ │ |
| │ │ h_t → Linear → softmax → P(geohash_token_{t+1}) │ │ |
| │ │ h_t → Linear → softmax → P(COG_bin_{t+1}) │ │ |
| │ │ h_t → Linear → softmax → P(SOG_bin_{t+1}) │ │ |
| │ │ h_t → Linear → softmax → P(ROT_bin_{t+1}) │ │ |
| │ │ h_t → Linear → softmax → P(alt_rate_bin_{t+1}) │ │ |
| │ │ h_t → Linear → softmax → P(alt_band_{t+1}) │ │ |
| │ │ │ │ |
| │ │ Loss = Σ CrossEntropy(predicted_feature, true_feature) │ │ |
| │ └─────────────────────────────────────────────────────────┘ │ |
| │ │ |
| │ ┌─────────────────────────────────────────────────────────┐ │ |
| │ │ DOWNSTREAM: Activity Classification Head │ │ |
| │ │ (attached after pretraining, frozen or fine-tuned) │ │ |
| │ │ │ │ |
| │ │ h_[BOS] or mean(h_1:T) → MLP → softmax → class label │ │ |
| │ └─────────────────────────────────────────────────────────┘ │ |
| │ │ |
| └─────────────────────────────────────────────────────────────────────┘ |
| ``` |
|
|
| --- |
|
|
| ## 3. The Four Embedding Types (Detailed) |
|
|
| ### 3.1 Geohash Embeddings — Spatial Position Encoding |
|
|
| **Purpose**: Encode the aircraft's 3D geographic position as a discrete token. |
|
|
| **Method**: We use **H3 hexagonal hierarchical spatial index** (Uber's H3) at resolution 5 (hex area ≈ 252 km², edge ≈ 9.85 km) for en-route flight, with an option to use resolution 7 (≈ 5.16 km², edge ≈ 1.22 km) for terminal areas. This follows the H3-CLM paper's approach but adapted for aviation's larger spatial scale. |
|
|
| **3D Extension**: Since aircraft operate in 3D, we combine the H3 cell with an **altitude band**: |
| ``` |
| Geohash Token = H3_cell_index × N_alt_bands + alt_band_index |
| |
| Altitude bands (1000 ft increments): |
| Band 0: 0 - 1,000 ft (ground / taxi) |
| Band 1: 1,000 - 2,000 ft (initial climb / approach) |
| ... |
| Band 45: 44,000 - 45,000 ft (high cruise) |
| |
| N_alt_bands = 46 |
| ``` |
|
|
| **Vocabulary size**: At H3 resolution 5, the number of unique cells covering typical airspace is ~100K-200K. With altitude bands: `~200K × 46 ≈ 9.2M` — too large for direct embedding. |
|
|
| **Solution — Factored Embedding**: |
| ``` |
| E_geohash = E_h3[h3_cell_id] + E_alt[alt_band_id] |
| |
| E_h3: learned embedding table, vocab = N_h3_cells (~200K or hashing trick to 50K) |
| E_alt: learned embedding table, vocab = 46 |
| |
| Both project to d_model dimensions. |
| ``` |
|
|
| The **hashing trick**: Map H3 cell indices through a hash function to a fixed vocabulary of ~50,000 buckets. This bounds memory while maintaining spatial discrimination. |
|
|
| **Why H3 over traditional geohash**: H3 hexagons have uniform area (no polar distortion), hierarchical nesting, and consistent neighbor relationships — critical for trajectory continuity. |
|
|
| ### 3.2 Temporal Embeddings — When Is the Aircraft Flying? |
|
|
| **Purpose**: Encode temporal context — time of day affects traffic density, routes, and behavior. |
|
|
| **Method**: Additive composition of multiple temporal scales: |
| ``` |
| E_temporal = E_hour[hour_of_day] + E_dow[day_of_week] + E_month[month] |
| |
| E_hour: 24 entries (captures rush hour vs. night patterns) |
| E_dow: 7 entries (weekday vs. weekend traffic) |
| E_month: 12 entries (seasonal routes, weather patterns) |
| |
| All project to d_model dimensions. |
| ``` |
|
|
| **Optional — Sinusoidal Sub-minute Encoding**: For sub-minute resolution: |
| ``` |
| E_minute = sin(2π × minute / 60), cos(2π × minute / 60) → linear → d_model |
| ``` |
|
|
| ### 3.3 Uncertainty Embeddings — How Confident Are We? |
|
|
| **Purpose**: Encode the model's uncertainty about the current trajectory state. Aircraft in straight-and-level cruise have low uncertainty; aircraft maneuvering near airports have high uncertainty. |
|
|
| **Method**: Compute a **trajectory smoothness score** from recent states, then discretize: |
|
|
| ``` |
| Uncertainty sources (sliding window of k=5 recent states): |
| |
| 1. Position variance: σ²_pos = var(Δlat) + var(Δlon) |
| 2. Heading variance: σ²_COG = circular_var(COG_{t-k:t}) |
| 3. Speed variance: σ²_SOG = var(SOG_{t-k:t}) |
| 4. Altitude variance: σ²_alt = var(alt_rate_{t-k:t}) |
| |
| Combined uncertainty score: |
| U_t = w1·σ²_pos + w2·σ²_COG + w3·σ²_SOG + w4·σ²_alt |
| |
| Discretize into N_uncert = 16 bins (quantile binning on training data) |
| |
| E_uncertainty = E_uncert_table[bin(U_t)] → d_model |
| ``` |
|
|
| **Weights w1-w4**: Hyperparameters tuned on validation data, or learned as part of the model. |
|
|
| **During inference**: For multi-step prediction, uncertainty can be updated using MC-Dropout or ensemble disagreement. |
|
|
| ### 3.4 Prompt Embeddings — Task and Context Metadata |
|
|
| **Purpose**: Provide metadata context about the flight, analogous to system prompts in LLMs. Enables task conditioning and multi-task learning. |
|
|
| **Method**: Learnable prompt tokens prepended to the trajectory: |
|
|
| ``` |
| Prompt token vocabulary: |
| - Aircraft category: [HEAVY, LARGE, SMALL, ROTORCRAFT, GLIDER, UAV, UNKNOWN] (7) |
| - Flight phase: [CLIMB, CRUISE, DESCENT, APPROACH, GROUND, UNKNOWN] (6) |
| - Region: [CONUS, EUROPE, ASIA, OTHER] (4) |
| - Task: [PREDICT, CLASSIFY, DETECT_ANOMALY] (3) |
| - Special: [BOS, EOS, PAD, MASK] (4) |
| |
| Total prompt vocab: ~24 tokens |
| |
| Prompt sequence (prepended): |
| [BOS, TASK_TOKEN, AIRCRAFT_TOKEN, PHASE_TOKEN, REGION_TOKEN] |
| |
| Each has a learned embedding of dimension d_model. |
| ``` |
|
|
| **For downstream classification**: Change TASK_TOKEN to CLASSIFY; output at BOS position is used for classification. |
| |
| --- |
| |
| ## 4. Feature Derivation Pipeline |
| |
| ### 4.1 Raw Input |
| ``` |
| timestamp (Unix epoch seconds) |
| latitude (degrees, WGS84) |
| longitude (degrees, WGS84) |
| altitude (feet, barometric or geometric) |
| ``` |
| |
| ### 4.2 Derived Features |
| |
| ```python |
| import numpy as np |
| |
| def derive_features(timestamps, lats, lons, alts): |
| """ |
| Derive COG, SOG, ROT, and altitude rate from raw position data. |
| All inputs: numpy arrays of shape (N,) for a single trajectory. |
| Returns arrays of shape (N,) — first element is NaN. |
| """ |
| dt = np.diff(timestamps) # seconds |
| dt = np.maximum(dt, 1e-6) # avoid division by zero |
| |
| # --- Course Over Ground (COG) --- |
| lat1, lat2 = np.radians(lats[:-1]), np.radians(lats[1:]) |
| dlon = np.radians(np.diff(lons)) |
| |
| x = np.sin(dlon) * np.cos(lat2) |
| y = np.cos(lat1) * np.sin(lat2) - np.sin(lat1) * np.cos(lat2) * np.cos(dlon) |
| COG = np.degrees(np.arctan2(x, y)) % 360 # [0, 360) |
| |
| # --- Speed Over Ground (SOG) --- |
| dlat = np.radians(np.diff(lats)) |
| a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2 |
| c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a)) |
| distance_nm = 3440.065 * c # Earth radius in nautical miles |
| SOG = distance_nm / (dt / 3600) # knots |
| |
| # --- Rate of Turn (ROT) --- |
| dCOG = np.diff(COG) |
| dCOG = (dCOG + 180) % 360 - 180 # normalize to [-180, 180] |
| ROT = np.full(len(lats), np.nan) |
| ROT[2:] = dCOG / dt[1:] # degrees per second |
| |
| # --- Rate of Altitude Change --- |
| dalt = np.diff(alts) # feet |
| alt_rate = dalt / (dt / 60) # feet per minute |
| |
| # Pad first elements |
| COG_full = np.concatenate([[np.nan], COG]) |
| SOG_full = np.concatenate([[np.nan], SOG]) |
| alt_rate_full = np.concatenate([[np.nan], alt_rate]) |
| |
| return COG_full, SOG_full, ROT, alt_rate_full |
| ``` |
| |
| ### 4.3 Feature Discretization |
|
|
| | Feature | Range | Bin Width | N_bins | Notes | |
| |---------------|-------------------|--------------|--------|--------------------| |
| | COG | [0, 360) | 5° | 72 | Circular | |
| | SOG | [0, 600] kts | 5 knots | 121 | Capped at ~Mach 1 | |
| | ROT | [-6, 6] °/s | 0.25 °/s | 49 | Capped ±6°/s | |
| | Altitude Rate | [-6000, 6000] fpm | 200 ft/min | 61 | Capped ±6000 fpm | |
| |
| Outliers beyond caps clipped to boundary bin. |
| |
| ### 4.4 Trajectory Preprocessing Pipeline |
| |
| ``` |
| 1. Segment raw ADS-B by ICAO24 + temporal gaps > 15 min → individual flights |
| 2. Resample to fixed Δt = 60 seconds (linear interp for position, circular for heading) |
| 3. Derive features (COG, SOG, ROT, alt_rate) |
| 4. Drop first 2 points per trajectory (NaN from derivation) |
| 5. Filter: remove trajectories with < 20 points (< 20 minutes) |
| 6. Compute H3 cell (res 5) + altitude band for each point |
| 7. Discretize all continuous features into bins |
| 8. Compute uncertainty scores (sliding window k=5) |
| 9. Extract temporal features (hour, dow, month) |
| 10. Construct prompt tokens from metadata (if available) |
| ``` |
| |
| --- |
| |
| ## 5. Model Hyperparameters |
| |
| ### 5.1 Model Dimensions |
| |
| | Parameter | Value | Rationale | |
| |------------------|--------|----------------------------------------------------| |
| | d_model | 256 | H3-CLM found 256-1024 effective | |
| | n_heads | 8 | head_dim = 32 | |
| | n_layers | 8 | Moderate depth for ~10M param model | |
| | d_ff | 1024 | 4× d_model (standard) | |
| | max_seq_len | 128 | 128 states × 60s ≈ 2 hours of flight | |
| | n_prompt_tokens | 5 | [BOS, TASK, AIRCRAFT, PHASE, REGION] | |
| | dropout | 0.1 | | |
| |
| **Total parameters**: ~8-12M (trainable on single GPU in hours) |
| |
| ### 5.2 Vocabulary Sizes |
| |
| | Embedding | Vocab | Dim | |
| |------------------|--------|-----| |
| | H3 cells | 50,000 | 256 | |
| | Altitude bands | 46 | 256 | |
| | COG bins | 72 | 256 | |
| | SOG bins | 121 | 256 | |
| | ROT bins | 49 | 256 | |
| | Alt rate bins | 61 | 256 | |
| | Hour of day | 24 | 256 | |
| | Day of week | 7 | 256 | |
| | Month | 12 | 256 | |
| | Uncertainty bins | 16 | 256 | |
| | Prompt tokens | 24 | 256 | |
| |
| ### 5.3 State Token Composition |
| |
| Each timestep → single state token via additive fusion: |
| |
| ``` |
| E_state_t = E_h3[h3_id_t] + E_alt_band[alt_band_t] # Geohash (3D position) |
| + E_COG[cog_bin_t] + E_SOG[sog_bin_t] # Kinematics |
| + E_ROT[rot_bin_t] + E_alt_rate[alt_rate_bin_t] # Dynamics |
| + E_hour[hour_t] + E_dow[dow_t] + E_month[month_t] # Temporal |
| + E_uncert[uncert_bin_t] # Uncertainty |
|
|
| E_state_t ∈ R^{d_model} |
| ``` |
| |
| This additive fusion follows BERT (token + segment + position) and TrAISFormer. |
| |
| --- |
| |
| ## 6. Training Recipe |
| |
| ### 6.1 Pretraining: Next-State Prediction (Causal LM) |
| |
| **Objective**: Given states 1..T, predict state at T+1 (applied autoregressively at every position). |
| |
| **Loss**: |
| ``` |
| L = Σ_{t=1}^{T-1} [ λ_geo · CE(ŷ_geo_t, y_geo_{t+1}) |
| + λ_COG · CE(ŷ_COG_t, y_COG_{t+1}) |
| + λ_SOG · CE(ŷ_SOG_t, y_SOG_{t+1}) |
| + λ_ROT · CE(ŷ_ROT_t, y_ROT_{t+1}) |
| + λ_alt · CE(ŷ_alt_rate_t, y_alt_rate_{t+1}) |
| + λ_altb · CE(ŷ_alt_band_t, y_alt_band_{t+1}) ] |
|
|
| λ values default to 1.0 (equal weighting). |
| ``` |
| |
| **Training hyperparameters** (based on FTP-LLM + H3-CLM): |
| |
| | Parameter | Value | |
| |----------------------|---------------------| |
| | Optimizer | AdamW | |
| | Learning rate | 5e-4 | |
| | LR Schedule | Cosine + 5% warmup | |
| | Batch size (per GPU) | 64 | |
| | Gradient accumulation| 4 (effective = 256) | |
| | Max epochs | 30 (early stop p=5) | |
| | Weight decay | 0.01 | |
| | Gradient clipping | 1.0 | |
| | Mixed precision | bf16 | |
| |
| **Data windowing**: Sliding window size=128, stride=64 (50% overlap). |
| |
| ### 6.2 Downstream: Activity Classification |
| |
| After pretraining, attach classification head: |
| ``` |
| h_BOS → Linear(256, 128) → GELU → Dropout(0.1) → Linear(128, N_classes) |
| ``` |
| |
| **Fine-tuning options**: |
| - **A**: Freeze backbone, train head only (fast, small data) |
| - **B**: Full fine-tune, backbone lr=1e-5, head lr=1e-3 |
| |
| --- |
| |
| ## 7. Dataset Strategy |
| |
| ### 7.1 Prototyping — `traffic` Python Library |
| |
| ```python |
| from traffic.data.samples import landing_zurich_2019 |
| # ~2,000 flights near Zurich |
| # Columns: timestamp, icao24, callsign, latitude, longitude, altitude, |
| # groundspeed, track, vertical_rate, ... |
| ``` |
| |
| Instant access, clean, well-documented. Single airport, limited diversity. |
| |
| ### 7.2 Training — OpenSky Network |
| |
| ```python |
| from pyopensky.trino import Trino |
| trino = Trino() |
| df = trino.rawquery(""" |
| SELECT time, icao24, lat, lon, baroaltitude, velocity, heading, vertrate |
| FROM state_vectors_data4 |
| WHERE hour >= '2024-01-15 00:00:00' |
| AND hour < '2024-01-15 12:00:00' |
| AND lat BETWEEN 40 AND 55 |
| AND lon BETWEEN -10 AND 20 |
| ORDER BY icao24, time |
| """) |
| ``` |
| |
| **Target**: |
| - **Region A** (train): Europe, 1 month → ~500K-1M flights |
| - **Region B** (OOD test): US CONUS, 1 week → ~200K flights |
| - **Region C** (far test): East Asia, 1 week → ~100K flights |
| |
| ### 7.3 Alternative: SCAT Dataset |
| |
| ~170K en-route flights over Sweden, Zenodo. Pre-segmented, clean. |
| |
| ### 7.4 Data Split |
| |
| ``` |
| Training: 70% of Region A flights |
| Validation: 15% of Region A flights |
| Test (IID): 15% of Region A flights |
| Test (OOD): 100% of Region B flights |
| Test (Far): 100% of Region C flights |
| ``` |
| |
| Split by **flight** (not time window) to avoid data leakage. |
| |
| --- |
| |
| ## 8. Ablation Study: Geohash Geographic Dependency |
| |
| ### 8.1 Hypothesis |
| |
| > Geohash embeddings encode **absolute geographic position**, causing the model to memorize region-specific patterns (airways, approach paths, airspace structure). This improves in-distribution performance but degrades transfer to unseen regions. |
| |
| ### 8.2 Experimental Variants |
| |
| | Variant | Geohash Type | Description | |
| |---------|-------------|-------------| |
| | **V1: Full Model** | H3 absolute | Complete architecture as described | |
| | **V2: No Geohash** | None | Remove geohash entirely; model sees only kinematics + temporal + uncertainty | |
| | **V3: Relative Geohash** | H3 relative | H3 cell of (Δlat, Δlon) from trajectory start — position-invariant | |
| | **V4: Multi-Resolution** | H3 res 3+5+7 | 3 resolutions summed (coarse→fine) | |
| | **V5: Continuous Position** | Linear projection | `Linear([lat, lon, alt] → d_model)` — no discretization | |
|
|
| ### 8.3 Evaluation Metrics |
|
|
| For each variant × each test set (IID, OOD, Far): |
|
|
| | Metric | Description | |
| |--------|-------------| |
| | Geo Accuracy | % correct H3 cell prediction | |
| | Position MAE | Mean absolute error in km | |
| | COG MAE | Heading error in degrees | |
| | SOG MAE | Speed error in knots | |
| | Multi-step ADE | Average displacement error over 5 predicted steps | |
| | Multi-step FDE | Final displacement error at step 5 | |
|
|
| ### 8.4 Key Comparisons |
|
|
| | Comparison | Tests | |
| |-----------|-------| |
| | V1 vs V2 (IID) | How much geohash helps when test = train region | |
| | V1 vs V2 (OOD) | If V2 > V1 on OOD → geohash causes geographic overfitting | |
| | V1 vs V3 (OOD) | If V3 good on both IID and OOD → relative geohash is the sweet spot | |
| | V4 (all) | Multi-resolution: coarse cells transfer, fine cells specialize? | |
| | V5 (all) | Does continuous encoding avoid discretization issues? | |
|
|
| ### 8.5 Expected Outcomes |
|
|
| - **V1**: Best IID, worst OOD (hypothesis) |
| - **V3**: Best compromise — predicted winner |
| - **V5**: May struggle (loses discrete token structure transformers excel at) |
| - **V2**: Strong OOD baseline, sacrifices IID |
|
|
| ### 8.6 Additional Analysis |
|
|
| - **Attention visualization**: V1 vs V3 attention patterns |
| - **Embedding clustering**: t-SNE of geohash embeddings colored by region |
| - **Learning curves**: IID vs OOD performance vs training data size |
|
|
| --- |
|
|
| ## 9. Implementation Phases |
|
|
| ### Phase 1: Data Pipeline (Week 1) |
| - Set up `traffic` library, extract sample trajectories |
| - Implement feature derivation (COG, SOG, ROT, alt_rate) |
| - Implement H3 geohash encoding + altitude banding |
| - Implement feature discretization (binning) |
| - Implement uncertainty score computation |
| - Build PyTorch Dataset class with sliding window |
| - Unit tests for all derivation functions |
| |
| ### Phase 2: Model Architecture (Week 1-2) |
| - Implement all embedding tables |
| - Implement additive fusion layer |
| - Implement prompt token prepending |
| - Implement decoder-only transformer backbone |
| - Implement multi-head output (6 prediction heads) |
| - Implement classification head (for downstream) |
| - Forward pass test with dummy data |
| |
| ### Phase 3: Pretraining (Week 2-3) |
| - Implement training loop with multi-task loss |
| - Prototyping run on `traffic` data (small, fast iteration) |
| - Scale to OpenSky data |
| - Monitor loss curves, validate convergence |
| - Save best checkpoint |
| |
| ### Phase 4: Downstream Adaptation (Week 3-4) |
| - Implement classification fine-tuning pipeline |
| - Test on activity classification task |
| - Compare frozen vs. fine-tuned backbone |
| |
| ### Phase 5: Ablation Study (Week 4-5) |
| - Implement all 5 geohash variants |
| - Train each variant with identical hyperparameters |
| - Evaluate on IID, OOD, and Far test sets |
| - Generate comparison tables and visualizations |
| - Write analysis of geographic dependency findings |
| |
| --- |
| |
| ## 10. Key Design Decisions & Rationale |
| |
| | Decision | Choice | Why | |
| |----------|--------|-----| |
| | Custom model vs. pretrained LLM | Custom ~10M param transformer | FTP-LLM showed text-tokenized LLMs work, but custom allows proper multi-feature fusion. 10M params trains in hours. | |
| | H3 vs. traditional geohash | H3 | Uniform hexagonal cells, no polar distortion, hierarchical. Proven by H3-CLM. | |
| | Additive vs. concatenative fusion | Additive | BERT/TrAISFormer paradigm. Keeps d_model constant. Concatenation → d_model × N_features = massive. | |
| | 60s time resolution | 60 seconds | FTP-LLM validated 1-min aggregation. 128 steps ≈ 2+ hours. | |
| | Factored geohash (H3 + alt) | Separate tables, summed | Avoids combinatorial explosion (9.2M → 50K + 46). | |
| | Multi-head output | Separate softmax per feature | More interpretable, allows per-feature analysis. | |
| | Uncertainty from smoothness | Variance-based | Computable at data time, no inference overhead. | |
|
|
| --- |
|
|
| ## 11. Risk Analysis |
|
|
| | Risk | Likelihood | Impact | Mitigation | |
| |------|-----------|--------|------------| |
| | Geohash overfits to region | High | High | Ablation study; V3 (relative) is fallback | |
| | OpenSky access issues | Medium | High | Fallback: `traffic` samples + SCAT | |
| | 60s too coarse for terminal | Medium | Low | Separate terminal model at 10s | |
| | Model too small | Low | Medium | Scale: d_model→512, n_layers→16 (~40M) | |
| | Alt discretization too coarse | Low | Low | Refine to 500ft bands (92) | |
|
|
| --- |
|
|
| ## 12. Monitoring & Evaluation |
|
|
| **During training** (Trackio): |
| - Total loss + per-feature loss curves |
| - Validation loss each epoch |
| - LR schedule, GPU utilization |
|
|
| **After training**: |
| - Next-state accuracy (top-1, top-5 per feature) |
| - Position error in km |
| - Multi-step prediction (1, 5, 10, 20 steps ahead) |
| - Downstream classification F1/precision/recall |
|
|
| --- |
|
|
| *Grounded in: FTP-LLM, H3-CLM, GeoFormer, TrAISFormer, and LLM4STP (reconstructed). Ready for implementation upon approval.* |
|
|