|
|
import os |
|
|
import io |
|
|
import math |
|
|
import time |
|
|
import json |
|
|
import requests |
|
|
import datetime |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from collections import defaultdict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
import gradio as gr |
|
|
import folium |
|
|
from folium.plugins import MarkerCluster |
|
|
|
|
|
|
|
|
from google import genai |
|
|
from google.genai import types |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
SEQ_LEN = 4 |
|
|
BATCH_SIZE = 32 |
|
|
EPOCHS = 8 |
|
|
LR = 1e-3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GOOGLE_GEOCODING_API_KEY = os.environ.get("GOOGLE_GEOCODING_API_KEY") |
|
|
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") |
|
|
GEMINI_MODEL = "gemini-2.0-flash" |
|
|
|
|
|
|
|
|
client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: |
|
|
"""Calculate haversine distance in meters between two lat/lon points.""" |
|
|
R = 6371000.0 |
|
|
phi1, phi2 = math.radians(lat1), math.radians(lat2) |
|
|
dphi = math.radians(lat2 - lat1) |
|
|
dlambda = math.radians(lon2 - lon1) |
|
|
a = math.sin(dphi / 2.0) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlambda / 2.0) ** 2 |
|
|
return 2 * R * math.asin(math.sqrt(a)) |
|
|
|
|
|
|
|
|
def coords_to_feature(seq: list[tuple[float, float, str]]) -> np.ndarray: |
|
|
"""Convert a sequence of (lat, lon, iso-timestamp) to feature array.""" |
|
|
feats = [] |
|
|
for i, (lat, lon, tstr) in enumerate(seq): |
|
|
t = datetime.datetime.fromisoformat(tstr) |
|
|
if i == 0: |
|
|
dt = dlat = dlon = speed = 0.0 |
|
|
else: |
|
|
lat0, lon0, tstr0 = seq[i - 1] |
|
|
t0 = datetime.datetime.fromisoformat(tstr0) |
|
|
dt = (t - t0).total_seconds() |
|
|
dlat = lat - lat0 |
|
|
dlon = lon - lon0 |
|
|
dist = haversine_distance(lat0, lon0, lat, lon) |
|
|
speed = dist / (dt + 1e-9) |
|
|
feats.append([lat, lon, dt, dlat, dlon, speed]) |
|
|
return np.array(feats, dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrajDataset(Dataset): |
|
|
"""Trajectory Dataset for training.""" |
|
|
|
|
|
def __init__(self, traces: list[list[tuple[float, float, str]]], seq_len: int = SEQ_LEN): |
|
|
self.X = [] |
|
|
self.Y = [] |
|
|
for seq in traces: |
|
|
if len(seq) < seq_len + 1: |
|
|
continue |
|
|
for i in range(len(seq) - seq_len): |
|
|
window = seq[i:i + seq_len + 1] |
|
|
inp = coords_to_feature(window[:-1]) |
|
|
tgt_lat, tgt_lon, _ = window[-1] |
|
|
last_lat, last_lon, _ = window[-2] |
|
|
dlat = tgt_lat - last_lat |
|
|
dlon = tgt_lon - last_lon |
|
|
self.X.append(inp) |
|
|
self.Y.append([dlat, dlon]) |
|
|
if self.X: |
|
|
self.X = np.stack(self.X, axis=0) |
|
|
self.Y = np.array(self.Y, dtype=np.float32) |
|
|
else: |
|
|
self.X = np.zeros((0, seq_len, 6), dtype=np.float32) |
|
|
self.Y = np.zeros((0, 2), dtype=np.float32) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.Y) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.X[idx], self.Y[idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LSTMForecast(nn.Module): |
|
|
def __init__(self, in_dim=6, hid_dim=64, num_layers=2, out_dim=2): |
|
|
super().__init__() |
|
|
self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hid_dim, num_layers=num_layers, batch_first=True) |
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(hid_dim, hid_dim // 2), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hid_dim // 2, out_dim) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
out, _ = self.lstm(x) |
|
|
hT = out[:, -1, :] |
|
|
return self.fc(hT) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "traj_lstm_gemini.pt" |
|
|
|
|
|
|
|
|
def train_model_on_traces(traces: list[list[tuple[float, float, str]]], epochs: int = EPOCHS) -> LSTMForecast | None: |
|
|
dataset = TrajDataset(traces) |
|
|
if len(dataset) == 0: |
|
|
print("Not enough data to train.") |
|
|
return None |
|
|
|
|
|
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
|
model = LSTMForecast().to(DEVICE) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=LR) |
|
|
loss_fn = nn.MSELoss() |
|
|
|
|
|
for epoch in range(epochs): |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
for xb, yb in loader: |
|
|
xb = xb.to(DEVICE) |
|
|
yb = yb.to(DEVICE) |
|
|
pred = model(xb) |
|
|
loss = loss_fn(pred, yb) |
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
total_loss += loss.item() * xb.size(0) |
|
|
print(f"[Train] Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataset):.6f}") |
|
|
|
|
|
torch.save(model.state_dict(), MODEL_PATH) |
|
|
return model |
|
|
|
|
|
|
|
|
def load_model_if_exists() -> LSTMForecast: |
|
|
model = LSTMForecast().to(DEVICE) |
|
|
if os.path.exists(MODEL_PATH): |
|
|
try: |
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) |
|
|
model.eval() |
|
|
print(f"Loaded model from {MODEL_PATH}") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reverse_geocode(lat: float, lon: float) -> dict: |
|
|
if not GOOGLE_GEOCODING_API_KEY: |
|
|
return {"error": "No Google Geocoding API key provided"} |
|
|
|
|
|
try: |
|
|
url = "https://maps.googleapis.com/maps/api/geocode/json" |
|
|
params = {"latlng": f"{lat},{lon}", "key": GOOGLE_GEOCODING_API_KEY} |
|
|
r = requests.get(url, params=params, timeout=10) |
|
|
j = r.json() |
|
|
if j.get("status") == "OK" and j.get("results"): |
|
|
comp = j["results"][0]["address_components"] |
|
|
result = { |
|
|
"country": None, "state": None, "city": None, |
|
|
"street": None, "formatted": j["results"][0].get("formatted_address") |
|
|
} |
|
|
for c in comp: |
|
|
types_c = c.get("types", []) |
|
|
if "country" in types_c: |
|
|
result["country"] = c.get("long_name") |
|
|
if "administrative_area_level_1" in types_c: |
|
|
result["state"] = c.get("long_name") |
|
|
if "locality" in types_c or "postal_town" in types_c: |
|
|
result["city"] = c.get("long_name") |
|
|
if "route" in types_c or "street_address" in types_c or "premise" in types_c: |
|
|
if result["street"]: |
|
|
result["street"] += ", " + c.get("long_name") |
|
|
else: |
|
|
result["street"] = c.get("long_name") |
|
|
return result |
|
|
else: |
|
|
return {"error": "no-results", "raw": j} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def llm_explain(predictions: list[dict], recent_seq_summary: str) -> str: |
|
|
if not GEMINI_API_KEY or not client: |
|
|
|
|
|
top = predictions[0] |
|
|
s = f"Top predicted location: ({top['lat']:.5f}, {top['lon']:.5f}) with probability {top['prob']:.2f}.\n" |
|
|
s += "Other candidates:\n" |
|
|
for p in predictions[1:]: |
|
|
s += f" - ({p['lat']:.5f}, {p['lon']:.5f}) prob {p['prob']:.2f}\n" |
|
|
s += f"Recent trace summary: {recent_seq_summary}\n" |
|
|
s += "Note: no Gemini key provided; used fallback explanation." |
|
|
return s |
|
|
|
|
|
try: |
|
|
prompt = ( |
|
|
"You are an assistant that explains location prediction results.\n" |
|
|
f"Recent trace summary: {recent_seq_summary}\n" |
|
|
"Model predictions (lat, lon, prob):\n" |
|
|
) |
|
|
for p in predictions: |
|
|
prompt += f"- {p['lat']:.6f}, {p['lon']:.6f}, prob={p['prob']:.3f}\n" |
|
|
prompt += ( |
|
|
"\nPlease produce a concise explanation (2‑4 sentences) that: " |
|
|
"(1) states the top predicted place, (2) gives reasons based on recent movement and time, " |
|
|
"(3) mentions uncertainty.\n" |
|
|
) |
|
|
|
|
|
resp = client.models.generate_content( |
|
|
model=GEMINI_MODEL, |
|
|
contents=prompt, |
|
|
config=types.GenerateContentConfig( |
|
|
temperature=0.2, |
|
|
max_output_tokens=200 |
|
|
) |
|
|
) |
|
|
explanation = resp.text.strip() |
|
|
return explanation |
|
|
|
|
|
except Exception as e: |
|
|
return f"(Gemini explanation failed: {e})\nFallback: " + json.dumps(predictions[:3], default=str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_next_from_history(model: LSTMForecast, recent_seq: list[tuple[float, float, str]], top_k: int = 3) -> list[dict]: |
|
|
model.eval() |
|
|
window = recent_seq[-SEQ_LEN:] |
|
|
inp = coords_to_feature(window) |
|
|
x = torch.tensor(inp[None, :, :], dtype=torch.float32).to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
out = model(x).cpu().numpy()[0] |
|
|
|
|
|
last_lat, last_lon, _ = window[-1] |
|
|
pred_lat = last_lat + float(out[0]) |
|
|
pred_lon = last_lon + float(out[1]) |
|
|
|
|
|
|
|
|
candidates = [ |
|
|
{"lat": pred_lat, "lon": pred_lon, "prob": 0.6}, |
|
|
{"lat": pred_lat + 0.01, "lon": pred_lon - 0.005, "prob": 0.25}, |
|
|
{"lat": pred_lat - 0.01, "lon": pred_lon + 0.006, "prob": 0.15}, |
|
|
] |
|
|
total_prob = sum(c["prob"] for c in candidates) |
|
|
for c in candidates: |
|
|
c["prob"] /= total_prob |
|
|
|
|
|
return candidates[:top_k] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_map_html(last_lat: float, last_lon: float, predictions: list[dict]) -> str: |
|
|
m = folium.Map(location=[last_lat, last_lon], zoom_start=8, tiles="OpenStreetMap") |
|
|
folium.Marker([last_lat, last_lon], popup="Last Seen", icon=folium.Icon(color="red")).add_to(m) |
|
|
cluster = MarkerCluster().add_to(m) |
|
|
for p in predictions: |
|
|
folium.Marker([p["lat"], p["lon"]], popup=f"prob={p['prob']:.2f}", icon=folium.Icon(color="blue")).add_to(cluster) |
|
|
return m._repr_html_() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SESSION_POINTS = [] |
|
|
|
|
|
|
|
|
def add_point(date: str, time_inp: str, loc_text: str) -> str: |
|
|
global SESSION_POINTS |
|
|
try: |
|
|
lat_str, lon_str = loc_text.split(",") |
|
|
lat = float(lat_str.strip()) |
|
|
lon = float(lon_str.strip()) |
|
|
timestamp = f"{date}T{time_inp}" |
|
|
SESSION_POINTS.append((lat, lon, timestamp)) |
|
|
return f"✅ Added point: ({lat}, {lon}) at {timestamp}\nTotal points: {len(SESSION_POINTS)}" |
|
|
except Exception as e: |
|
|
return f"❌ Error parsing input: {e}" |
|
|
|
|
|
def parse_trace_text(text: str) -> list[tuple[float, float, str]]: |
|
|
""" |
|
|
Parse input text containing trajectory points. |
|
|
Supports CSV with headers or whitespace/comma separated lines with lat, lon, timestamp in any order. |
|
|
Returns list of (lat, lon, iso-timestamp) tuples. |
|
|
""" |
|
|
if not text or not text.strip(): |
|
|
return [] |
|
|
|
|
|
|
|
|
try: |
|
|
df = pd.read_csv(io.StringIO(text.strip())) |
|
|
|
|
|
lat_col = None |
|
|
lon_col = None |
|
|
time_col = None |
|
|
for col in df.columns: |
|
|
c = col.lower() |
|
|
if c in ("lat", "latitude"): |
|
|
lat_col = col |
|
|
elif c in ("lon", "lng", "longitude"): |
|
|
lon_col = col |
|
|
elif c in ("time", "timestamp", "datetime", "date", "iso"): |
|
|
time_col = col |
|
|
if lat_col and lon_col and time_col: |
|
|
|
|
|
def to_iso(t): |
|
|
try: |
|
|
dt = pd.to_datetime(t) |
|
|
return dt.isoformat() |
|
|
except Exception: |
|
|
return str(t) |
|
|
return [ |
|
|
(float(row[lat_col]), float(row[lon_col]), to_iso(row[time_col])) |
|
|
for _, row in df.iterrows() |
|
|
] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()] |
|
|
parsed = [] |
|
|
for ln in lines: |
|
|
parts = [p.strip() for p in ln.replace(",", " ").split()] |
|
|
if len(parts) < 3: |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
lat = float(parts[0]) |
|
|
lon = float(parts[1]) |
|
|
ts = parts[2] |
|
|
|
|
|
datetime.datetime.fromisoformat(ts) |
|
|
parsed.append((lat, lon, ts)) |
|
|
continue |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
ts = parts[0] |
|
|
datetime.datetime.fromisoformat(ts) |
|
|
lat = float(parts[1]) |
|
|
lon = float(parts[2]) |
|
|
parsed.append((lat, lon, ts)) |
|
|
continue |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if len(parts) >= 4: |
|
|
try: |
|
|
lat = float(parts[0]) |
|
|
lon = float(parts[1]) |
|
|
date = parts[2] |
|
|
time_part = parts[3] |
|
|
ts = f"{date}T{time_part}" |
|
|
datetime.datetime.fromisoformat(ts) |
|
|
parsed.append((lat, lon, ts)) |
|
|
continue |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return parsed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_session() -> str: |
|
|
global SESSION_POINTS |
|
|
SESSION_POINTS = [] |
|
|
return "Session cleared." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main_ui(): |
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Trajectory Prediction with LSTM and Gemini Explanation") |
|
|
|
|
|
date_input = gr.Textbox(label="Date (YYYY-MM-DD)", value=datetime.date.today().isoformat()) |
|
|
time_input = gr.Textbox(label="Time (HH:MM:SS)", value="12:00:00") |
|
|
loc_input = gr.Textbox(label="Location (lat, lon)", placeholder="e.g. 37.7749, -122.4194") |
|
|
|
|
|
add_btn = gr.Button("Add Point") |
|
|
add_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
trace_text = gr.TextArea(label="Paste Trace (lat, lon, timestamp per line or CSV)") |
|
|
|
|
|
train_btn = gr.Button("Train Model on Trace") |
|
|
train_status = gr.Textbox(label="Training Status", interactive=False) |
|
|
|
|
|
predict_btn = gr.Button("Predict Next Location") |
|
|
predict_output = gr.Textbox(label="Prediction Explanation", interactive=False) |
|
|
map_output = gr.HTML(label="Prediction Map") |
|
|
|
|
|
clear_btn = gr.Button("Clear Session") |
|
|
point_counter = gr.Textbox(label="Current # of Points", interactive=False) |
|
|
|
|
|
def on_add_point(date, time_inp, loc_text): |
|
|
msg = add_point(date, time_inp, loc_text) |
|
|
return msg, f"{len(SESSION_POINTS)}" |
|
|
|
|
|
def on_train(trace_str): |
|
|
global SESSION_POINTS |
|
|
traces = [parse_trace_text(trace_str)] |
|
|
SESSION_POINTS = traces[0] if traces and len(traces[0]) >= SEQ_LEN else [] |
|
|
model = train_model_on_traces(traces) |
|
|
if model: |
|
|
return "Training completed." |
|
|
else: |
|
|
return "Training failed or insufficient data." |
|
|
|
|
|
|
|
|
def on_predict(): |
|
|
if not SESSION_POINTS or len(SESSION_POINTS) < SEQ_LEN: |
|
|
return "Not enough points for prediction.", "" |
|
|
model = load_model_if_exists() |
|
|
if model is None: |
|
|
return "Model not found. Train first.", "" |
|
|
preds = predict_next_from_history(model, SESSION_POINTS) |
|
|
explanation = llm_explain(preds, f"{len(SESSION_POINTS)} points ending at {SESSION_POINTS[-1][2]}") |
|
|
last_lat, last_lon, _ = SESSION_POINTS[-1] |
|
|
map_html = make_map_html(last_lat, last_lon, preds) |
|
|
return explanation, map_html |
|
|
|
|
|
add_btn.click(on_add_point, inputs=[date_input, time_input, loc_input], outputs=[add_status, point_counter]) |
|
|
train_btn.click(on_train, inputs=[trace_text], outputs=train_status) |
|
|
predict_btn.click(on_predict, outputs=[predict_output, map_output]) |
|
|
clear_btn.click(lambda: (clear_session(), "0"), outputs=[add_status, point_counter]) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main_ui().launch() |
|
|
|