HiCoTraj / app.py
ginnyxxxxxxx's picture
clear
144e51b
raw
history blame
24.2 kB
import gradio as gr
import pandas as pd
import folium
import numpy as np
import os
import re
from huggingface_hub import InferenceClient
BASE = os.path.dirname(os.path.abspath(__file__))
STAY_POINTS = os.path.join(BASE, "data", "stay_points_sampled.csv")
POI_PATH = os.path.join(BASE, "data", "poi_sampled.csv")
DEMO_PATH = os.path.join(BASE, "data", "demographics_sampled.csv")
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
SEX_MAP = {1:"Male", 2:"Female", -8:"Unknown", -7:"Prefer not to answer"}
EDU_MAP = {1:"Less than HS", 2:"HS Graduate/GED", 3:"Some College/Associate",
4:"Bachelor's Degree", 5:"Graduate/Professional Degree",
-1:"N/A", -7:"Prefer not to answer", -8:"Unknown"}
INC_MAP = {1:"<$10,000", 2:"$10,000–$14,999", 3:"$15,000–$24,999",
4:"$25,000–$34,999", 5:"$35,000–$49,999", 6:"$50,000–$74,999",
7:"$75,000–$99,999", 8:"$100,000–$124,999", 9:"$125,000–$149,999",
10:"$150,000–$199,999", 11:"$200,000+",
-7:"Prefer not to answer", -8:"Unknown", -9:"Not ascertained"}
RACE_MAP = {1:"White", 2:"Black or African American", 3:"Asian",
4:"American Indian or Alaska Native",
5:"Native Hawaiian or Other Pacific Islander",
6:"Multiple races", 97:"Other",
-7:"Prefer not to answer", -8:"Unknown"}
ACT_MAP = {0:"Transportation", 1:"Home", 2:"Work", 3:"School", 4:"ChildCare",
5:"BuyGoods", 6:"Services", 7:"EatOut", 8:"Errands", 9:"Recreation",
10:"Exercise", 11:"Visit", 12:"HealthCare", 13:"Religious",
14:"SomethingElse", 15:"DropOff"}
print("Loading data...")
sp = pd.read_csv(STAY_POINTS)
poi = pd.read_csv(POI_PATH)
demo = pd.read_csv(DEMO_PATH)
sp = sp.merge(poi, on="poi_id", how="left")
sp["start_datetime"] = pd.to_datetime(sp["start_datetime"], utc=True)
sp["end_datetime"] = pd.to_datetime(sp["end_datetime"], utc=True)
sp["duration_min"] = ((sp["end_datetime"] - sp["start_datetime"]).dt.total_seconds() / 60).round(1)
def parse_act_types(x):
try:
codes = list(map(int, str(x).strip("[]").split()))
return ", ".join(ACT_MAP.get(c, str(c)) for c in codes)
except:
return str(x)
sp["act_label"] = sp["act_types"].apply(parse_act_types)
sample_agents = sorted(sp["agent_id"].unique().tolist())
print(f"Ready. {len(sample_agents)} agents loaded.")
# ── Mobility text builders ────────────────────────────────────────────────────
def build_mobility_summary(agent_sp):
top5 = (agent_sp.groupby("name")["duration_min"]
.agg(visits="count", avg_dur="mean")
.sort_values("visits", ascending=False)
.head(5))
obs_start = agent_sp["start_datetime"].min().strftime("%Y-%m-%d")
obs_end = agent_sp["end_datetime"].max().strftime("%Y-%m-%d")
days = (agent_sp["end_datetime"].max() - agent_sp["start_datetime"].min()).days
lines = [
"MOBILITY TRAJECTORY DATA",
"===========================",
f"Observation Period: {obs_start} to {obs_end} ({days} days)",
f"Total Stay Points: {len(agent_sp)}",
f"Unique Locations: {agent_sp['name'].nunique()}",
"",
"LOCATION PATTERNS",
"----------------",
]
for i, (name, row) in enumerate(top5.iterrows(), 1):
lines += [f"{i}. {name}",
f" Visits: {int(row['visits'])} times",
f" Average Duration: {int(row['avg_dur'])} minutes", ""]
agent_sp2 = agent_sp.copy()
agent_sp2["hour"] = agent_sp2["start_datetime"].dt.hour
def tod(h):
if 5 <= h < 12: return "morning"
if 12 <= h < 17: return "afternoon"
if 17 <= h < 21: return "evening"
return "night"
agent_sp2["tod"] = agent_sp2["hour"].apply(tod)
tod_pct = (agent_sp2["tod"].value_counts(normalize=True) * 100).round(0).astype(int)
agent_sp2["is_weekend"] = agent_sp2["start_datetime"].dt.dayofweek >= 5
wd_pct = int((~agent_sp2["is_weekend"]).mean() * 100)
lines += ["TEMPORAL PATTERNS", "----------------", "Activity by Time of Day:"]
for k, v in tod_pct.items():
lines.append(f"- {k}: {v}%")
lines += ["", "Weekday vs Weekend:",
f"- weekday: {wd_pct}%", f"- weekend: {100 - wd_pct}%"]
return "\n".join(lines)
def build_weekly_checkin(agent_sp):
lines = ["WEEKLY CHECK-IN SUMMARY", "======================="]
agent_sp2 = agent_sp.copy()
agent_sp2["date"] = agent_sp2["start_datetime"].dt.date
for date, grp in agent_sp2.groupby("date"):
dow = grp["start_datetime"].iloc[0].strftime("%A")
label = "Weekend" if grp["start_datetime"].iloc[0].dayofweek >= 5 else "Weekday"
lines.append(f"\n--- {dow}, {date} ({label}) ---")
lines.append(f"Total activities: {len(grp)}")
for _, row in grp.iterrows():
lines.append(
f"- {row['start_datetime'].strftime('%H:%M')}-"
f"{row['end_datetime'].strftime('%H:%M')} "
f"({int(row['duration_min'])} mins): "
f"{row['name']} - {row['act_label']}"
)
return "\n".join(lines)
# ── Prompts ───────────────────────────────────────────────────────────────────
STEP1_SYSTEM = """You are an expert mobility analyst. Extract objective features from the trajectory data.
Respond with EXACTLY this structure, keep each point to one short sentence:
LOCATION INVENTORY:
- Top venues: [list top 3 with visit counts]
- Price level: [budget/mid-range/high-end mix]
- Neighborhood: [residential/commercial/urban/suburban]
TEMPORAL PATTERNS:
- Active hours: [time range]
- Weekday/Weekend: [ratio]
- Routine: [consistent/variable]
SEQUENCE:
- Typical chain: [e.g. Home β†’ Work β†’ Home]
- Notable pattern: [one observation]
Do NOT interpret or infer demographics. Be concise."""
STEP2_SYSTEM = """You are an expert mobility analyst. Based on the extracted features, analyze behavioral patterns.
Respond with EXACTLY this structure, one short sentence per point:
SCHEDULE: [fixed/flexible/shift β€” one sentence]
ECONOMIC: [budget/mid-range/premium spending β€” one sentence]
SOCIAL: [family/individual/community focus β€” one sentence]
LIFESTYLE: [urban professional/suburban/student/other β€” one sentence]
STABILITY: [routine consistency β€” one sentence]
Do NOT make income predictions yet. Be concise."""
STEP3_SYSTEM = """You are an expert mobility analyst performing final income inference.
Based on the trajectory features and behavioral analysis, output EXACTLY:
INCOME_PREDICTION: [Very Low (<$15k) | Low ($15k-$35k) | Middle ($35k-$75k) | Upper-Middle ($75k-$125k) | High ($125k-$200k) | Very High (>$200k)]
INCOME_CONFIDENCE: [1-5]
INCOME_REASONING: [2-3 sentences linking specific mobility evidence to the prediction]
ALTERNATIVES: [2nd most likely] | [3rd most likely]"""
def call_llm(client, system_prompt, user_content, max_tokens=400):
response = client.chat.completions.create(
model=MODEL_ID,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
max_tokens=max_tokens,
temperature=0.3,
)
return response.choices[0].message.content.strip()
# ── HTML rendering ────────────────────────────────────────────────────────────
CHAIN_CSS = """
<style>
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&family=IBM+Plex+Sans:wght@300;400;600&display=swap');
.hicotraj-chain {
font-family: 'IBM Plex Sans', sans-serif;
padding: 12px 4px;
max-width: 100%;
}
/* Stage cards */
.stage-card {
border-radius: 10px;
padding: 16px 18px;
margin-bottom: 0;
position: relative;
transition: box-shadow 0.3s;
}
.stage-card.dim { opacity: 0.35; filter: grayscale(0.4); }
.stage-card.active { box-shadow: 0 4px 20px rgba(0,0,0,0.12); opacity: 1; filter: none; }
.stage-card.s1 { background: #f8f9fc; border: 1.5px solid #c8d0e0; }
.stage-card.s2 { background: #fdf6f0; border: 1.5px solid #e8c9a8; }
.stage-card.s3 { background: #fff8f8; border: 2px solid #c0392b; }
.stage-header {
display: flex;
align-items: center;
gap: 10px;
margin-bottom: 10px;
}
.stage-badge {
font-family: 'IBM Plex Mono', monospace;
font-size: 10px;
font-weight: 600;
letter-spacing: 0.08em;
padding: 3px 8px;
border-radius: 4px;
text-transform: uppercase;
}
.s1 .stage-badge { background: #dde3f0; color: #3a4a6b; }
.s2 .stage-badge { background: #f0dcc8; color: #7a4010; }
.s3 .stage-badge { background: #c0392b; color: #fff; }
.stage-title {
font-size: 13px;
font-weight: 600;
color: #1a1a2e;
}
/* Content inside cards */
.tag-row { display: flex; flex-wrap: wrap; gap: 6px; margin-top: 4px; }
.tag {
font-family: 'IBM Plex Mono', monospace;
font-size: 11px;
background: #e8ecf5;
color: #2c3e60;
padding: 3px 8px;
border-radius: 4px;
white-space: nowrap;
}
.s2 .tag { background: #f5e8d8; color: #6b3a10; }
.behavior-row {
display: grid;
grid-template-columns: 100px 1fr;
gap: 4px 10px;
margin-top: 2px;
font-size: 12px;
line-height: 1.5;
}
.bkey {
font-family: 'IBM Plex Mono', monospace;
font-size: 11px;
font-weight: 600;
color: #9b6a3a;
padding-top: 1px;
}
.bval { color: #3a2a1a; }
/* Prediction block */
.pred-block { margin-top: 8px; }
.pred-label {
font-size: 11px;
font-family: 'IBM Plex Mono', monospace;
color: #888;
text-transform: uppercase;
letter-spacing: 0.06em;
margin-bottom: 4px;
}
.pred-value {
font-size: 22px;
font-weight: 600;
color: #c0392b;
letter-spacing: -0.01em;
margin-bottom: 8px;
}
.confidence-bar-wrap {
display: flex;
align-items: center;
gap: 10px;
margin-bottom: 10px;
}
.confidence-bar-bg {
flex: 1;
height: 6px;
background: #f0d0cf;
border-radius: 3px;
overflow: hidden;
}
.confidence-bar-fill {
height: 100%;
background: linear-gradient(90deg, #e74c3c, #8b0000);
border-radius: 3px;
transition: width 0.8s ease;
}
.confidence-label {
font-family: 'IBM Plex Mono', monospace;
font-size: 11px;
color: #c0392b;
font-weight: 600;
white-space: nowrap;
}
.reasoning-text {
font-size: 12px;
color: #4a2a2a;
line-height: 1.6;
border-left: 3px solid #e8c0be;
padding-left: 10px;
margin-top: 6px;
}
.alternatives {
margin-top: 10px;
font-size: 11px;
font-family: 'IBM Plex Mono', monospace;
color: #999;
}
.alternatives span { color: #c0392b; opacity: 0.7; }
/* Arrow connector */
.chain-arrow {
display: flex;
flex-direction: column;
align-items: center;
margin: 0;
padding: 4px 0;
gap: 0;
}
.arrow-line {
width: 2px;
height: 18px;
background: linear-gradient(180deg, #c8d0e0, #e8c9a8);
}
.arrow-label {
font-family: 'IBM Plex Mono', monospace;
font-size: 10px;
color: #aaa;
letter-spacing: 0.06em;
text-transform: uppercase;
background: white;
padding: 2px 8px;
border: 1px solid #e0e0e0;
border-radius: 10px;
margin: 2px 0;
}
.arrow-tip {
width: 0; height: 0;
border-left: 5px solid transparent;
border-right: 5px solid transparent;
border-top: 7px solid #e8c9a8;
}
/* Waiting state */
.waiting-dot {
display: inline-block;
width: 7px; height: 7px;
border-radius: 50%;
background: #ccc;
margin: 0 2px;
animation: pulse 1.2s ease-in-out infinite;
}
.waiting-dot:nth-child(2) { animation-delay: 0.2s; }
.waiting-dot:nth-child(3) { animation-delay: 0.4s; }
@keyframes pulse {
0%, 100% { opacity: 0.3; transform: scale(0.8); }
50% { opacity: 1; transform: scale(1.1); }
}
</style>
"""
def _waiting_dots():
return '<span class="waiting-dot"></span><span class="waiting-dot"></span><span class="waiting-dot"></span>'
def render_chain(s1_text="", s2_text="", s3_text="", status="idle"):
"""
status: idle | running1 | running2 | running3 | done
"""
s1_active = status in ("running1", "running2", "running3", "done")
s2_active = status in ("running2", "running3", "done")
s3_active = status in ("running3", "done")
# ── Stage 1 content ──────────────────────────────────────────────────────
if status == "running1":
s1_content = f'<div style="padding:8px 0; color:#888; font-size:13px;">Extracting features {_waiting_dots()}</div>'
elif s1_text:
# Parse tags from the response β€” pull out short bullet points as tags
tags = []
for line in s1_text.splitlines():
line = line.strip().lstrip("-").strip()
if line and len(line) < 60 and not line.endswith(":"):
tags.append(line)
if len(tags) >= 8:
break
tag_html = "".join(f'<span class="tag">{t}</span>' for t in tags[:8])
s1_content = f'<div class="tag-row">{tag_html}</div>'
else:
s1_content = '<div style="font-size:12px;color:#bbb;padding:6px 0;">Run inference to see results</div>'
# ── Stage 2 content ──────────────────────────────────────────────────────
BEHAVIOR_KEYS = ["SCHEDULE", "ECONOMIC", "SOCIAL", "LIFESTYLE", "STABILITY"]
if status == "running2":
s2_content = f'<div style="padding:8px 0; color:#a06030; font-size:13px;">Analyzing behavior {_waiting_dots()}</div>'
elif s2_text:
rows_html = ""
for key in BEHAVIOR_KEYS:
pattern = rf"{key}[:\s]+(.+)"
m = re.search(pattern, s2_text, re.IGNORECASE)
val = m.group(1).strip().rstrip(".") if m else "β€”"
if len(val) > 80:
val = val[:77] + "..."
rows_html += f'<div class="bkey">{key}</div><div class="bval">{val}</div>'
s2_content = f'<div class="behavior-row">{rows_html}</div>'
else:
s2_content = '<div style="font-size:12px;color:#bbb;padding:6px 0;">Run inference to see results</div>'
# ── Stage 3 content ──────────────────────────────────────────────────────
if status == "running3":
s3_content = f'<div style="padding:8px 0; color:#c0392b; font-size:13px;">Inferring demographics {_waiting_dots()}</div>'
elif s3_text:
# Parse structured output
pred = conf_raw = reasoning = alts = ""
for line in s3_text.splitlines():
line = line.strip()
if line.startswith("INCOME_PREDICTION:"):
pred = line.replace("INCOME_PREDICTION:", "").strip()
elif line.startswith("INCOME_CONFIDENCE:"):
conf_raw = line.replace("INCOME_CONFIDENCE:", "").strip()
elif line.startswith("INCOME_REASONING:"):
reasoning = line.replace("INCOME_REASONING:", "").strip()
elif line.startswith("ALTERNATIVES:"):
alts = line.replace("ALTERNATIVES:", "").strip()
# Confidence bar
try:
conf_int = int(re.search(r"\d", conf_raw).group())
except:
conf_int = 3
bar_pct = conf_int * 20
alts_html = ""
if alts:
alts_html = f'<div class="alternatives">Also possible: <span>{alts}</span></div>'
s3_content = f"""
<div class="pred-block">
<div class="pred-label">Income Prediction</div>
<div class="pred-value">{pred or "β€”"}</div>
<div class="confidence-bar-wrap">
<div class="confidence-bar-bg">
<div class="confidence-bar-fill" style="width:{bar_pct}%"></div>
</div>
<div class="confidence-label">Confidence {conf_int}/5</div>
</div>
<div class="reasoning-text">{reasoning or s3_text[:200]}</div>
{alts_html}
</div>"""
else:
s3_content = '<div style="font-size:12px;color:#bbb;padding:6px 0;">Run inference to see results</div>'
def card(cls, badge, title, content, active):
dim_cls = "active" if active else "dim"
return f"""
<div class="stage-card {cls} {dim_cls}">
<div class="stage-header">
<span class="stage-badge">{badge}</span>
<span class="stage-title">{title}</span>
</div>
{content}
</div>"""
def arrow(label, active):
opacity = "1" if active else "0.3"
return f"""
<div class="chain-arrow" style="opacity:{opacity}">
<div class="arrow-line"></div>
<div class="arrow-label">{label}</div>
<div class="arrow-line"></div>
<div class="arrow-tip"></div>
</div>"""
html = CHAIN_CSS + '<div class="hicotraj-chain">'
html += card("s1", "Stage 1", "Factual Feature Extraction", s1_content, s1_active)
html += arrow("behavioral abstraction", s2_active)
html += card("s2", "Stage 2", "Behavioral Pattern Analysis", s2_content, s2_active)
html += arrow("demographic inference", s3_active)
html += card("s3", "Stage 3", "Demographic Inference", s3_content, s3_active)
html += "</div>"
return html
# ── Map & demo ────────────────────────────────────────────────────────────────
def build_map(agent_sp):
agent_sp = agent_sp.reset_index(drop=True).copy()
agent_sp["latitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
agent_sp["longitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
lat = agent_sp["latitude"].mean()
lon = agent_sp["longitude"].mean()
m = folium.Map(location=[lat, lon], zoom_start=12, tiles="CartoDB positron")
coords = list(zip(agent_sp["latitude"], agent_sp["longitude"]))
if len(coords) > 1:
folium.PolyLine(coords, color="#cc000055", weight=1.5, opacity=0.4).add_to(m)
n = len(agent_sp)
for i, row in agent_sp.iterrows():
ratio = i / max(n - 1, 1)
r = int(255 - ratio * (255 - 139))
g = int(204 * (1 - ratio) ** 2)
b = 0
color = f"#{r:02x}{g:02x}{b:02x}"
folium.CircleMarker(
location=[row["latitude"], row["longitude"]],
radius=7, color=color, fill=True, fill_color=color, fill_opacity=0.9,
popup=folium.Popup(
f"<b>#{i+1} {row['name']}</b><br>"
f"{row['start_datetime'].strftime('%a %m/%d %H:%M')}<br>"
f"{int(row['duration_min'])} min<br>{row['act_label']}",
max_width=220
)
).add_to(m)
m.get_root().width = "100%"
m.get_root().height = "420px"
return m._repr_html_()
def build_demo_text(row):
age = int(row["age"]) if row["age"] > 0 else "Unknown"
return (
f"Age: {age} | "
f"Sex: {SEX_MAP.get(int(row['sex']), row['sex'])} | "
f"Race: {RACE_MAP.get(int(row['race']), row['race'])} | "
f"Education: {EDU_MAP.get(int(row['education']), row['education'])} | "
f"Income: {INC_MAP.get(int(row['hh_income']), row['hh_income'])}"
)
# ── Callbacks ─────────────────────────────────────────────────────────────────
def on_select(agent_id):
agent_id = int(agent_id)
agent_sp = sp[sp["agent_id"] == agent_id].sort_values("start_datetime")
agent_demo = demo[demo["agent_id"] == agent_id].iloc[0]
map_html = build_map(agent_sp)
demo_text = build_demo_text(agent_demo)
raw_text = build_mobility_summary(agent_sp) + "\n\n" + build_weekly_checkin(agent_sp)
chain_html = render_chain(status="idle")
return map_html, raw_text, demo_text, chain_html
def run_inference(agent_id, hf_token):
if not hf_token or not hf_token.strip():
yield render_chain(s3_text="⚠️ Please enter your Hugging Face token first.", status="done")
return
agent_id = int(agent_id)
agent_sp = sp[sp["agent_id"] == agent_id].sort_values("start_datetime")
traj_text = build_mobility_summary(agent_sp) + "\n\n" + build_weekly_checkin(agent_sp)
try:
client = InferenceClient(token=hf_token.strip())
yield render_chain(status="running1")
s1 = call_llm(client, STEP1_SYSTEM, traj_text, max_tokens=400)
yield render_chain(s1_text=s1, status="running2")
s2_input = f"Features:\n{s1}\n\nNow analyze behavioral patterns."
s2 = call_llm(client, STEP2_SYSTEM, s2_input, max_tokens=300)
yield render_chain(s1_text=s1, s2_text=s2, status="running3")
s3_input = f"Features:\n{s1}\n\nBehavioral analysis:\n{s2}\n\nNow infer income."
s3 = call_llm(client, STEP3_SYSTEM, s3_input, max_tokens=300)
yield render_chain(s1_text=s1, s2_text=s2, s3_text=s3, status="done")
except Exception as e:
yield render_chain(s3_text=f"❌ Error: {str(e)}", status="done")
def call_llm(client, system_prompt, user_content, max_tokens=400):
response = client.chat.completions.create(
model=MODEL_ID,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
max_tokens=max_tokens,
temperature=0.3,
)
return response.choices[0].message.content.strip()
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="HiCoTraj Demo", theme=gr.themes.Soft()) as app:
gr.Markdown("## HiCoTraj β€” Trajectory Visualization & Hierarchical CoT Demo")
gr.Markdown("*Zero-Shot Demographic Reasoning via Hierarchical Chain-of-Thought Prompting from Trajectory*")
with gr.Row():
hf_token_box = gr.Textbox(
label="Hugging Face Token",
placeholder="hf_...",
type="password",
scale=2
)
with gr.Row():
agent_dd = gr.Dropdown(
choices=[str(a) for a in sample_agents],
label="Select Agent",
value=str(sample_agents[0]),
scale=1
)
demo_label = gr.Textbox(
label="Ground Truth Demographics",
interactive=False,
scale=4
)
with gr.Row():
# LEFT: map + NUMOSIM data
with gr.Column(scale=1):
gr.Markdown("### Trajectory Map")
map_out = gr.HTML()
gr.Markdown("### NUMOSIM Raw Data")
raw_out = gr.Textbox(
lines=25, interactive=False,
label="Mobility Summary + Weekly Check-in"
)
# RIGHT: reasoning chain
with gr.Column(scale=1):
gr.Markdown("### Hierarchical Chain-of-Thought Reasoning")
run_btn = gr.Button("β–Ά Run HiCoTraj Inference", variant="primary")
chain_out = gr.HTML(value=render_chain(status="idle"))
agent_dd.change(
fn=on_select, inputs=agent_dd,
outputs=[map_out, raw_out, demo_label, chain_out]
)
app.load(
fn=on_select, inputs=agent_dd,
outputs=[map_out, raw_out, demo_label, chain_out]
)
run_btn.click(
fn=run_inference,
inputs=[agent_dd, hf_token_box],
outputs=[chain_out]
)
if __name__ == "__main__":
app.launch()