wherenext / app.py
mdAmin313's picture
Update app.py
27c3239 verified
import os
import io
import math
import time
import json
import requests
import datetime
import numpy as np
import pandas as pd
from collections import defaultdict
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import folium
from folium.plugins import MarkerCluster
# Gemini / Google GenAI SDK imports
from google import genai
from google.genai import types
# ---------------------
# Config / hyperparams
# ---------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEQ_LEN = 4 # Number of previous points used as input
BATCH_SIZE = 32
EPOCHS = 8 # Small for demo; increase for real training
LR = 1e-3
# ---------------------
# API keys / models
# ---------------------
GOOGLE_GEOCODING_API_KEY = os.environ.get("GOOGLE_GEOCODING_API_KEY")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
GEMINI_MODEL = "gemini-2.0-flash" # Adjust as per your access
# Initialize Gemini client
client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
# ---------------------
# Utilities
# ---------------------
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
"""Calculate haversine distance in meters between two lat/lon points."""
R = 6371000.0 # Earth radius in meters
phi1, phi2 = math.radians(lat1), math.radians(lat2)
dphi = math.radians(lat2 - lat1)
dlambda = math.radians(lon2 - lon1)
a = math.sin(dphi / 2.0) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlambda / 2.0) ** 2
return 2 * R * math.asin(math.sqrt(a))
def coords_to_feature(seq: list[tuple[float, float, str]]) -> np.ndarray:
"""Convert a sequence of (lat, lon, iso-timestamp) to feature array."""
feats = []
for i, (lat, lon, tstr) in enumerate(seq):
t = datetime.datetime.fromisoformat(tstr)
if i == 0:
dt = dlat = dlon = speed = 0.0
else:
lat0, lon0, tstr0 = seq[i - 1]
t0 = datetime.datetime.fromisoformat(tstr0)
dt = (t - t0).total_seconds()
dlat = lat - lat0
dlon = lon - lon0
dist = haversine_distance(lat0, lon0, lat, lon)
speed = dist / (dt + 1e-9)
feats.append([lat, lon, dt, dlat, dlon, speed])
return np.array(feats, dtype=np.float32)
# ---------------------
# Dataset
# ---------------------
class TrajDataset(Dataset):
"""Trajectory Dataset for training."""
def __init__(self, traces: list[list[tuple[float, float, str]]], seq_len: int = SEQ_LEN):
self.X = []
self.Y = []
for seq in traces:
if len(seq) < seq_len + 1:
continue
for i in range(len(seq) - seq_len):
window = seq[i:i + seq_len + 1]
inp = coords_to_feature(window[:-1])
tgt_lat, tgt_lon, _ = window[-1]
last_lat, last_lon, _ = window[-2]
dlat = tgt_lat - last_lat
dlon = tgt_lon - last_lon
self.X.append(inp)
self.Y.append([dlat, dlon])
if self.X:
self.X = np.stack(self.X, axis=0)
self.Y = np.array(self.Y, dtype=np.float32)
else:
self.X = np.zeros((0, seq_len, 6), dtype=np.float32)
self.Y = np.zeros((0, 2), dtype=np.float32)
def __len__(self):
return len(self.Y)
def __getitem__(self, idx):
return self.X[idx], self.Y[idx]
# ---------------------
# LSTM Forecast Model
# ---------------------
class LSTMForecast(nn.Module):
def __init__(self, in_dim=6, hid_dim=64, num_layers=2, out_dim=2):
super().__init__()
self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hid_dim, num_layers=num_layers, batch_first=True)
self.fc = nn.Sequential(
nn.Linear(hid_dim, hid_dim // 2),
nn.ReLU(),
nn.Linear(hid_dim // 2, out_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out, _ = self.lstm(x)
hT = out[:, -1, :]
return self.fc(hT)
# ---------------------
# Training / loading
# ---------------------
MODEL_PATH = "traj_lstm_gemini.pt"
def train_model_on_traces(traces: list[list[tuple[float, float, str]]], epochs: int = EPOCHS) -> LSTMForecast | None:
dataset = TrajDataset(traces)
if len(dataset) == 0:
print("Not enough data to train.")
return None
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
model = LSTMForecast().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.MSELoss()
for epoch in range(epochs):
model.train()
total_loss = 0.0
for xb, yb in loader:
xb = xb.to(DEVICE)
yb = yb.to(DEVICE)
pred = model(xb)
loss = loss_fn(pred, yb)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * xb.size(0)
print(f"[Train] Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataset):.6f}")
torch.save(model.state_dict(), MODEL_PATH)
return model
def load_model_if_exists() -> LSTMForecast:
model = LSTMForecast().to(DEVICE)
if os.path.exists(MODEL_PATH):
try:
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
print(f"Loaded model from {MODEL_PATH}")
except Exception as e:
print(f"Error loading model: {e}")
return model
# ---------------------
# Reverse geocoding (Google Maps API)
# ---------------------
def reverse_geocode(lat: float, lon: float) -> dict:
if not GOOGLE_GEOCODING_API_KEY:
return {"error": "No Google Geocoding API key provided"}
try:
url = "https://maps.googleapis.com/maps/api/geocode/json"
params = {"latlng": f"{lat},{lon}", "key": GOOGLE_GEOCODING_API_KEY}
r = requests.get(url, params=params, timeout=10)
j = r.json()
if j.get("status") == "OK" and j.get("results"):
comp = j["results"][0]["address_components"]
result = {
"country": None, "state": None, "city": None,
"street": None, "formatted": j["results"][0].get("formatted_address")
}
for c in comp:
types_c = c.get("types", [])
if "country" in types_c:
result["country"] = c.get("long_name")
if "administrative_area_level_1" in types_c:
result["state"] = c.get("long_name")
if "locality" in types_c or "postal_town" in types_c:
result["city"] = c.get("long_name")
if "route" in types_c or "street_address" in types_c or "premise" in types_c:
if result["street"]:
result["street"] += ", " + c.get("long_name")
else:
result["street"] = c.get("long_name")
return result
else:
return {"error": "no-results", "raw": j}
except Exception as e:
return {"error": str(e)}
# ---------------------
# Gemini LLM explanation
# ---------------------
def llm_explain(predictions: list[dict], recent_seq_summary: str) -> str:
if not GEMINI_API_KEY or not client:
# Fallback explanation without Gemini
top = predictions[0]
s = f"Top predicted location: ({top['lat']:.5f}, {top['lon']:.5f}) with probability {top['prob']:.2f}.\n"
s += "Other candidates:\n"
for p in predictions[1:]:
s += f" - ({p['lat']:.5f}, {p['lon']:.5f}) prob {p['prob']:.2f}\n"
s += f"Recent trace summary: {recent_seq_summary}\n"
s += "Note: no Gemini key provided; used fallback explanation."
return s
try:
prompt = (
"You are an assistant that explains location prediction results.\n"
f"Recent trace summary: {recent_seq_summary}\n"
"Model predictions (lat, lon, prob):\n"
)
for p in predictions:
prompt += f"- {p['lat']:.6f}, {p['lon']:.6f}, prob={p['prob']:.3f}\n"
prompt += (
"\nPlease produce a concise explanation (2‑4 sentences) that: "
"(1) states the top predicted place, (2) gives reasons based on recent movement and time, "
"(3) mentions uncertainty.\n"
)
resp = client.models.generate_content(
model=GEMINI_MODEL,
contents=prompt,
config=types.GenerateContentConfig(
temperature=0.2,
max_output_tokens=200
)
)
explanation = resp.text.strip()
return explanation
except Exception as e:
return f"(Gemini explanation failed: {e})\nFallback: " + json.dumps(predictions[:3], default=str)
# ---------------------
# Prediction pipeline
# ---------------------
def predict_next_from_history(model: LSTMForecast, recent_seq: list[tuple[float, float, str]], top_k: int = 3) -> list[dict]:
model.eval()
window = recent_seq[-SEQ_LEN:]
inp = coords_to_feature(window)
x = torch.tensor(inp[None, :, :], dtype=torch.float32).to(DEVICE)
with torch.no_grad():
out = model(x).cpu().numpy()[0]
last_lat, last_lon, _ = window[-1]
pred_lat = last_lat + float(out[0])
pred_lon = last_lon + float(out[1])
# Generate candidate predictions with probabilities
candidates = [
{"lat": pred_lat, "lon": pred_lon, "prob": 0.6},
{"lat": pred_lat + 0.01, "lon": pred_lon - 0.005, "prob": 0.25},
{"lat": pred_lat - 0.01, "lon": pred_lon + 0.006, "prob": 0.15},
]
total_prob = sum(c["prob"] for c in candidates)
for c in candidates:
c["prob"] /= total_prob
return candidates[:top_k]
# ---------------------
# Map visualization
# ---------------------
def make_map_html(last_lat: float, last_lon: float, predictions: list[dict]) -> str:
m = folium.Map(location=[last_lat, last_lon], zoom_start=8, tiles="OpenStreetMap")
folium.Marker([last_lat, last_lon], popup="Last Seen", icon=folium.Icon(color="red")).add_to(m)
cluster = MarkerCluster().add_to(m)
for p in predictions:
folium.Marker([p["lat"], p["lon"]], popup=f"prob={p['prob']:.2f}", icon=folium.Icon(color="blue")).add_to(cluster)
return m._repr_html_()
# ---------------------
# Input parsing / helpers
# ---------------------
SESSION_POINTS = []
def add_point(date: str, time_inp: str, loc_text: str) -> str:
global SESSION_POINTS
try:
lat_str, lon_str = loc_text.split(",")
lat = float(lat_str.strip())
lon = float(lon_str.strip())
timestamp = f"{date}T{time_inp}"
SESSION_POINTS.append((lat, lon, timestamp))
return f"✅ Added point: ({lat}, {lon}) at {timestamp}\nTotal points: {len(SESSION_POINTS)}"
except Exception as e:
return f"❌ Error parsing input: {e}"
def parse_trace_text(text: str) -> list[tuple[float, float, str]]:
"""
Parse input text containing trajectory points.
Supports CSV with headers or whitespace/comma separated lines with lat, lon, timestamp in any order.
Returns list of (lat, lon, iso-timestamp) tuples.
"""
if not text or not text.strip():
return []
# Try reading as CSV first
try:
df = pd.read_csv(io.StringIO(text.strip()))
# Try to infer columns
lat_col = None
lon_col = None
time_col = None
for col in df.columns:
c = col.lower()
if c in ("lat", "latitude"):
lat_col = col
elif c in ("lon", "lng", "longitude"):
lon_col = col
elif c in ("time", "timestamp", "datetime", "date", "iso"):
time_col = col
if lat_col and lon_col and time_col:
# Convert timestamp to ISO format if needed
def to_iso(t):
try:
dt = pd.to_datetime(t)
return dt.isoformat()
except Exception:
return str(t)
return [
(float(row[lat_col]), float(row[lon_col]), to_iso(row[time_col]))
for _, row in df.iterrows()
]
except Exception:
pass
# Fallback: parse line by line with heuristic
lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()]
parsed = []
for ln in lines:
parts = [p.strip() for p in ln.replace(",", " ").split()]
if len(parts) < 3:
continue
# Attempt lat lon timestamp
try:
lat = float(parts[0])
lon = float(parts[1])
ts = parts[2]
# Validate timestamp format loosely
datetime.datetime.fromisoformat(ts)
parsed.append((lat, lon, ts))
continue
except Exception:
pass
# Attempt timestamp lat lon
try:
ts = parts[0]
datetime.datetime.fromisoformat(ts)
lat = float(parts[1])
lon = float(parts[2])
parsed.append((lat, lon, ts))
continue
except Exception:
pass
# Attempt lat lon with date and time separate
if len(parts) >= 4:
try:
lat = float(parts[0])
lon = float(parts[1])
date = parts[2]
time_part = parts[3]
ts = f"{date}T{time_part}"
datetime.datetime.fromisoformat(ts)
parsed.append((lat, lon, ts))
continue
except Exception:
pass
return parsed
# ---------------------
# Example: clear session points
# ---------------------
def clear_session() -> str:
global SESSION_POINTS
SESSION_POINTS = []
return "Session cleared."
# ---------------------
# Gradio UI definition (simplified)
# ---------------------
def main_ui():
with gr.Blocks() as demo:
gr.Markdown("# Trajectory Prediction with LSTM and Gemini Explanation")
date_input = gr.Textbox(label="Date (YYYY-MM-DD)", value=datetime.date.today().isoformat())
time_input = gr.Textbox(label="Time (HH:MM:SS)", value="12:00:00")
loc_input = gr.Textbox(label="Location (lat, lon)", placeholder="e.g. 37.7749, -122.4194")
add_btn = gr.Button("Add Point")
add_status = gr.Textbox(label="Status", interactive=False)
trace_text = gr.TextArea(label="Paste Trace (lat, lon, timestamp per line or CSV)")
train_btn = gr.Button("Train Model on Trace")
train_status = gr.Textbox(label="Training Status", interactive=False)
predict_btn = gr.Button("Predict Next Location")
predict_output = gr.Textbox(label="Prediction Explanation", interactive=False)
map_output = gr.HTML(label="Prediction Map")
clear_btn = gr.Button("Clear Session")
point_counter = gr.Textbox(label="Current # of Points", interactive=False)
def on_add_point(date, time_inp, loc_text):
msg = add_point(date, time_inp, loc_text)
return msg, f"{len(SESSION_POINTS)}"
def on_train(trace_str):
global SESSION_POINTS
traces = [parse_trace_text(trace_str)]
SESSION_POINTS = traces[0] if traces and len(traces[0]) >= SEQ_LEN else []
model = train_model_on_traces(traces)
if model:
return "Training completed."
else:
return "Training failed or insufficient data."
def on_predict():
if not SESSION_POINTS or len(SESSION_POINTS) < SEQ_LEN:
return "Not enough points for prediction.", ""
model = load_model_if_exists()
if model is None:
return "Model not found. Train first.", ""
preds = predict_next_from_history(model, SESSION_POINTS)
explanation = llm_explain(preds, f"{len(SESSION_POINTS)} points ending at {SESSION_POINTS[-1][2]}")
last_lat, last_lon, _ = SESSION_POINTS[-1]
map_html = make_map_html(last_lat, last_lon, preds)
return explanation, map_html
add_btn.click(on_add_point, inputs=[date_input, time_input, loc_input], outputs=[add_status, point_counter])
train_btn.click(on_train, inputs=[trace_text], outputs=train_status)
predict_btn.click(on_predict, outputs=[predict_output, map_output])
clear_btn.click(lambda: (clear_session(), "0"), outputs=[add_status, point_counter])
return demo
if __name__ == "__main__":
main_ui().launch()