HiCoTraj / app.py
ginnyxxxxxxx's picture
cot
64743fe
raw
history blame
32.6 kB
import gradio as gr
import pandas as pd
import folium
import numpy as np
import os
import re
import json
BASE = os.path.dirname(os.path.abspath(__file__)) if "__file__" in dir() else os.getcwd()
STAY_POINTS = os.path.join(BASE, "data", "stay_points_inference_sample.csv")
POI_PATH = os.path.join(BASE, "data", "poi_inference_sample.csv")
DEMO_PATH = os.path.join(BASE, "data", "demographics_inference_sample.csv")
COT_PATH = os.path.join(BASE, "data", "inference_results_sample.json")
SEX_MAP = {1:"Male", 2:"Female", -8:"Unknown", -7:"Prefer not to answer"}
EDU_MAP = {1:"Less than HS", 2:"HS Graduate/GED", 3:"Some College/Associate",
4:"Bachelor's Degree", 5:"Graduate/Professional Degree",
-1:"N/A", -7:"Prefer not to answer", -8:"Unknown"}
INC_MAP = {1:"<$10,000", 2:"$10,000–$14,999", 3:"$15,000–$24,999",
4:"$25,000–$34,999", 5:"$35,000–$49,999", 6:"$50,000–$74,999",
7:"$75,000–$99,999", 8:"$100,000–$124,999", 9:"$125,000–$149,999",
10:"$150,000–$199,999", 11:"$200,000+",
-7:"Prefer not to answer", -8:"Unknown", -9:"Not ascertained"}
RACE_MAP = {1:"White", 2:"Black or African American", 3:"Asian",
4:"American Indian or Alaska Native",
5:"Native Hawaiian or Other Pacific Islander",
6:"Multiple races", 97:"Other",
-7:"Prefer not to answer", -8:"Unknown"}
ACT_MAP = {0:"Transportation", 1:"Home", 2:"Work", 3:"School", 4:"ChildCare",
5:"BuyGoods", 6:"Services", 7:"EatOut", 8:"Errands", 9:"Recreation",
10:"Exercise", 11:"Visit", 12:"HealthCare", 13:"Religious",
14:"SomethingElse", 15:"DropOff"}
print("Loading data...")
sp = pd.read_csv(STAY_POINTS)
poi = pd.read_csv(POI_PATH)
demo = pd.read_csv(DEMO_PATH)
sp = sp.merge(poi, on="poi_id", how="left")
sp["start_datetime"] = pd.to_datetime(sp["start_datetime"], utc=True)
sp["end_datetime"] = pd.to_datetime(sp["end_datetime"], utc=True)
sp["duration_min"] = ((sp["end_datetime"] - sp["start_datetime"]).dt.total_seconds() / 60).round(1)
def parse_act_types(x):
try:
codes = list(map(int, str(x).strip("[]").split()))
return ", ".join(ACT_MAP.get(c, str(c)) for c in codes)
except:
return str(x)
sp["act_label"] = sp["act_types"].apply(parse_act_types)
# Load CoT JSON (optional)
cot_by_agent = {}
if os.path.exists(COT_PATH):
with open(COT_PATH, "r") as f:
cot_raw = json.load(f)
# Support both list and {"inference_results": [...]} formats
records = cot_raw if isinstance(cot_raw, list) else cot_raw.get("inference_results", [])
for result in records:
cot_by_agent[int(result["agent_id"])] = result
print(f"Loaded CoT for {len(cot_by_agent)} agents.")
sample_agents = sorted(sp["agent_id"].unique().tolist())
print(f"Ready. {len(sample_agents)} agents loaded.")
# ── Mock CoT (fallback when agent not in JSON) ────────────────────────────────
MOCK_S1 = """LOCATION INVENTORY:
- Top venues: residence (36 visits), Clinton Mobile Estates (9 visits), 7-Eleven (8 visits)
- Price level: budget (7-Eleven, car wash) and mid-range (Euro Caffe, Pepper Shaker Cafe)
- Neighborhood: residential and commercial urban mix
TEMPORAL PATTERNS:
- Active hours: 09:00-23:00
- Weekday/Weekend: 66% weekday, 34% weekend
- Routine: consistent morning start times
SEQUENCE:
- Typical chain: Home to Exercise/Work to Home
- Notable pattern: weekend religious visits every Sunday morning"""
MOCK_S2 = """SCHEDULE: Fixed weekday routine with flexible afternoon activities
ECONOMIC: Budget-conscious with occasional mid-range dining
SOCIAL: Community-engaged through regular religious attendance
LIFESTYLE: Urban working-class with active recreational habits
STABILITY: Highly consistent 4-week pattern with minimal deviation"""
MOCK_S3 = """INCOME_PREDICTION: Middle ($35k-$75k)
INCOME_CONFIDENCE: 4
INCOME_REASONING: Frequent budget venue visits (7-Eleven, self-service car wash) signal cost awareness, while occasional mid-range dining and stable employment-like patterns at Clinton Mobile Estates suggest a steady middle income. No luxury venue signals detected.
ALTERNATIVES: Low ($15k-$35k) | Upper-Middle ($75k-$125k)"""
def get_cot(agent_id):
"""Return (s1, s2, s3) text for agent, falling back to mock."""
result = cot_by_agent.get(agent_id)
if result:
s1 = result.get("step1_response", MOCK_S1)
s2 = result.get("step2_response", MOCK_S2)
s3 = result.get("step3_response", MOCK_S3)
else:
s1, s2, s3 = MOCK_S1, MOCK_S2, MOCK_S3
return s1, s2, s3
# ── Mobility text builders ────────────────────────────────────────────────────
def build_mobility_summary(agent_sp):
top5 = (agent_sp.groupby("name")["duration_min"]
.agg(visits="count", avg_dur="mean")
.sort_values("visits", ascending=False)
.head(5))
obs_start = agent_sp["start_datetime"].min().strftime("%Y-%m-%d")
obs_end = agent_sp["end_datetime"].max().strftime("%Y-%m-%d")
days = (agent_sp["end_datetime"].max() - agent_sp["start_datetime"].min()).days
# Top activity types
act_counts = agent_sp["act_label"].value_counts().head(3)
top_acts = ", ".join(f"{a} ({n})" for a, n in act_counts.items())
# Time of day
agent_sp2 = agent_sp.copy()
agent_sp2["hour"] = agent_sp2["start_datetime"].dt.hour
def tod(h):
if 5 <= h < 12: return "Morning"
if 12 <= h < 17: return "Afternoon"
if 17 <= h < 21: return "Evening"
return "Night"
agent_sp2["tod"] = agent_sp2["hour"].apply(tod)
peak_tod = agent_sp2["tod"].value_counts().idxmax()
agent_sp2["is_weekend"] = agent_sp2["start_datetime"].dt.dayofweek >= 5
wd_pct = int((~agent_sp2["is_weekend"]).mean() * 100)
lines = [
f"Period: {obs_start} ~ {obs_end} ({days} days)",
f"Stay points: {len(agent_sp)} | Unique locations: {agent_sp['name'].nunique()}",
f"Weekday/Weekend: {wd_pct}% / {100-wd_pct}% | Peak time: {peak_tod}",
f"Top activities: {top_acts}",
"",
"Top Locations:",
]
for i, (name, row) in enumerate(top5.iterrows(), 1):
lines.append(f" {i}. {name} β€” {int(row['visits'])} visits, avg {int(row['avg_dur'])} min")
return "\n".join(lines)
def build_weekly_checkin(agent_sp, max_days=None):
agent_sp2 = agent_sp.copy()
agent_sp2["date"] = agent_sp2["start_datetime"].dt.date
all_dates = sorted(agent_sp2["date"].unique())
dates_to_show = all_dates[:max_days] if max_days else all_dates
total_days = len(all_dates)
lines = ["WEEKLY CHECK-IN SUMMARY", "======================="]
for date in dates_to_show:
grp = agent_sp2[agent_sp2["date"] == date]
dow = grp["start_datetime"].iloc[0].strftime("%A")
label = "Weekend" if grp["start_datetime"].iloc[0].dayofweek >= 5 else "Weekday"
lines.append(f"\n--- {dow}, {date} ({label}) ---")
lines.append(f"Total activities: {len(grp)}")
for _, row in grp.iterrows():
lines.append(
f"- {row['start_datetime'].strftime('%H:%M')}-"
f"{row['end_datetime'].strftime('%H:%M')} "
f"({int(row['duration_min'])} mins): "
f"{row['name']} - {row['act_label']}"
)
if max_days and total_days > max_days:
lines.append(f"\n... ({total_days - max_days} more days)")
return "\n".join(lines)
# ── HTML reasoning chain ──────────────────────────────────────────────────────
CHAIN_CSS = """
<style>
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&family=IBM+Plex+Sans:wght@300;400;600&display=swap');
.hicotraj-chain { font-family: 'IBM Plex Sans', sans-serif; padding: 12px 4px; }
.stage-card {
border-radius: 10px; padding: 16px 18px; margin-bottom: 0;
transition: opacity 0.4s, filter 0.4s;
}
.stage-card.dim { opacity: 0.32; filter: grayscale(0.5); }
.stage-card.active { opacity: 1; filter: none; }
.stage-card.s1 { background: #f8f9fc; border: 1.5px solid #c8d0e0; }
.stage-card.s2 { background: #fdf6f0; border: 1.5px solid #e8c9a8; }
.stage-card.s3 { background: #fff8f8; border: 2px solid #c0392b; }
.stage-header { display: flex; align-items: center; gap: 10px; margin-bottom: 10px; }
.stage-badge {
font-family: 'IBM Plex Mono', monospace;
font-size: 10px; font-weight: 600; letter-spacing: 0.08em;
padding: 3px 8px; border-radius: 4px; text-transform: uppercase;
}
.s1 .stage-badge { background: #dde3f0; color: #3a4a6b; }
.s2 .stage-badge { background: #f0dcc8; color: #7a4010; }
.s3 .stage-badge { background: #c0392b; color: #fff; }
.stage-title { font-size: 13px; font-weight: 600; color: #1a1a2e; }
.tag-row { display: flex; flex-wrap: wrap; gap: 6px; margin-top: 4px; }
.tag {
font-family: 'IBM Plex Mono', monospace; font-size: 11px;
background: #e8ecf5; color: #2c3e60;
padding: 3px 8px; border-radius: 4px; white-space: nowrap;
}
.behavior-row {
display: grid; grid-template-columns: 100px 1fr;
gap: 4px 10px; margin-top: 2px; font-size: 12px; line-height: 1.6;
}
.bkey {
font-family: 'IBM Plex Mono', monospace; font-size: 11px;
font-weight: 600; color: #9b6a3a; padding-top: 1px;
}
.bval { color: #3a2a1a; }
.pred-block { margin-top: 4px; }
.pred-label {
font-size: 11px; font-family: 'IBM Plex Mono', monospace; color: #888;
text-transform: uppercase; letter-spacing: 0.06em; margin-bottom: 4px;
}
.pred-value { font-size: 22px; font-weight: 600; color: #c0392b; margin-bottom: 8px; }
.confidence-bar-wrap { display: flex; align-items: center; gap: 10px; margin-bottom: 10px; }
.confidence-bar-bg { flex: 1; height: 6px; background: #f0d0cf; border-radius: 3px; overflow: hidden; }
.confidence-bar-fill { height: 100%; background: linear-gradient(90deg, #e74c3c, #8b0000); border-radius: 3px; }
.confidence-label { font-family: 'IBM Plex Mono', monospace; font-size: 11px; color: #c0392b; font-weight: 600; white-space: nowrap; }
.reasoning-text { font-size: 12px; color: #4a2a2a; line-height: 1.6; border-left: 3px solid #e8c0be; padding-left: 10px; margin-top: 6px; }
.alternatives { margin-top: 10px; font-size: 11px; font-family: 'IBM Plex Mono', monospace; color: #999; }
.alternatives span { color: #c0392b; opacity: 0.7; }
.chain-arrow { display: flex; flex-direction: column; align-items: center; padding: 4px 0; transition: opacity 0.4s; }
.arrow-line { width: 2px; height: 16px; background: #d0c0b0; }
.arrow-label {
font-family: 'IBM Plex Mono', monospace; font-size: 10px; color: #aaa;
letter-spacing: 0.06em; text-transform: uppercase;
background: white; padding: 2px 8px; border: 1px solid #e0e0e0; border-radius: 10px; margin: 2px 0;
}
.arrow-tip { width: 0; height: 0; border-left: 5px solid transparent; border-right: 5px solid transparent; border-top: 7px solid #d0c0b0; }
.thinking { font-size: 13px; color: #888; padding: 8px 0; }
.empty-hint { font-size: 12px; color: #ccc; padding: 6px 0; }
.temporal-line { font-size: 11px; color: #666; margin-top: 8px; font-family: 'IBM Plex Mono', monospace; }
.prompt-snippet {
font-size: 11px; color: #888; line-height: 1.5;
background: rgba(0,0,0,0.03); border-left: 2px solid #ddd;
padding: 6px 10px; border-radius: 0 4px 4px 0;
margin-bottom: 8px; font-family: 'IBM Plex Mono', monospace;
}
.prompt-label {
display: inline-block; font-size: 9px; font-weight: 600;
text-transform: uppercase; letter-spacing: 0.08em;
color: #aaa; margin-right: 6px;
background: #eee; padding: 1px 5px; border-radius: 3px;
}
.resp-label {
font-size: 9px; font-weight: 600; text-transform: uppercase;
letter-spacing: 0.08em; color: #aaa; margin-bottom: 4px;
display: inline-block; background: #eee; padding: 1px 5px; border-radius: 3px;
}
.wd {
display: inline-block; width: 6px; height: 6px; border-radius: 50%;
background: currentColor; margin: 0 2px; opacity: 0.3;
animation: wd-pulse 1.2s ease-in-out infinite;
}
.wd:nth-child(2) { animation-delay: 0.2s; }
.wd:nth-child(3) { animation-delay: 0.4s; }
@keyframes wd-pulse {
0%, 100% { opacity: 0.2; transform: scale(0.8); }
50% { opacity: 1; transform: scale(1.1); }
}
</style>
"""
def _dots():
return '<span class="wd"></span><span class="wd"></span><span class="wd"></span>'
def render_chain(s1_text, s2_text, s3_text, status="done"):
s1_active = status in ("running1", "running2", "running3", "done")
s2_active = status in ("running2", "running3", "done")
s3_active = status in ("running3", "done")
# ── Stage 1 ───────────────────────────────────────────────────────────────
if status == "running1":
s1_content = f'<div class="thinking">Extracting features {_dots()}</div>'
elif s1_text:
tags = []
# Parse LOCATION INVENTORY bullets: "- Name: N visits, description"
in_inventory = False
for line in s1_text.splitlines():
line = line.strip()
if "LOCATION INVENTORY" in line.upper():
in_inventory = True
continue
if in_inventory:
if line.startswith("TEMPORAL") or line.startswith("SEQUENCE") or (line and not line.startswith("-") and not line.startswith("*") and len(line) > 40):
break
if line.startswith("-"):
# "- Name: N visits, type" or "- Name (N visits)"
clean = line.lstrip("-").strip()
# Shorten: keep "Name (N visits)" style
m = re.match(r'(.+?):\s*(\d+)\s*visit', clean, re.IGNORECASE)
if m:
name = m.group(1).strip()
n = m.group(2)
tags.append(f"{name} Β· {n}x")
elif clean:
tags.append(clean[:55])
if len(tags) >= 8:
break
# Fallback: also grab temporal summary line
temporal_line = ""
for line in s1_text.splitlines():
line = line.strip()
if "weekly distribution" in line.lower() or "weekday" in line.lower():
temporal_line = line.lstrip("-").strip()[:70]
break
tag_html = "".join(f'<span class="tag">{t}</span>' for t in tags)
temp_html = f'<div class="temporal-line">⏱ {temporal_line}</div>' if temporal_line else ""
s1_content = f'<div class="tag-row">{tag_html}</div>{temp_html}'
else:
s1_content = '<div class="empty-hint">Press β–Ά to start</div>'
# ── Stage 2 ───────────────────────────────────────────────────────────────
KEYS = [
("SCHEDULE", ["ROUTINE", "SCHEDULE"]),
("ECONOMIC", ["ECONOMIC", "SPENDING", "FINANCIAL"]),
("SOCIAL", ["SOCIAL", "COMMUNITY", "LIFESTYLE"]),
("STABILITY", ["STABILITY", "CONSISTENCY", "REGULARITY"]),
]
if status == "running2":
s2_content = f'<div class="thinking" style="color:#a06030">Analyzing behavior {_dots()}</div>'
elif s2_text:
# Parse numbered sections
sections = {}
current_key = None
current_bullets = []
for line in s2_text.splitlines():
line = line.strip()
m = re.match(r'^\d+\.\s+(.+?)(?:\s+ANALYSIS)?(?:\s+PATTERNS)?(?:\s+INDICATORS)?:\s*$', line, re.IGNORECASE)
if m:
if current_key:
sections[current_key] = current_bullets
current_key = m.group(1).upper()
current_bullets = []
elif current_key and line.startswith("-"):
bullet = line.lstrip("-").strip()
if bullet:
current_bullets.append(bullet)
if current_key:
sections[current_key] = current_bullets
rows_html = ""
for label, search_words in KEYS:
val = "β€”"
for k, bullets in sections.items():
if any(w in k for w in search_words) and bullets:
# Take first bullet, truncate at 2 sentences
text = bullets[0]
sentences = re.split(r'(?<=[.!?])\s+', text)
val = " ".join(sentences[:2])
if len(val) > 100:
val = val[:97] + "..."
break
rows_html += f'<div class="bkey">{label}</div><div class="bval">{val}</div>'
s2_content = f'<div class="behavior-row">{rows_html}</div>'
else:
s2_content = '<div class="empty-hint">Waiting...</div>'
# ── Stage 3 ───────────────────────────────────────────────────────────────
if status == "running3":
s3_content = f'<div class="thinking" style="color:#c0392b">Inferring demographics {_dots()}</div>'
elif s3_text:
pred = reasoning = ""
lines = s3_text.splitlines()
i = 0
while i < len(lines):
line = lines[i].strip()
if line.startswith("INCOME_PREDICTION:"):
pred = line.replace("INCOME_PREDICTION:", "").strip()
elif line.startswith("INCOME_REASONING:"):
reasoning = line.replace("INCOME_REASONING:", "").strip()
# Collect continuation lines until blank or next key
i += 1
while i < len(lines):
nxt = lines[i].strip()
if not nxt or nxt.startswith("INCOME_") or re.match(r'^\d+\.', nxt):
break
reasoning += " " + nxt
i += 1
continue
i += 1
# Truncate reasoning to ~2 sentences
sentences = re.split(r'(?<=[.!?])\s+', reasoning.strip())
short_reasoning = " ".join(sentences[:2])
if len(short_reasoning) > 160:
short_reasoning = short_reasoning[:157] + "..."
s3_content = f"""
<div class="pred-block">
<div class="pred-label">Income Prediction</div>
<div class="pred-value">{pred or "β€”"}</div>
<div class="reasoning-text">{short_reasoning}</div>
</div>"""
else:
s3_content = '<div class="empty-hint">Waiting...</div>'
PROMPT_SNIPPETS = {
"s1": "You are an expert mobility analyst. Given the trajectory data below, extract: (1) LOCATION INVENTORY β€” list all POI categories visited and visit frequency; (2) TEMPORAL PATTERNS β€” weekly distribution, peak hours; (3) SEQUENCE β€” typical activity chains...",
"s2": "Based on the trajectory features identified: {Response 1}. Now analyze what these mobility patterns reveal about lifestyle: (1) SCHEDULE β€” work/activity routine type; (2) ECONOMIC β€” spending venue tiers; (3) SOCIAL β€” social engagement; (4) STABILITY β€” consistency of routine...",
"s3": "Based on feature analysis {Response 1} and behavioral analysis {Response 2}, predict income level. Output β€” INCOME_PREDICTION: [range]; INCOME_REASONING: [detailed reasoning]...",
}
def card(cls, badge, title, content_html, active):
dim = "active" if active else "dim"
prompt = PROMPT_SNIPPETS.get(cls, "")
prompt_html = f'<div class="prompt-snippet"><span class="prompt-label">Prompt</span>{prompt}</div>' if prompt else ""
resp_label = '<div class="resp-label">Response</div>' if active and content_html and "empty-hint" not in content_html and "thinking" not in content_html else ""
return f"""
<div class="stage-card {cls} {dim}">
<div class="stage-header">
<span class="stage-badge">{badge}</span>
<span class="stage-title">{title}</span>
</div>
{prompt_html}
{resp_label}
{content_html}
</div>"""
def arrow(label, active):
op = "1" if active else "0.25"
return f"""
<div class="chain-arrow" style="opacity:{op}">
<div class="arrow-line"></div>
<div class="arrow-label">{label}</div>
<div class="arrow-line"></div>
<div class="arrow-tip"></div>
</div>"""
html = CHAIN_CSS + '<div class="hicotraj-chain">'
html += card("s1", "Stage 1", "Factual Feature Extraction", s1_content, s1_active)
html += arrow("behavioral abstraction", s2_active)
html += card("s2", "Stage 2", "Behavioral Pattern Analysis", s2_content, s2_active)
html += arrow("demographic inference", s3_active)
html += card("s3", "Stage 3", "Demographic Inference", s3_content, s3_active)
html += "</div>"
return html
def build_map(agent_sp):
agent_sp = agent_sp.reset_index(drop=True).copy()
agent_sp["latitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
agent_sp["longitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
lat = agent_sp["latitude"].mean()
lon = agent_sp["longitude"].mean()
m = folium.Map(location=[lat, lon], zoom_start=12, tiles="CartoDB positron")
coords = list(zip(agent_sp["latitude"], agent_sp["longitude"]))
if len(coords) > 1:
folium.PolyLine(coords, color="#cc000055", weight=1.5, opacity=0.4).add_to(m)
n = len(agent_sp)
for i, row in agent_sp.iterrows():
ratio = i / max(n - 1, 1)
r = int(255 - ratio * (255 - 139))
g = int(204 * (1 - ratio) ** 2)
b = 0
color = f"#{r:02x}{g:02x}{b:02x}"
folium.CircleMarker(
location=[row["latitude"], row["longitude"]],
radius=7, color=color, fill=True, fill_color=color, fill_opacity=0.9,
popup=folium.Popup(
f"<b>#{i+1} {row['name']}</b><br>"
f"{row['start_datetime'].strftime('%a %m/%d %H:%M')}<br>"
f"{int(row['duration_min'])} min<br>{row['act_label']}",
max_width=220
)
).add_to(m)
legend_html = """
<div style="
position:fixed; bottom:18px; left:18px; z-index:9999;
background:rgba(255,255,255,0.92); border-radius:8px;
padding:8px 12px; font-size:11px; font-family:sans-serif;
box-shadow:0 1px 5px rgba(0,0,0,0.2); line-height:1.8;
">
<div style="font-weight:600;margin-bottom:4px;">Stay Point Legend</div>
<div style="display:flex;align-items:center;gap:6px;">
<svg width="60" height="10">
<defs><linearGradient id="lg" x1="0" x2="1" y1="0" y2="0">
<stop offset="0%" stop-color="#ffcc00"/>
<stop offset="100%" stop-color="#8b0000"/>
</linearGradient></defs>
<rect width="60" height="10" rx="4" fill="url(#lg)"/>
</svg>
<span>Earlier &rarr; Later</span>
</div>
<div style="display:flex;align-items:center;gap:6px;margin-top:2px;">
<svg width="14" height="14"><circle cx="7" cy="7" r="5" fill="#cc4444" opacity="0.5"/></svg>
<span>Movement path</span>
</div>
<div style="color:#999;font-size:10px;margin-top:2px;">Click dot for details</div>
</div>
"""
m.get_root().html.add_child(folium.Element(legend_html))
m.get_root().width = "100%"
m.get_root().height = "420px"
return m._repr_html_()
def build_demo_text(row):
age = int(row["age"]) if row["age"] > 0 else "Unknown"
return (
f"Age: {age} | "
f"Sex: {SEX_MAP.get(int(row['sex']), row['sex'])} | "
f"Race: {RACE_MAP.get(int(row['race']), row['race'])} | "
f"Education: {EDU_MAP.get(int(row['education']), row['education'])} | "
f"Income: {INC_MAP.get(int(row['hh_income']), row['hh_income'])}"
)
# ── Callbacks ─────────────────────────────────────────────────────────────────
def on_select(agent_id):
agent_id = int(agent_id)
agent_sp = sp[sp["agent_id"] == agent_id].sort_values("start_datetime")
agent_demo = demo[demo["agent_id"] == agent_id].iloc[0]
map_html = build_map(agent_sp)
demo_text = build_demo_text(agent_demo)
raw_text = build_mobility_summary(agent_sp) + "\n\n" + build_weekly_checkin(agent_sp)
chain_html = render_chain("", "", "", status="idle")
return map_html, raw_text, demo_text, chain_html
def run_step(agent_id, step):
"""Reveal one stage per click. step: 0->1->2->done(-1)"""
agent_id = int(agent_id)
s1, s2, s3 = get_cot(agent_id)
next_step = step + 1
if next_step == 1:
html = render_chain(s1, "", "", status="running2")
label = "β–Ά Stage 2: Behavioral Analysis"
return html, 1, gr.update(value=label)
elif next_step == 2:
html = render_chain(s1, s2, "", status="running3")
label = "β–Ά Stage 3: Demographic Inference"
return html, 2, gr.update(value=label)
else:
html = render_chain(s1, s2, s3, status="done")
return html, -1, gr.update(value="β†Ί Reset")
def handle_btn(agent_id, step):
if step == -1:
html = render_chain("", "", "", status="idle")
return html, 0, gr.update(value="β–Ά Stage 1: Feature Extraction")
return run_step(agent_id, step)
def on_select_reset(agent_id):
agent_id_int = int(agent_id)
agent_sp = sp[sp["agent_id"] == agent_id_int].sort_values("start_datetime")
agent_demo = demo[demo["agent_id"] == agent_id_int].iloc[0]
map_html = build_map(agent_sp)
demo_text = build_demo_text(agent_demo)
cot_entry = cot_by_agent.get(agent_id_int, {})
summary = cot_entry.get("text_representation") or build_mobility_summary(agent_sp)
raw_full = cot_entry.get("weekly_checkin") or build_weekly_checkin(agent_sp)
sep = "\n\n--- "
parts = raw_full.split(sep)
extra = len(parts) - 1
raw_text = parts[0] + (sep.join([""] + parts[1:2]) + ("\n\n... ({} more days)".format(extra - 1) if extra > 1 else "")) if extra > 0 else raw_full
chain_html = render_chain("", "", "", status="idle")
return map_html, summary, raw_text, demo_text, chain_html, 0, gr.update(value="β–Ά Stage 1: Feature Extraction")
SHOWCASE_AGENTS = sample_agents[:6]
def build_agent_cards(selected_id):
selected_id = int(selected_id)
parts = []
parts.append("<div style='display:grid;grid-template-columns:repeat(3,1fr);gap:10px;padding:4px 0;'>")
for aid in SHOWCASE_AGENTS:
row = demo[demo["agent_id"] == aid].iloc[0]
age = int(row["age"]) if row["age"] > 0 else "?"
sex = SEX_MAP.get(int(row["sex"]), "?")
edu = EDU_MAP.get(int(row["education"]), "?")
inc = INC_MAP.get(int(row["hh_income"]), "?")
is_sel = (aid == selected_id)
sel_style = "border:2px solid #c0392b;background:#fff8f8;box-shadow:0 2px 8px rgba(192,57,43,0.15);"
nor_style = "border:1.5px solid #ddd;background:#fafafa;box-shadow:0 1px 3px rgba(0,0,0,0.06);"
style = sel_style if is_sel else nor_style
dot = "<span style='display:inline-block;width:8px;height:8px;border-radius:50%;background:#c0392b;margin-right:5px;'></span>" if is_sel else ""
js = "var t=document.querySelector('#agent_hidden_input textarea');t.value='AID';t.dispatchEvent(new Event('input',{bubbles:true}));".replace("AID", str(aid))
parts.append(
"<div onclick=\"" + js + "\" style=\"cursor:pointer;border-radius:10px;padding:10px 13px;transition:all 0.2s;" + style + "\">"
"<div style='font-size:11px;font-weight:700;color:#c0392b;margin-bottom:3px;font-family:monospace;'>" + dot + "Agent #" + str(aid) + "</div>"
"<div style='font-size:11px;color:#333;line-height:1.6;'>"
"<b>Age:</b> " + str(age) + " &nbsp; <b>Sex:</b> " + sex + "<br>"
"<b>Edu:</b> " + edu + "<br>"
"<b>Income:</b> " + inc + "</div></div>"
)
parts.append("</div>")
return "".join(parts)
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="HiCoTraj Demo", theme=gr.themes.Soft()) as app:
gr.Markdown("## HiCoTraj β€” Trajectory Visualization & Hierarchical CoT Demo")
gr.Markdown("*Zero-Shot Demographic Reasoning via Hierarchical Chain-of-Thought Prompting from Trajectory* Β· ACM SIGSPATIAL GeoGenAgent 2025")
gr.Markdown("""
**Dataset:** NUMOSIM β€” a synthetic mobility dataset with realistic activity patterns across 6,000 agents.
> Stanford C, Adari S, Liao X, et al. *NUMoSim: A Synthetic Mobility Dataset with Anomaly Detection Benchmarks.* ACM SIGSPATIAL Workshop on Geospatial Anomaly Detection, 2024.
""")
gr.Markdown("### Select Agent")
agent_cards = gr.HTML(value=build_agent_cards(SHOWCASE_AGENTS[0]))
agent_hidden = gr.Textbox(
value=str(SHOWCASE_AGENTS[0]),
visible=True,
elem_id="agent_hidden_input",
elem_classes=["hidden-input"]
)
gr.HTML("<style>.hidden-input { display:none !important; }</style>")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Trajectory Map")
map_out = gr.HTML()
gr.Markdown("### NUMOSIM Raw Data")
with gr.Tabs():
with gr.Tab("Summary"):
summary_out = gr.Textbox(lines=10, interactive=False, label="", show_label=False)
with gr.Tab("Raw Data"):
raw_out = gr.Textbox(lines=10, interactive=False, label="", show_label=False)
show_all_btn = gr.Button("Show All Days", size="sm", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### Hierarchical Chain-of-Thought Reasoning")
step_state = gr.State(value=0)
run_btn = gr.Button("β–Ά Stage 1: Feature Extraction", variant="primary")
chain_out = gr.HTML(value=render_chain("", "", "", status="idle"))
def on_agent_click(agent_id):
cards_html = build_agent_cards(agent_id)
map_html, summary, raw_text, _demo_text, chain_html, step, btn = on_select_reset(agent_id)
return cards_html, map_html, summary, raw_text, chain_html, step, btn
agent_hidden.change(
fn=on_agent_click, inputs=agent_hidden,
outputs=[agent_cards, map_out, summary_out, raw_out, chain_out, step_state, run_btn]
)
def on_load(agent_id):
map_html, summary, raw_text, _demo_text, chain_html, step, btn = on_select_reset(agent_id)
return map_html, summary, raw_text, chain_html, step, btn
app.load(
fn=on_load, inputs=agent_hidden,
outputs=[map_out, summary_out, raw_out, chain_out, step_state, run_btn]
)
run_btn.click(
fn=handle_btn, inputs=[agent_hidden, step_state],
outputs=[chain_out, step_state, run_btn]
)
def toggle_raw(agent_id, current_text):
agent_id_int = int(agent_id)
cot_entry = cot_by_agent.get(agent_id_int, {})
agent_sp = sp[sp["agent_id"] == agent_id_int].sort_values("start_datetime")
raw_full = cot_entry.get("weekly_checkin") or build_weekly_checkin(agent_sp)
if "more days" in current_text:
return raw_full, gr.update(value="Show Less")
else:
sep = "\n\n--- "
parts = raw_full.split(sep)
extra = len(parts) - 1
short = parts[0] + (sep.join([""] + parts[1:2]) + ("\n\n... ({} more days)".format(extra - 1) if extra > 1 else "")) if extra > 0 else raw_full
return short, gr.update(value="Show All Days")
show_all_btn.click(
fn=toggle_raw, inputs=[agent_hidden, raw_out],
outputs=[raw_out, show_all_btn]
)
app.launch(show_error=True)