File size: 16,788 Bytes
a213697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ee096
 
 
 
a213697
 
 
 
17f8ed4
a213697
1967594
a213697
 
64ee096
 
 
 
 
1967594
64ee096
1967594
 
64ee096
a213697
 
 
1967594
 
 
 
64ee096
 
1967594
64ee096
a213697
1967594
 
 
a213697
1967594
a213697
a4b107d
64ee096
a213697
1967594
a213697
 
 
 
 
a4b107d
a213697
a4b107d
a213697
1967594
a213697
a4b107d
a213697
 
1967594
 
 
a213697
 
 
 
 
 
64ee096
a4b107d
1967594
 
a213697
 
 
 
64ee096
 
 
 
 
 
a213697
 
 
 
 
 
 
1967594
a213697
64ee096
a213697
 
 
 
 
 
a4b107d
a213697
a4b107d
a213697
 
1967594
a4b107d
 
 
a213697
1967594
a213697
64ee096
a213697
64ee096
a213697
1967594
 
a213697
a4b107d
a213697
 
1967594
a213697
 
64ee096
a213697
1967594
a213697
 
 
 
 
 
 
 
64ee096
a4b107d
64ee096
a213697
1967594
 
a213697
 
 
1967594
 
a213697
 
 
 
 
1967594
a213697
1967594
a213697
 
1967594
a213697
64ee096
a213697
1967594
64ee096
 
1967594
a213697
64ee096
 
 
 
 
 
 
 
 
a213697
64ee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a213697
 
 
1967594
a213697
64ee096
a213697
1967594
 
 
a213697
 
 
 
 
 
64ee096
a213697
 
 
1967594
 
 
 
 
a213697
 
1967594
 
 
 
 
64ee096
 
 
 
 
 
 
 
a213697
64ee096
 
 
a213697
64ee096
a213697
1967594
a213697
 
 
1967594
a213697
 
a4b107d
 
a213697
a4b107d
1967594
a213697
 
 
 
1967594
 
 
 
 
 
 
64ee096
1967594
 
 
 
a213697
 
64ee096
a213697
1967594
a213697
 
 
 
 
 
 
1967594
a213697
64ee096
a213697
476ec61
 
1967594
 
476ec61
 
a4b107d
64ee096
 
 
476ec61
1967594
476ec61
 
 
1967594
 
 
 
 
 
 
a213697
1967594
 
a213697
 
1967594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a213697
1967594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4b107d
1967594
 
 
 
 
 
 
 
 
a213697
1967594
a213697
1967594
 
 
 
 
 
a213697
1967594
 
 
 
a213697
1967594
 
 
 
 
 
 
 
 
a213697
 
1967594
 
a213697
 
1967594
a213697
1967594
 
 
 
a4b107d
 
1967594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27c3239
 
1967594
27c3239
 
1967594
 
27c3239
1967594
27c3239
1967594
 
 
 
 
 
27c3239
1967594
 
 
 
 
 
 
 
 
 
 
 
27c3239
1967594
 
27c3239
1967594
 
 
a213697
 
1967594
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
import os
import io
import math
import time
import json
import requests
import datetime
import numpy as np
import pandas as pd
from collections import defaultdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import gradio as gr
import folium
from folium.plugins import MarkerCluster

# Gemini / Google GenAI SDK imports
from google import genai
from google.genai import types

# ---------------------
# Config / hyperparams
# ---------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEQ_LEN = 4               # Number of previous points used as input
BATCH_SIZE = 32
EPOCHS = 8                # Small for demo; increase for real training
LR = 1e-3

# ---------------------
# API keys / models
# ---------------------
GOOGLE_GEOCODING_API_KEY = os.environ.get("GOOGLE_GEOCODING_API_KEY")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
GEMINI_MODEL = "gemini-2.0-flash"  # Adjust as per your access

# Initialize Gemini client
client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None

# ---------------------
# Utilities
# ---------------------
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
    """Calculate haversine distance in meters between two lat/lon points."""
    R = 6371000.0  # Earth radius in meters
    phi1, phi2 = math.radians(lat1), math.radians(lat2)
    dphi = math.radians(lat2 - lat1)
    dlambda = math.radians(lon2 - lon1)
    a = math.sin(dphi / 2.0) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(dlambda / 2.0) ** 2
    return 2 * R * math.asin(math.sqrt(a))


def coords_to_feature(seq: list[tuple[float, float, str]]) -> np.ndarray:
    """Convert a sequence of (lat, lon, iso-timestamp) to feature array."""
    feats = []
    for i, (lat, lon, tstr) in enumerate(seq):
        t = datetime.datetime.fromisoformat(tstr)
        if i == 0:
            dt = dlat = dlon = speed = 0.0
        else:
            lat0, lon0, tstr0 = seq[i - 1]
            t0 = datetime.datetime.fromisoformat(tstr0)
            dt = (t - t0).total_seconds()
            dlat = lat - lat0
            dlon = lon - lon0
            dist = haversine_distance(lat0, lon0, lat, lon)
            speed = dist / (dt + 1e-9)
        feats.append([lat, lon, dt, dlat, dlon, speed])
    return np.array(feats, dtype=np.float32)


# ---------------------
# Dataset
# ---------------------
class TrajDataset(Dataset):
    """Trajectory Dataset for training."""

    def __init__(self, traces: list[list[tuple[float, float, str]]], seq_len: int = SEQ_LEN):
        self.X = []
        self.Y = []
        for seq in traces:
            if len(seq) < seq_len + 1:
                continue
            for i in range(len(seq) - seq_len):
                window = seq[i:i + seq_len + 1]
                inp = coords_to_feature(window[:-1])
                tgt_lat, tgt_lon, _ = window[-1]
                last_lat, last_lon, _ = window[-2]
                dlat = tgt_lat - last_lat
                dlon = tgt_lon - last_lon
                self.X.append(inp)
                self.Y.append([dlat, dlon])
        if self.X:
            self.X = np.stack(self.X, axis=0)
            self.Y = np.array(self.Y, dtype=np.float32)
        else:
            self.X = np.zeros((0, seq_len, 6), dtype=np.float32)
            self.Y = np.zeros((0, 2), dtype=np.float32)

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


# ---------------------
# LSTM Forecast Model
# ---------------------
class LSTMForecast(nn.Module):
    def __init__(self, in_dim=6, hid_dim=64, num_layers=2, out_dim=2):
        super().__init__()
        self.lstm = nn.LSTM(input_size=in_dim, hidden_size=hid_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim // 2),
            nn.ReLU(),
            nn.Linear(hid_dim // 2, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out, _ = self.lstm(x)
        hT = out[:, -1, :]
        return self.fc(hT)


# ---------------------
# Training / loading
# ---------------------
MODEL_PATH = "traj_lstm_gemini.pt"


def train_model_on_traces(traces: list[list[tuple[float, float, str]]], epochs: int = EPOCHS) -> LSTMForecast | None:
    dataset = TrajDataset(traces)
    if len(dataset) == 0:
        print("Not enough data to train.")
        return None

    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    model = LSTMForecast().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.MSELoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            pred = model(xb)
            loss = loss_fn(pred, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * xb.size(0)
        print(f"[Train] Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataset):.6f}")

    torch.save(model.state_dict(), MODEL_PATH)
    return model


def load_model_if_exists() -> LSTMForecast:
    model = LSTMForecast().to(DEVICE)
    if os.path.exists(MODEL_PATH):
        try:
            model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
            model.eval()
            print(f"Loaded model from {MODEL_PATH}")
        except Exception as e:
            print(f"Error loading model: {e}")
    return model


# ---------------------
# Reverse geocoding (Google Maps API)
# ---------------------
def reverse_geocode(lat: float, lon: float) -> dict:
    if not GOOGLE_GEOCODING_API_KEY:
        return {"error": "No Google Geocoding API key provided"}

    try:
        url = "https://maps.googleapis.com/maps/api/geocode/json"
        params = {"latlng": f"{lat},{lon}", "key": GOOGLE_GEOCODING_API_KEY}
        r = requests.get(url, params=params, timeout=10)
        j = r.json()
        if j.get("status") == "OK" and j.get("results"):
            comp = j["results"][0]["address_components"]
            result = {
                "country": None, "state": None, "city": None,
                "street": None, "formatted": j["results"][0].get("formatted_address")
            }
            for c in comp:
                types_c = c.get("types", [])
                if "country" in types_c:
                    result["country"] = c.get("long_name")
                if "administrative_area_level_1" in types_c:
                    result["state"] = c.get("long_name")
                if "locality" in types_c or "postal_town" in types_c:
                    result["city"] = c.get("long_name")
                if "route" in types_c or "street_address" in types_c or "premise" in types_c:
                    if result["street"]:
                        result["street"] += ", " + c.get("long_name")
                    else:
                        result["street"] = c.get("long_name")
            return result
        else:
            return {"error": "no-results", "raw": j}
    except Exception as e:
        return {"error": str(e)}


# ---------------------
# Gemini LLM explanation
# ---------------------
def llm_explain(predictions: list[dict], recent_seq_summary: str) -> str:
    if not GEMINI_API_KEY or not client:
        # Fallback explanation without Gemini
        top = predictions[0]
        s = f"Top predicted location: ({top['lat']:.5f}, {top['lon']:.5f}) with probability {top['prob']:.2f}.\n"
        s += "Other candidates:\n"
        for p in predictions[1:]:
            s += f" - ({p['lat']:.5f}, {p['lon']:.5f}) prob {p['prob']:.2f}\n"
        s += f"Recent trace summary: {recent_seq_summary}\n"
        s += "Note: no Gemini key provided; used fallback explanation."
        return s

    try:
        prompt = (
            "You are an assistant that explains location prediction results.\n"
            f"Recent trace summary: {recent_seq_summary}\n"
            "Model predictions (lat, lon, prob):\n"
        )
        for p in predictions:
            prompt += f"- {p['lat']:.6f}, {p['lon']:.6f}, prob={p['prob']:.3f}\n"
        prompt += (
            "\nPlease produce a concise explanation (2‑4 sentences) that: "
            "(1) states the top predicted place, (2) gives reasons based on recent movement and time, "
            "(3) mentions uncertainty.\n"
        )

        resp = client.models.generate_content(
            model=GEMINI_MODEL,
            contents=prompt,
            config=types.GenerateContentConfig(
                temperature=0.2,
                max_output_tokens=200
            )
        )
        explanation = resp.text.strip()
        return explanation

    except Exception as e:
        return f"(Gemini explanation failed: {e})\nFallback: " + json.dumps(predictions[:3], default=str)


# ---------------------
# Prediction pipeline
# ---------------------
def predict_next_from_history(model: LSTMForecast, recent_seq: list[tuple[float, float, str]], top_k: int = 3) -> list[dict]:
    model.eval()
    window = recent_seq[-SEQ_LEN:]
    inp = coords_to_feature(window)
    x = torch.tensor(inp[None, :, :], dtype=torch.float32).to(DEVICE)
    with torch.no_grad():
        out = model(x).cpu().numpy()[0]

    last_lat, last_lon, _ = window[-1]
    pred_lat = last_lat + float(out[0])
    pred_lon = last_lon + float(out[1])

    # Generate candidate predictions with probabilities
    candidates = [
        {"lat": pred_lat, "lon": pred_lon, "prob": 0.6},
        {"lat": pred_lat + 0.01, "lon": pred_lon - 0.005, "prob": 0.25},
        {"lat": pred_lat - 0.01, "lon": pred_lon + 0.006, "prob": 0.15},
    ]
    total_prob = sum(c["prob"] for c in candidates)
    for c in candidates:
        c["prob"] /= total_prob

    return candidates[:top_k]


# ---------------------
# Map visualization
# ---------------------
def make_map_html(last_lat: float, last_lon: float, predictions: list[dict]) -> str:
    m = folium.Map(location=[last_lat, last_lon], zoom_start=8, tiles="OpenStreetMap")
    folium.Marker([last_lat, last_lon], popup="Last Seen", icon=folium.Icon(color="red")).add_to(m)
    cluster = MarkerCluster().add_to(m)
    for p in predictions:
        folium.Marker([p["lat"], p["lon"]], popup=f"prob={p['prob']:.2f}", icon=folium.Icon(color="blue")).add_to(cluster)
    return m._repr_html_()


# ---------------------
# Input parsing / helpers
# ---------------------
SESSION_POINTS = []


def add_point(date: str, time_inp: str, loc_text: str) -> str:
    global SESSION_POINTS
    try:
        lat_str, lon_str = loc_text.split(",")
        lat = float(lat_str.strip())
        lon = float(lon_str.strip())
        timestamp = f"{date}T{time_inp}"
        SESSION_POINTS.append((lat, lon, timestamp))
        return f"✅ Added point: ({lat}, {lon}) at {timestamp}\nTotal points: {len(SESSION_POINTS)}"
    except Exception as e:
        return f"❌ Error parsing input: {e}"

def parse_trace_text(text: str) -> list[tuple[float, float, str]]:
    """
    Parse input text containing trajectory points.
    Supports CSV with headers or whitespace/comma separated lines with lat, lon, timestamp in any order.
    Returns list of (lat, lon, iso-timestamp) tuples.
    """
    if not text or not text.strip():
        return []

    # Try reading as CSV first
    try:
        df = pd.read_csv(io.StringIO(text.strip()))
        # Try to infer columns
        lat_col = None
        lon_col = None
        time_col = None
        for col in df.columns:
            c = col.lower()
            if c in ("lat", "latitude"):
                lat_col = col
            elif c in ("lon", "lng", "longitude"):
                lon_col = col
            elif c in ("time", "timestamp", "datetime", "date", "iso"):
                time_col = col
        if lat_col and lon_col and time_col:
            # Convert timestamp to ISO format if needed
            def to_iso(t):
                try:
                    dt = pd.to_datetime(t)
                    return dt.isoformat()
                except Exception:
                    return str(t)
            return [
                (float(row[lat_col]), float(row[lon_col]), to_iso(row[time_col]))
                for _, row in df.iterrows()
            ]
    except Exception:
        pass

    # Fallback: parse line by line with heuristic
    lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()]
    parsed = []
    for ln in lines:
        parts = [p.strip() for p in ln.replace(",", " ").split()]
        if len(parts) < 3:
            continue

        # Attempt lat lon timestamp
        try:
            lat = float(parts[0])
            lon = float(parts[1])
            ts = parts[2]
            # Validate timestamp format loosely
            datetime.datetime.fromisoformat(ts)
            parsed.append((lat, lon, ts))
            continue
        except Exception:
            pass

        # Attempt timestamp lat lon
        try:
            ts = parts[0]
            datetime.datetime.fromisoformat(ts)
            lat = float(parts[1])
            lon = float(parts[2])
            parsed.append((lat, lon, ts))
            continue
        except Exception:
            pass

        # Attempt lat lon with date and time separate
        if len(parts) >= 4:
            try:
                lat = float(parts[0])
                lon = float(parts[1])
                date = parts[2]
                time_part = parts[3]
                ts = f"{date}T{time_part}"
                datetime.datetime.fromisoformat(ts)
                parsed.append((lat, lon, ts))
                continue
            except Exception:
                pass

    return parsed


# ---------------------
# Example: clear session points
# ---------------------
def clear_session() -> str:
    global SESSION_POINTS
    SESSION_POINTS = []
    return "Session cleared."


# ---------------------
# Gradio UI definition (simplified)
# ---------------------
def main_ui():
    with gr.Blocks() as demo:
        gr.Markdown("# Trajectory Prediction with LSTM and Gemini Explanation")

        date_input = gr.Textbox(label="Date (YYYY-MM-DD)", value=datetime.date.today().isoformat())
        time_input = gr.Textbox(label="Time (HH:MM:SS)", value="12:00:00")
        loc_input = gr.Textbox(label="Location (lat, lon)", placeholder="e.g. 37.7749, -122.4194")

        add_btn = gr.Button("Add Point")
        add_status = gr.Textbox(label="Status", interactive=False)

        trace_text = gr.TextArea(label="Paste Trace (lat, lon, timestamp per line or CSV)")

        train_btn = gr.Button("Train Model on Trace")
        train_status = gr.Textbox(label="Training Status", interactive=False)

        predict_btn = gr.Button("Predict Next Location")
        predict_output = gr.Textbox(label="Prediction Explanation", interactive=False)
        map_output = gr.HTML(label="Prediction Map")

        clear_btn = gr.Button("Clear Session")
        point_counter = gr.Textbox(label="Current # of Points", interactive=False)
        
        def on_add_point(date, time_inp, loc_text):
            msg = add_point(date, time_inp, loc_text)
            return msg, f"{len(SESSION_POINTS)}"

        def on_train(trace_str):
            global SESSION_POINTS
            traces = [parse_trace_text(trace_str)]
            SESSION_POINTS = traces[0] if traces and len(traces[0]) >= SEQ_LEN else []
            model = train_model_on_traces(traces)
            if model:
                return "Training completed."
            else:
                return "Training failed or insufficient data."


        def on_predict():
            if not SESSION_POINTS or len(SESSION_POINTS) < SEQ_LEN:
                return "Not enough points for prediction.", ""
            model = load_model_if_exists()
            if model is None:
                return "Model not found. Train first.", ""
            preds = predict_next_from_history(model, SESSION_POINTS)
            explanation = llm_explain(preds, f"{len(SESSION_POINTS)} points ending at {SESSION_POINTS[-1][2]}")
            last_lat, last_lon, _ = SESSION_POINTS[-1]
            map_html = make_map_html(last_lat, last_lon, preds)
            return explanation, map_html

        add_btn.click(on_add_point, inputs=[date_input, time_input, loc_input], outputs=[add_status, point_counter])
        train_btn.click(on_train, inputs=[trace_text], outputs=train_status)
        predict_btn.click(on_predict, outputs=[predict_output, map_output])
        clear_btn.click(lambda: (clear_session(), "0"), outputs=[add_status, point_counter])

    return demo


if __name__ == "__main__":
    main_ui().launch()