import gradio as gr
import pandas as pd
import folium
import numpy as np
import os
import re
import json
BASE = os.path.dirname(os.path.abspath(__file__)) if "__file__" in dir() else os.getcwd()
STAY_POINTS = os.path.join(BASE, "data", "stay_points_inference_sample.csv")
POI_PATH = os.path.join(BASE, "data", "poi_inference_sample.csv")
DEMO_PATH = os.path.join(BASE, "data", "demographics_inference_sample.csv")
COT_PATH = os.path.join(BASE, "data", "inference_results_sample.json")
SEX_MAP = {1:"Male", 2:"Female", -8:"Unknown", -7:"Prefer not to answer"}
EDU_MAP = {1:"Less than HS", 2:"HS Graduate/GED", 3:"Some College/Associate",
4:"Bachelor's Degree", 5:"Graduate/Professional Degree",
-1:"N/A", -7:"Prefer not to answer", -8:"Unknown"}
INC_MAP = {1:"<$10,000", 2:"$10,000–$14,999", 3:"$15,000–$24,999",
4:"$25,000–$34,999", 5:"$35,000–$49,999", 6:"$50,000–$74,999",
7:"$75,000–$99,999", 8:"$100,000–$124,999", 9:"$125,000–$149,999",
10:"$150,000–$199,999", 11:"$200,000+",
-7:"Prefer not to answer", -8:"Unknown", -9:"Not ascertained"}
RACE_MAP = {1:"White", 2:"Black or African American", 3:"Asian",
4:"American Indian or Alaska Native",
5:"Native Hawaiian or Other Pacific Islander",
6:"Multiple races", 97:"Other",
-7:"Prefer not to answer", -8:"Unknown"}
ACT_MAP = {0:"Transportation", 1:"Home", 2:"Work", 3:"School", 4:"ChildCare",
5:"BuyGoods", 6:"Services", 7:"EatOut", 8:"Errands", 9:"Recreation",
10:"Exercise", 11:"Visit", 12:"HealthCare", 13:"Religious",
14:"SomethingElse", 15:"DropOff"}
print("Loading data...")
sp = pd.read_csv(STAY_POINTS)
poi = pd.read_csv(POI_PATH)
demo = pd.read_csv(DEMO_PATH)
sp = sp.merge(poi, on="poi_id", how="left")
sp["start_datetime"] = pd.to_datetime(sp["start_datetime"], utc=True)
sp["end_datetime"] = pd.to_datetime(sp["end_datetime"], utc=True)
sp["duration_min"] = ((sp["end_datetime"] - sp["start_datetime"]).dt.total_seconds() / 60).round(1)
def parse_act_types(x):
try:
codes = list(map(int, str(x).strip("[]").split()))
return ", ".join(ACT_MAP.get(c, str(c)) for c in codes)
except:
return str(x)
sp["act_label"] = sp["act_types"].apply(parse_act_types)
# Load CoT JSON (optional)
cot_by_agent = {}
if os.path.exists(COT_PATH):
with open(COT_PATH, "r") as f:
cot_raw = json.load(f)
# Support both list and {"inference_results": [...]} formats
records = cot_raw if isinstance(cot_raw, list) else cot_raw.get("inference_results", [])
for result in records:
cot_by_agent[int(result["agent_id"])] = result
print(f"Loaded CoT for {len(cot_by_agent)} agents.")
sample_agents = sorted(sp["agent_id"].unique().tolist())
print(f"Ready. {len(sample_agents)} agents loaded.")
def get_cot(agent_id):
result = cot_by_agent.get(int(agent_id), {})
s1 = result.get("step1_response", "")
s2 = result.get("step2_response", "")
s3 = result.get("step3_response", "")
return s1, s2, s3
# ── Mobility text builders ────────────────────────────────────────────────────
def build_mobility_summary(agent_sp):
top5 = (agent_sp.groupby("name")["duration_min"]
.agg(visits="count", avg_dur="mean")
.sort_values("visits", ascending=False)
.head(5))
obs_start = agent_sp["start_datetime"].min().strftime("%Y-%m-%d")
obs_end = agent_sp["end_datetime"].max().strftime("%Y-%m-%d")
days = (agent_sp["end_datetime"].max() - agent_sp["start_datetime"].min()).days
# Top activity types
act_counts = agent_sp["act_label"].value_counts().head(3)
top_acts = ", ".join(f"{a} ({n})" for a, n in act_counts.items())
# Time of day
agent_sp2 = agent_sp.copy()
agent_sp2["hour"] = agent_sp2["start_datetime"].dt.hour
def tod(h):
if 5 <= h < 12: return "Morning"
if 12 <= h < 17: return "Afternoon"
if 17 <= h < 21: return "Evening"
return "Night"
agent_sp2["tod"] = agent_sp2["hour"].apply(tod)
peak_tod = agent_sp2["tod"].value_counts().idxmax()
agent_sp2["is_weekend"] = agent_sp2["start_datetime"].dt.dayofweek >= 5
wd_pct = int((~agent_sp2["is_weekend"]).mean() * 100)
lines = [
f"Period: {obs_start} ~ {obs_end} ({days} days)",
f"Stay points: {len(agent_sp)} | Unique locations: {agent_sp['name'].nunique()}",
f"Weekday/Weekend: {wd_pct}% / {100-wd_pct}% | Peak time: {peak_tod}",
f"Top activities: {top_acts}",
"",
"Top Locations:",
]
for i, (name, row) in enumerate(top5.iterrows(), 1):
lines.append(f" {i}. {name} — {int(row['visits'])} visits, avg {int(row['avg_dur'])} min")
return "\n".join(lines)
def build_weekly_checkin(agent_sp, max_days=None):
agent_sp2 = agent_sp.copy()
agent_sp2["date"] = agent_sp2["start_datetime"].dt.date
all_dates = sorted(agent_sp2["date"].unique())
dates_to_show = all_dates[:max_days] if max_days else all_dates
total_days = len(all_dates)
lines = ["WEEKLY CHECK-IN SUMMARY", "======================="]
for date in dates_to_show:
grp = agent_sp2[agent_sp2["date"] == date]
dow = grp["start_datetime"].iloc[0].strftime("%A")
label = "Weekend" if grp["start_datetime"].iloc[0].dayofweek >= 5 else "Weekday"
lines.append(f"\n--- {dow}, {date} ({label}) ---")
lines.append(f"Total activities: {len(grp)}")
for _, row in grp.iterrows():
lines.append(
f"- {row['start_datetime'].strftime('%H:%M')}-"
f"{row['end_datetime'].strftime('%H:%M')} "
f"({int(row['duration_min'])} mins): "
f"{row['name']} - {row['act_label']}"
)
if max_days and total_days > max_days:
lines.append(f"\n... ({total_days - max_days} more days)")
return "\n".join(lines)
# ── HTML reasoning chain ──────────────────────────────────────────────────────
# ── Paste this entire block into app.py, replacing the existing CHAIN_CSS, render_chain, and helper functions ──
import re
CHAIN_CSS = """
"""
def _loading(msg):
return f'
{msg}
'
def _parse_s1(text):
"""Returns (locations, tod, wk, dist)"""
locations = []
dur_map = {}
tod = {}
wk = {}
dist = None
for line in text.splitlines():
s = line.strip()
# Location inventory: "- Name: N visits, ..."
m = re.match(r'-\s+(.+?):\s+(\d+)\s+visit', s, re.IGNORECASE)
if m:
locations.append((m.group(1).strip(), int(m.group(2))))
# Duration: "- LocationName: Average duration of X minutes"
m2 = re.match(r'-?\s*(.+?):\s+Average duration of ([\d.]+)\s+min', s, re.IGNORECASE)
if m2:
dur_map[m2.group(1).strip()] = float(m2.group(2))
# TOD: "65% morning, 23% afternoon, 6% evening, 5% night"
if not tod:
m3 = re.search(r'(\d+)%\s*morning.*?(\d+)%\s*afternoon.*?(\d+)%\s*evening.*?(\d+)%\s*night', s, re.IGNORECASE)
if m3:
tod = {'Morning': int(m3.group(1)), 'Afternoon': int(m3.group(2)),
'Evening': int(m3.group(3)), 'Night': int(m3.group(4))}
# Weekday/weekend
if not wk:
m4 = re.search(r'(\d+)%\s*weekday.*?(\d+)%\s*weekend', s, re.IGNORECASE)
if m4:
wk = {'Weekday': int(m4.group(1)), 'Weekend': int(m4.group(2))}
# Distance
if not dist:
m5 = re.search(r'average distance of approximately ([\d.]+)\s*miles', s, re.IGNORECASE)
if m5:
dist = float(m5.group(1))
result_locs = [(n, v, dur_map.get(n)) for n, v in locations[:7]]
return result_locs, tod, wk, dist
def _parse_s2(text):
"""Returns dict: ROUTINE, ECONOMIC, SOCIAL, URBAN, STABILITY → short summary string"""
DIMS = {
'ROUTINE': ['ROUTINE', 'SCHEDULE'],
'ECONOMIC': ['ECONOMIC', 'SPENDING'],
'SOCIAL': ['SOCIAL', 'LIFESTYLE'],
'URBAN': ['URBAN', 'COMMUNITY'],
'STABILITY': ['STABILITY', 'REGULARITY', 'CONSISTENCY'],
}
sections = {}
current_key = None
current_lines = []
for line in text.splitlines():
s = line.strip()
# Format A: "1. TITLE ANALYSIS:" or "2. ECONOMIC BEHAVIOR PATTERNS:"
mA = re.match(r'^\d+\.\s+([A-Z][A-Z\s&]+?)(?:\s+ANALYSIS|\s+PATTERNS|\s+INDICATORS|\s+CHARACTERISTICS|\s+STABILITY)?:\s*$', s, re.IGNORECASE)
# Format B: "STEP 1: ROUTINE & SCHEDULE ANALYSIS"
mB = re.match(r'^STEP\s+\d+:\s+([A-Z][A-Z\s&]+?)(?:\s+ANALYSIS|\s+PATTERNS|\s+INDICATORS|\s+CHARACTERISTICS|\s+STABILITY)?\s*$', s, re.IGNORECASE)
mm = mA or mB
if mm:
if current_key and current_lines:
sections[current_key] = ' '.join(current_lines)
current_key = mm.group(1).upper().strip()
current_lines = []
elif current_key and s:
if re.match(r'^\d+\.\d+', s):
sub = re.sub(r'^\d+\.\d+[^:]*:\s*', '', s)
if sub:
current_lines.append(sub)
elif s.startswith('-'):
current_lines.append(s.lstrip('-').strip())
elif not re.match(r'^\d+\.', s):
current_lines.append(s)
if current_key and current_lines:
sections[current_key] = ' '.join(current_lines)
result = {}
for dim, keywords in DIMS.items():
for k, txt in sections.items():
if any(kw in k for kw in keywords) and txt:
sents = re.split(r'(?<=[.!?])\s+', txt.strip())
summary = ' '.join(sents[:2])
if len(summary) > 160:
summary = summary[:157] + '…'
result[dim] = summary
break
return result
def _parse_s3(text):
pred, conf, reasoning = '', 0, ''
in_r = False
r_lines = []
for line in text.splitlines():
s = line.strip()
if s.startswith('INCOME_PREDICTION:'):
pred = s.replace('INCOME_PREDICTION:', '').strip()
elif s.startswith('INCOME_CONFIDENCE:'):
try:
conf = int(re.search(r'\d+', s).group())
except:
conf = 0
elif s.startswith('INCOME_REASONING:'):
in_r = True
r_lines.append(s.replace('INCOME_REASONING:', '').strip())
elif in_r:
if re.match(r'^2\.', s) or s.startswith('INCOME_'):
break
if s:
r_lines.append(s)
reasoning = ' '.join(r_lines).strip()
sents = re.split(r'(?<=[.!?])\s+', reasoning)
reasoning = ' '.join(sents[:3])
if len(reasoning) > 280:
reasoning = reasoning[:277] + '…'
return pred, conf, reasoning
def _s1_body(text, active):
if not active:
return '
Press ▶ to start
'
if not text:
return _loading('Extracting features')
locs, tod, wk, dist = _parse_s1(text)
# Location table
max_v = max((v for _, v, _ in locs), default=1)
rows = ''
for name, visits, dur in locs:
bar_w = int(60 * visits / max_v)
dur_str = f'{int(dur)}m' if dur else '—'
rows += (
f'
'
f'
{name}
'
f'
'
f''
f'{visits}
'
f'
{dur_str}
'
f'
'
)
table = (
f'
'
f'
Location
Visits
Avg Stay
'
f'{rows}'
f'
'
) if rows else ''
# Temporal panels
def seg_bar(data, seg_classes):
total = sum(data.values()) or 1
segs = ''.join(
f''
for (label, v), cls in zip(data.items(), seg_classes)
)
legend = ''.join(
f'
{label} {v}%
'
for (label, v), cls in zip(data.items(), seg_classes)
)
return f'
{segs}
{legend}
'
tod_panel = ''
if tod:
tod_panel = (
f'
'
f'
Time of Day
'
f'{seg_bar(tod, ["seg-morning","seg-afternoon","seg-evening","seg-night"])}'
f'
'
)
wk_panel = ''
if wk:
wk_panel = (
f'
'
f'
Weekday / Weekend
'
f'{seg_bar(wk, ["seg-weekday","seg-weekend"])}'
f'
'
)
temporal = f'
{tod_panel}{wk_panel}
' if (tod_panel or wk_panel) else ''
dist_line = ''
if dist:
dist_line = f'
📍 Avg trip distance {dist} mi
'
return table + temporal + dist_line
def _s2_body(text, active):
if not active:
return '
Waiting…
'
if not text:
return _loading('Analyzing behavior')
dims = _parse_s2(text)
DIM_META = [
('ROUTINE', '🕐', 'Schedule'),
('ECONOMIC', '💰', 'Economic'),
('SOCIAL', '👥', 'Social'),
('STABILITY', '🔄', 'Stability'),
]
# fallback to URBAN if STABILITY missing
if 'STABILITY' not in dims and 'URBAN' in dims:
dims['STABILITY'] = dims['URBAN']
cards = ''
for key, icon, label in DIM_META:
txt = dims.get(key, '')
content = f'
{txt}
' if txt else '
—
'
cards += (
f'
'
f'
'
f'{icon}'
f'{label}'
f'
'
f'{content}'
f'
'
)
return f'
{cards}
'
def _s3_body(text, active):
if not active:
return '
Waiting…
'
if not text:
return _loading('Inferring demographics')
pred, conf, reasoning = _parse_s3(text)
conf_pct = int(conf / 5 * 100)
return (
f'
'
f'
'
f'
{pred or "—"}
'
f'
Income
'
f'
'
f'
'
f'
Confidence {conf}/5
'
f'
'
f'
'
f'
'
f'
{reasoning}
'
)
def render_chain(s1_text, s2_text, s3_text, status="idle"):
s1_on = status in ("running1", "running2", "running3", "done")
s2_on = status in ("running2", "running3", "done")
s3_on = status in ("running3", "done")
# For "running" states the text may be empty → show loading dots
s1_body = _s1_body(s1_text if s1_on else '', s1_on)
s2_body = _s2_body(s2_text if s2_on else '', s2_on)
s3_body = _s3_body(s3_text if s3_on else '', s3_on)
def stage(cls, num, title, body, on):
dim_cls = 'active' if on else 'dim'
return (
f'
'
f'
'
f'{num}'
f'{title}'
f'
'
f'
{body}
'
f'
'
)
def arrow(label, on):
op = '1' if on else '0.2'
return (
f'
'
f''
f'
{label}
'
f''
f'
'
)
html = CHAIN_CSS + '
'
html += stage('s1', 'Stage 01', 'Feature Extraction', s1_body, s1_on)
html += arrow('behavioral abstraction', s2_on)
html += stage('s2', 'Stage 02', 'Behavioral Analysis', s2_body, s2_on)
html += arrow('demographic inference', s3_on)
html += stage('s3', 'Stage 03', 'Demographic Inference', s3_body, s3_on)
html += '
'
return html
def build_map(agent_sp):
agent_sp = agent_sp.reset_index(drop=True).copy()
agent_sp["latitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
agent_sp["longitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
lat = agent_sp["latitude"].mean()
lon = agent_sp["longitude"].mean()
m = folium.Map(location=[lat, lon], zoom_start=12, tiles="CartoDB positron")
coords = list(zip(agent_sp["latitude"], agent_sp["longitude"]))
if len(coords) > 1:
folium.PolyLine(coords, color="#cc000055", weight=1.5, opacity=0.4).add_to(m)
n = len(agent_sp)
for i, row in agent_sp.iterrows():
ratio = i / max(n - 1, 1)
r = int(255 - ratio * (255 - 139))
g = int(204 * (1 - ratio) ** 2)
b = 0
color = f"#{r:02x}{g:02x}{b:02x}"
folium.CircleMarker(
location=[row["latitude"], row["longitude"]],
radius=7, color=color, fill=True, fill_color=color, fill_opacity=0.9,
popup=folium.Popup(
f"#{i+1} {row['name']} "
f"{row['start_datetime'].strftime('%a %m/%d %H:%M')} "
f"{int(row['duration_min'])} min {row['act_label']}",
max_width=220
)
).add_to(m)
legend_html = """