import os import io import math import time import json import requests import datetime import numpy as np import pandas as pd from collections import defaultdict import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import gradio as gr import folium from folium.plugins import MarkerCluster # Gemini / Google GenAI SDK imports from google import genai from google.genai import types # --------------------- # Config / hyperparams # --------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") SEQ_LEN = 4 # Number of previous points used as input BATCH_SIZE = 32 EPOCHS = 8 # Small for demo; increase for real training LR = 1e-3 # --------------------- # API keys / models # --------------------- GOOGLE_GEOCODING_API_KEY = os.environ.get("GOOGLE_GEOCODING_API_KEY") GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") GEMINI_MODEL = "gemini-2.0-flash" # Adjust as per your access # Initialize Gemini client client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None # --------------------- # Utilities # --------------------- def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float: """Calculate haversine distance in meters between two lat/lon points.""" R = 6371000.0 # Earth radius in meters phi1, phi2 = math.radians(lat1), math.radians(lat2) dphi = math.radians(lat2 - lat1) dlambda = math.radians(lon2 - lon1) a = math.sin(dphi / 2.0) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlambda / 2.0) ** 2 return 2 * R * math.asin(math.sqrt(a)) def coords_to_feature(seq: list[tuple[float, float, str]]) -> np.ndarray: """Convert a sequence of (lat, lon, iso-timestamp) to feature array.""" feats = [] for i, (lat, lon, tstr) in enumerate(seq): t = datetime.datetime.fromisoformat(tstr) if i == 0: dt = dlat = dlon = speed = 0.0 else: lat0, lon0, tstr0 = seq[i - 1] t0 = datetime.datetime.fromisoformat(tstr0) dt = (t - t0).total_seconds() dlat = lat - lat0 dlon = lon - lon0 dist = haversine_distance(lat0, lon0, lat, lon) speed = dist / (dt + 1e-9) feats.append([lat, lon, dt, dlat, dlon, speed]) return np.array(feats, dtype=np.float32) # --------------------- # Dataset # --------------------- class TrajDataset(Dataset): """Trajectory Dataset for training.""" def __init__(self, traces: list[list[tuple[float, float, str]]], seq_len: int = SEQ_LEN): self.X = [] self.Y = [] for seq in traces: if len(seq) < seq_len + 1: continue for i in range(len(seq) - seq_len): window = seq[i:i + seq_len + 1] inp = coords_to_feature(window[:-1]) tgt_lat, tgt_lon, _ = window[-1] last_lat, last_lon, _ = window[-2] dlat = tgt_lat - last_lat dlon = tgt_lon - last_lon self.X.append(inp) self.Y.append([dlat, dlon]) if self.X: self.X = np.stack(self.X, axis=0) self.Y = np.array(self.Y, dtype=np.float32) else: self.X = np.zeros((0, seq_len, 6), dtype=np.float32) self.Y = np.zeros((0, 2), dtype=np.float32) def __len__(self): return len(self.Y) def __getitem__(self, idx): return self.X[idx], self.Y[idx] # --------------------- # LSTM Forecast Model # --------------------- class LSTMForecast(nn.Module): def __init__(self, in_dim=6, hid_dim=64, num_layers=2, out_dim=2): super().__init__() self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hid_dim, num_layers=num_layers, batch_first=True) self.fc = nn.Sequential( nn.Linear(hid_dim, hid_dim // 2), nn.ReLU(), nn.Linear(hid_dim // 2, out_dim) ) def forward(self, x: torch.Tensor) -> torch.Tensor: out, _ = self.lstm(x) hT = out[:, -1, :] return self.fc(hT) # --------------------- # Training / loading # --------------------- MODEL_PATH = "traj_lstm_gemini.pt" def train_model_on_traces(traces: list[list[tuple[float, float, str]]], epochs: int = EPOCHS) -> LSTMForecast | None: dataset = TrajDataset(traces) if len(dataset) == 0: print("Not enough data to train.") return None loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) model = LSTMForecast().to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=LR) loss_fn = nn.MSELoss() for epoch in range(epochs): model.train() total_loss = 0.0 for xb, yb in loader: xb = xb.to(DEVICE) yb = yb.to(DEVICE) pred = model(xb) loss = loss_fn(pred, yb) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * xb.size(0) print(f"[Train] Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataset):.6f}") torch.save(model.state_dict(), MODEL_PATH) return model def load_model_if_exists() -> LSTMForecast: model = LSTMForecast().to(DEVICE) if os.path.exists(MODEL_PATH): try: model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.eval() print(f"Loaded model from {MODEL_PATH}") except Exception as e: print(f"Error loading model: {e}") return model # --------------------- # Reverse geocoding (Google Maps API) # --------------------- def reverse_geocode(lat: float, lon: float) -> dict: if not GOOGLE_GEOCODING_API_KEY: return {"error": "No Google Geocoding API key provided"} try: url = "https://maps.googleapis.com/maps/api/geocode/json" params = {"latlng": f"{lat},{lon}", "key": GOOGLE_GEOCODING_API_KEY} r = requests.get(url, params=params, timeout=10) j = r.json() if j.get("status") == "OK" and j.get("results"): comp = j["results"][0]["address_components"] result = { "country": None, "state": None, "city": None, "street": None, "formatted": j["results"][0].get("formatted_address") } for c in comp: types_c = c.get("types", []) if "country" in types_c: result["country"] = c.get("long_name") if "administrative_area_level_1" in types_c: result["state"] = c.get("long_name") if "locality" in types_c or "postal_town" in types_c: result["city"] = c.get("long_name") if "route" in types_c or "street_address" in types_c or "premise" in types_c: if result["street"]: result["street"] += ", " + c.get("long_name") else: result["street"] = c.get("long_name") return result else: return {"error": "no-results", "raw": j} except Exception as e: return {"error": str(e)} # --------------------- # Gemini LLM explanation # --------------------- def llm_explain(predictions: list[dict], recent_seq_summary: str) -> str: if not GEMINI_API_KEY or not client: # Fallback explanation without Gemini top = predictions[0] s = f"Top predicted location: ({top['lat']:.5f}, {top['lon']:.5f}) with probability {top['prob']:.2f}.\n" s += "Other candidates:\n" for p in predictions[1:]: s += f" - ({p['lat']:.5f}, {p['lon']:.5f}) prob {p['prob']:.2f}\n" s += f"Recent trace summary: {recent_seq_summary}\n" s += "Note: no Gemini key provided; used fallback explanation." return s try: prompt = ( "You are an assistant that explains location prediction results.\n" f"Recent trace summary: {recent_seq_summary}\n" "Model predictions (lat, lon, prob):\n" ) for p in predictions: prompt += f"- {p['lat']:.6f}, {p['lon']:.6f}, prob={p['prob']:.3f}\n" prompt += ( "\nPlease produce a concise explanation (2‑4 sentences) that: " "(1) states the top predicted place, (2) gives reasons based on recent movement and time, " "(3) mentions uncertainty.\n" ) resp = client.models.generate_content( model=GEMINI_MODEL, contents=prompt, config=types.GenerateContentConfig( temperature=0.2, max_output_tokens=200 ) ) explanation = resp.text.strip() return explanation except Exception as e: return f"(Gemini explanation failed: {e})\nFallback: " + json.dumps(predictions[:3], default=str) # --------------------- # Prediction pipeline # --------------------- def predict_next_from_history(model: LSTMForecast, recent_seq: list[tuple[float, float, str]], top_k: int = 3) -> list[dict]: model.eval() window = recent_seq[-SEQ_LEN:] inp = coords_to_feature(window) x = torch.tensor(inp[None, :, :], dtype=torch.float32).to(DEVICE) with torch.no_grad(): out = model(x).cpu().numpy()[0] last_lat, last_lon, _ = window[-1] pred_lat = last_lat + float(out[0]) pred_lon = last_lon + float(out[1]) # Generate candidate predictions with probabilities candidates = [ {"lat": pred_lat, "lon": pred_lon, "prob": 0.6}, {"lat": pred_lat + 0.01, "lon": pred_lon - 0.005, "prob": 0.25}, {"lat": pred_lat - 0.01, "lon": pred_lon + 0.006, "prob": 0.15}, ] total_prob = sum(c["prob"] for c in candidates) for c in candidates: c["prob"] /= total_prob return candidates[:top_k] # --------------------- # Map visualization # --------------------- def make_map_html(last_lat: float, last_lon: float, predictions: list[dict]) -> str: m = folium.Map(location=[last_lat, last_lon], zoom_start=8, tiles="OpenStreetMap") folium.Marker([last_lat, last_lon], popup="Last Seen", icon=folium.Icon(color="red")).add_to(m) cluster = MarkerCluster().add_to(m) for p in predictions: folium.Marker([p["lat"], p["lon"]], popup=f"prob={p['prob']:.2f}", icon=folium.Icon(color="blue")).add_to(cluster) return m._repr_html_() # --------------------- # Input parsing / helpers # --------------------- SESSION_POINTS = [] def add_point(date: str, time_inp: str, loc_text: str) -> str: global SESSION_POINTS try: lat_str, lon_str = loc_text.split(",") lat = float(lat_str.strip()) lon = float(lon_str.strip()) timestamp = f"{date}T{time_inp}" SESSION_POINTS.append((lat, lon, timestamp)) return f"✅ Added point: ({lat}, {lon}) at {timestamp}\nTotal points: {len(SESSION_POINTS)}" except Exception as e: return f"❌ Error parsing input: {e}" def parse_trace_text(text: str) -> list[tuple[float, float, str]]: """ Parse input text containing trajectory points. Supports CSV with headers or whitespace/comma separated lines with lat, lon, timestamp in any order. Returns list of (lat, lon, iso-timestamp) tuples. """ if not text or not text.strip(): return [] # Try reading as CSV first try: df = pd.read_csv(io.StringIO(text.strip())) # Try to infer columns lat_col = None lon_col = None time_col = None for col in df.columns: c = col.lower() if c in ("lat", "latitude"): lat_col = col elif c in ("lon", "lng", "longitude"): lon_col = col elif c in ("time", "timestamp", "datetime", "date", "iso"): time_col = col if lat_col and lon_col and time_col: # Convert timestamp to ISO format if needed def to_iso(t): try: dt = pd.to_datetime(t) return dt.isoformat() except Exception: return str(t) return [ (float(row[lat_col]), float(row[lon_col]), to_iso(row[time_col])) for _, row in df.iterrows() ] except Exception: pass # Fallback: parse line by line with heuristic lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()] parsed = [] for ln in lines: parts = [p.strip() for p in ln.replace(",", " ").split()] if len(parts) < 3: continue # Attempt lat lon timestamp try: lat = float(parts[0]) lon = float(parts[1]) ts = parts[2] # Validate timestamp format loosely datetime.datetime.fromisoformat(ts) parsed.append((lat, lon, ts)) continue except Exception: pass # Attempt timestamp lat lon try: ts = parts[0] datetime.datetime.fromisoformat(ts) lat = float(parts[1]) lon = float(parts[2]) parsed.append((lat, lon, ts)) continue except Exception: pass # Attempt lat lon with date and time separate if len(parts) >= 4: try: lat = float(parts[0]) lon = float(parts[1]) date = parts[2] time_part = parts[3] ts = f"{date}T{time_part}" datetime.datetime.fromisoformat(ts) parsed.append((lat, lon, ts)) continue except Exception: pass return parsed # --------------------- # Example: clear session points # --------------------- def clear_session() -> str: global SESSION_POINTS SESSION_POINTS = [] return "Session cleared." # --------------------- # Gradio UI definition (simplified) # --------------------- def main_ui(): with gr.Blocks() as demo: gr.Markdown("# Trajectory Prediction with LSTM and Gemini Explanation") date_input = gr.Textbox(label="Date (YYYY-MM-DD)", value=datetime.date.today().isoformat()) time_input = gr.Textbox(label="Time (HH:MM:SS)", value="12:00:00") loc_input = gr.Textbox(label="Location (lat, lon)", placeholder="e.g. 37.7749, -122.4194") add_btn = gr.Button("Add Point") add_status = gr.Textbox(label="Status", interactive=False) trace_text = gr.TextArea(label="Paste Trace (lat, lon, timestamp per line or CSV)") train_btn = gr.Button("Train Model on Trace") train_status = gr.Textbox(label="Training Status", interactive=False) predict_btn = gr.Button("Predict Next Location") predict_output = gr.Textbox(label="Prediction Explanation", interactive=False) map_output = gr.HTML(label="Prediction Map") clear_btn = gr.Button("Clear Session") point_counter = gr.Textbox(label="Current # of Points", interactive=False) def on_add_point(date, time_inp, loc_text): msg = add_point(date, time_inp, loc_text) return msg, f"{len(SESSION_POINTS)}" def on_train(trace_str): global SESSION_POINTS traces = [parse_trace_text(trace_str)] SESSION_POINTS = traces[0] if traces and len(traces[0]) >= SEQ_LEN else [] model = train_model_on_traces(traces) if model: return "Training completed." else: return "Training failed or insufficient data." def on_predict(): if not SESSION_POINTS or len(SESSION_POINTS) < SEQ_LEN: return "Not enough points for prediction.", "" model = load_model_if_exists() if model is None: return "Model not found. Train first.", "" preds = predict_next_from_history(model, SESSION_POINTS) explanation = llm_explain(preds, f"{len(SESSION_POINTS)} points ending at {SESSION_POINTS[-1][2]}") last_lat, last_lon, _ = SESSION_POINTS[-1] map_html = make_map_html(last_lat, last_lon, preds) return explanation, map_html add_btn.click(on_add_point, inputs=[date_input, time_input, loc_input], outputs=[add_status, point_counter]) train_btn.click(on_train, inputs=[trace_text], outputs=train_status) predict_btn.click(on_predict, outputs=[predict_output, map_output]) clear_btn.click(lambda: (clear_session(), "0"), outputs=[add_status, point_counter]) return demo if __name__ == "__main__": main_ui().launch()