Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import folium | |
| import numpy as np | |
| import os | |
| import re | |
| from huggingface_hub import InferenceClient | |
| BASE = os.path.dirname(os.path.abspath(__file__)) | |
| STAY_POINTS = os.path.join(BASE, "data", "stay_points_sampled.csv") | |
| POI_PATH = os.path.join(BASE, "data", "poi_sampled.csv") | |
| DEMO_PATH = os.path.join(BASE, "data", "demographics_sampled.csv") | |
| MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" | |
| SEX_MAP = {1:"Male", 2:"Female", -8:"Unknown", -7:"Prefer not to answer"} | |
| EDU_MAP = {1:"Less than HS", 2:"HS Graduate/GED", 3:"Some College/Associate", | |
| 4:"Bachelor's Degree", 5:"Graduate/Professional Degree", | |
| -1:"N/A", -7:"Prefer not to answer", -8:"Unknown"} | |
| INC_MAP = {1:"<$10,000", 2:"$10,000β$14,999", 3:"$15,000β$24,999", | |
| 4:"$25,000β$34,999", 5:"$35,000β$49,999", 6:"$50,000β$74,999", | |
| 7:"$75,000β$99,999", 8:"$100,000β$124,999", 9:"$125,000β$149,999", | |
| 10:"$150,000β$199,999", 11:"$200,000+", | |
| -7:"Prefer not to answer", -8:"Unknown", -9:"Not ascertained"} | |
| RACE_MAP = {1:"White", 2:"Black or African American", 3:"Asian", | |
| 4:"American Indian or Alaska Native", | |
| 5:"Native Hawaiian or Other Pacific Islander", | |
| 6:"Multiple races", 97:"Other", | |
| -7:"Prefer not to answer", -8:"Unknown"} | |
| ACT_MAP = {0:"Transportation", 1:"Home", 2:"Work", 3:"School", 4:"ChildCare", | |
| 5:"BuyGoods", 6:"Services", 7:"EatOut", 8:"Errands", 9:"Recreation", | |
| 10:"Exercise", 11:"Visit", 12:"HealthCare", 13:"Religious", | |
| 14:"SomethingElse", 15:"DropOff"} | |
| print("Loading data...") | |
| sp = pd.read_csv(STAY_POINTS) | |
| poi = pd.read_csv(POI_PATH) | |
| demo = pd.read_csv(DEMO_PATH) | |
| sp = sp.merge(poi, on="poi_id", how="left") | |
| sp["start_datetime"] = pd.to_datetime(sp["start_datetime"], utc=True) | |
| sp["end_datetime"] = pd.to_datetime(sp["end_datetime"], utc=True) | |
| sp["duration_min"] = ((sp["end_datetime"] - sp["start_datetime"]).dt.total_seconds() / 60).round(1) | |
| def parse_act_types(x): | |
| try: | |
| codes = list(map(int, str(x).strip("[]").split())) | |
| return ", ".join(ACT_MAP.get(c, str(c)) for c in codes) | |
| except: | |
| return str(x) | |
| sp["act_label"] = sp["act_types"].apply(parse_act_types) | |
| sample_agents = sorted(sp["agent_id"].unique().tolist()) | |
| print(f"Ready. {len(sample_agents)} agents loaded.") | |
| # ββ Mobility text builders ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_mobility_summary(agent_sp): | |
| top5 = (agent_sp.groupby("name")["duration_min"] | |
| .agg(visits="count", avg_dur="mean") | |
| .sort_values("visits", ascending=False) | |
| .head(5)) | |
| obs_start = agent_sp["start_datetime"].min().strftime("%Y-%m-%d") | |
| obs_end = agent_sp["end_datetime"].max().strftime("%Y-%m-%d") | |
| days = (agent_sp["end_datetime"].max() - agent_sp["start_datetime"].min()).days | |
| lines = [ | |
| "MOBILITY TRAJECTORY DATA", | |
| "===========================", | |
| f"Observation Period: {obs_start} to {obs_end} ({days} days)", | |
| f"Total Stay Points: {len(agent_sp)}", | |
| f"Unique Locations: {agent_sp['name'].nunique()}", | |
| "", | |
| "LOCATION PATTERNS", | |
| "----------------", | |
| ] | |
| for i, (name, row) in enumerate(top5.iterrows(), 1): | |
| lines += [f"{i}. {name}", | |
| f" Visits: {int(row['visits'])} times", | |
| f" Average Duration: {int(row['avg_dur'])} minutes", ""] | |
| agent_sp2 = agent_sp.copy() | |
| agent_sp2["hour"] = agent_sp2["start_datetime"].dt.hour | |
| def tod(h): | |
| if 5 <= h < 12: return "morning" | |
| if 12 <= h < 17: return "afternoon" | |
| if 17 <= h < 21: return "evening" | |
| return "night" | |
| agent_sp2["tod"] = agent_sp2["hour"].apply(tod) | |
| tod_pct = (agent_sp2["tod"].value_counts(normalize=True) * 100).round(0).astype(int) | |
| agent_sp2["is_weekend"] = agent_sp2["start_datetime"].dt.dayofweek >= 5 | |
| wd_pct = int((~agent_sp2["is_weekend"]).mean() * 100) | |
| lines += ["TEMPORAL PATTERNS", "----------------", "Activity by Time of Day:"] | |
| for k, v in tod_pct.items(): | |
| lines.append(f"- {k}: {v}%") | |
| lines += ["", "Weekday vs Weekend:", | |
| f"- weekday: {wd_pct}%", f"- weekend: {100 - wd_pct}%"] | |
| return "\n".join(lines) | |
| def build_weekly_checkin(agent_sp): | |
| lines = ["WEEKLY CHECK-IN SUMMARY", "======================="] | |
| agent_sp2 = agent_sp.copy() | |
| agent_sp2["date"] = agent_sp2["start_datetime"].dt.date | |
| for date, grp in agent_sp2.groupby("date"): | |
| dow = grp["start_datetime"].iloc[0].strftime("%A") | |
| label = "Weekend" if grp["start_datetime"].iloc[0].dayofweek >= 5 else "Weekday" | |
| lines.append(f"\n--- {dow}, {date} ({label}) ---") | |
| lines.append(f"Total activities: {len(grp)}") | |
| for _, row in grp.iterrows(): | |
| lines.append( | |
| f"- {row['start_datetime'].strftime('%H:%M')}-" | |
| f"{row['end_datetime'].strftime('%H:%M')} " | |
| f"({int(row['duration_min'])} mins): " | |
| f"{row['name']} - {row['act_label']}" | |
| ) | |
| return "\n".join(lines) | |
| # ββ Prompts βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| STEP1_SYSTEM = """You are an expert mobility analyst. Extract objective features from the trajectory data. | |
| Respond with EXACTLY this structure, keep each point to one short sentence: | |
| LOCATION INVENTORY: | |
| - Top venues: [list top 3 with visit counts] | |
| - Price level: [budget/mid-range/high-end mix] | |
| - Neighborhood: [residential/commercial/urban/suburban] | |
| TEMPORAL PATTERNS: | |
| - Active hours: [time range] | |
| - Weekday/Weekend: [ratio] | |
| - Routine: [consistent/variable] | |
| SEQUENCE: | |
| - Typical chain: [e.g. Home β Work β Home] | |
| - Notable pattern: [one observation] | |
| Do NOT interpret or infer demographics. Be concise.""" | |
| STEP2_SYSTEM = """You are an expert mobility analyst. Based on the extracted features, analyze behavioral patterns. | |
| Respond with EXACTLY this structure, one short sentence per point: | |
| SCHEDULE: [fixed/flexible/shift β one sentence] | |
| ECONOMIC: [budget/mid-range/premium spending β one sentence] | |
| SOCIAL: [family/individual/community focus β one sentence] | |
| LIFESTYLE: [urban professional/suburban/student/other β one sentence] | |
| STABILITY: [routine consistency β one sentence] | |
| Do NOT make income predictions yet. Be concise.""" | |
| STEP3_SYSTEM = """You are an expert mobility analyst performing final income inference. | |
| Based on the trajectory features and behavioral analysis, output EXACTLY: | |
| INCOME_PREDICTION: [Very Low (<$15k) | Low ($15k-$35k) | Middle ($35k-$75k) | Upper-Middle ($75k-$125k) | High ($125k-$200k) | Very High (>$200k)] | |
| INCOME_CONFIDENCE: [1-5] | |
| INCOME_REASONING: [2-3 sentences linking specific mobility evidence to the prediction] | |
| ALTERNATIVES: [2nd most likely] | [3rd most likely]""" | |
| def call_llm(client, system_prompt, user_content, max_tokens=400): | |
| response = client.chat.completions.create( | |
| model=MODEL_ID, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_content}, | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=0.3, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| # ββ HTML rendering ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CHAIN_CSS = """ | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&family=IBM+Plex+Sans:wght@300;400;600&display=swap'); | |
| .hicotraj-chain { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| padding: 12px 4px; | |
| max-width: 100%; | |
| } | |
| /* Stage cards */ | |
| .stage-card { | |
| border-radius: 10px; | |
| padding: 16px 18px; | |
| margin-bottom: 0; | |
| position: relative; | |
| transition: box-shadow 0.3s; | |
| } | |
| .stage-card.dim { opacity: 0.35; filter: grayscale(0.4); } | |
| .stage-card.active { box-shadow: 0 4px 20px rgba(0,0,0,0.12); opacity: 1; filter: none; } | |
| .stage-card.s1 { background: #f8f9fc; border: 1.5px solid #c8d0e0; } | |
| .stage-card.s2 { background: #fdf6f0; border: 1.5px solid #e8c9a8; } | |
| .stage-card.s3 { background: #fff8f8; border: 2px solid #c0392b; } | |
| .stage-header { | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| margin-bottom: 10px; | |
| } | |
| .stage-badge { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 10px; | |
| font-weight: 600; | |
| letter-spacing: 0.08em; | |
| padding: 3px 8px; | |
| border-radius: 4px; | |
| text-transform: uppercase; | |
| } | |
| .s1 .stage-badge { background: #dde3f0; color: #3a4a6b; } | |
| .s2 .stage-badge { background: #f0dcc8; color: #7a4010; } | |
| .s3 .stage-badge { background: #c0392b; color: #fff; } | |
| .stage-title { | |
| font-size: 13px; | |
| font-weight: 600; | |
| color: #1a1a2e; | |
| } | |
| /* Content inside cards */ | |
| .tag-row { display: flex; flex-wrap: wrap; gap: 6px; margin-top: 4px; } | |
| .tag { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 11px; | |
| background: #e8ecf5; | |
| color: #2c3e60; | |
| padding: 3px 8px; | |
| border-radius: 4px; | |
| white-space: nowrap; | |
| } | |
| .s2 .tag { background: #f5e8d8; color: #6b3a10; } | |
| .behavior-row { | |
| display: grid; | |
| grid-template-columns: 100px 1fr; | |
| gap: 4px 10px; | |
| margin-top: 2px; | |
| font-size: 12px; | |
| line-height: 1.5; | |
| } | |
| .bkey { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 11px; | |
| font-weight: 600; | |
| color: #9b6a3a; | |
| padding-top: 1px; | |
| } | |
| .bval { color: #3a2a1a; } | |
| /* Prediction block */ | |
| .pred-block { margin-top: 8px; } | |
| .pred-label { | |
| font-size: 11px; | |
| font-family: 'IBM Plex Mono', monospace; | |
| color: #888; | |
| text-transform: uppercase; | |
| letter-spacing: 0.06em; | |
| margin-bottom: 4px; | |
| } | |
| .pred-value { | |
| font-size: 22px; | |
| font-weight: 600; | |
| color: #c0392b; | |
| letter-spacing: -0.01em; | |
| margin-bottom: 8px; | |
| } | |
| .confidence-bar-wrap { | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| margin-bottom: 10px; | |
| } | |
| .confidence-bar-bg { | |
| flex: 1; | |
| height: 6px; | |
| background: #f0d0cf; | |
| border-radius: 3px; | |
| overflow: hidden; | |
| } | |
| .confidence-bar-fill { | |
| height: 100%; | |
| background: linear-gradient(90deg, #e74c3c, #8b0000); | |
| border-radius: 3px; | |
| transition: width 0.8s ease; | |
| } | |
| .confidence-label { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 11px; | |
| color: #c0392b; | |
| font-weight: 600; | |
| white-space: nowrap; | |
| } | |
| .reasoning-text { | |
| font-size: 12px; | |
| color: #4a2a2a; | |
| line-height: 1.6; | |
| border-left: 3px solid #e8c0be; | |
| padding-left: 10px; | |
| margin-top: 6px; | |
| } | |
| .alternatives { | |
| margin-top: 10px; | |
| font-size: 11px; | |
| font-family: 'IBM Plex Mono', monospace; | |
| color: #999; | |
| } | |
| .alternatives span { color: #c0392b; opacity: 0.7; } | |
| /* Arrow connector */ | |
| .chain-arrow { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| margin: 0; | |
| padding: 4px 0; | |
| gap: 0; | |
| } | |
| .arrow-line { | |
| width: 2px; | |
| height: 18px; | |
| background: linear-gradient(180deg, #c8d0e0, #e8c9a8); | |
| } | |
| .arrow-label { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 10px; | |
| color: #aaa; | |
| letter-spacing: 0.06em; | |
| text-transform: uppercase; | |
| background: white; | |
| padding: 2px 8px; | |
| border: 1px solid #e0e0e0; | |
| border-radius: 10px; | |
| margin: 2px 0; | |
| } | |
| .arrow-tip { | |
| width: 0; height: 0; | |
| border-left: 5px solid transparent; | |
| border-right: 5px solid transparent; | |
| border-top: 7px solid #e8c9a8; | |
| } | |
| /* Waiting state */ | |
| .waiting-dot { | |
| display: inline-block; | |
| width: 7px; height: 7px; | |
| border-radius: 50%; | |
| background: #ccc; | |
| margin: 0 2px; | |
| animation: pulse 1.2s ease-in-out infinite; | |
| } | |
| .waiting-dot:nth-child(2) { animation-delay: 0.2s; } | |
| .waiting-dot:nth-child(3) { animation-delay: 0.4s; } | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 0.3; transform: scale(0.8); } | |
| 50% { opacity: 1; transform: scale(1.1); } | |
| } | |
| </style> | |
| """ | |
| def _waiting_dots(): | |
| return '<span class="waiting-dot"></span><span class="waiting-dot"></span><span class="waiting-dot"></span>' | |
| def render_chain(s1_text="", s2_text="", s3_text="", status="idle"): | |
| """ | |
| status: idle | running1 | running2 | running3 | done | |
| """ | |
| s1_active = status in ("running1", "running2", "running3", "done") | |
| s2_active = status in ("running2", "running3", "done") | |
| s3_active = status in ("running3", "done") | |
| # ββ Stage 1 content ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if status == "running1": | |
| s1_content = f'<div style="padding:8px 0; color:#888; font-size:13px;">Extracting features {_waiting_dots()}</div>' | |
| elif s1_text: | |
| # Parse tags from the response β pull out short bullet points as tags | |
| tags = [] | |
| for line in s1_text.splitlines(): | |
| line = line.strip().lstrip("-").strip() | |
| if line and len(line) < 60 and not line.endswith(":"): | |
| tags.append(line) | |
| if len(tags) >= 8: | |
| break | |
| tag_html = "".join(f'<span class="tag">{t}</span>' for t in tags[:8]) | |
| s1_content = f'<div class="tag-row">{tag_html}</div>' | |
| else: | |
| s1_content = '<div style="font-size:12px;color:#bbb;padding:6px 0;">Run inference to see results</div>' | |
| # ββ Stage 2 content ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BEHAVIOR_KEYS = ["SCHEDULE", "ECONOMIC", "SOCIAL", "LIFESTYLE", "STABILITY"] | |
| if status == "running2": | |
| s2_content = f'<div style="padding:8px 0; color:#a06030; font-size:13px;">Analyzing behavior {_waiting_dots()}</div>' | |
| elif s2_text: | |
| rows_html = "" | |
| for key in BEHAVIOR_KEYS: | |
| pattern = rf"{key}[:\s]+(.+)" | |
| m = re.search(pattern, s2_text, re.IGNORECASE) | |
| val = m.group(1).strip().rstrip(".") if m else "β" | |
| if len(val) > 80: | |
| val = val[:77] + "..." | |
| rows_html += f'<div class="bkey">{key}</div><div class="bval">{val}</div>' | |
| s2_content = f'<div class="behavior-row">{rows_html}</div>' | |
| else: | |
| s2_content = '<div style="font-size:12px;color:#bbb;padding:6px 0;">Run inference to see results</div>' | |
| # ββ Stage 3 content ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if status == "running3": | |
| s3_content = f'<div style="padding:8px 0; color:#c0392b; font-size:13px;">Inferring demographics {_waiting_dots()}</div>' | |
| elif s3_text: | |
| # Parse structured output | |
| pred = conf_raw = reasoning = alts = "" | |
| for line in s3_text.splitlines(): | |
| line = line.strip() | |
| if line.startswith("INCOME_PREDICTION:"): | |
| pred = line.replace("INCOME_PREDICTION:", "").strip() | |
| elif line.startswith("INCOME_CONFIDENCE:"): | |
| conf_raw = line.replace("INCOME_CONFIDENCE:", "").strip() | |
| elif line.startswith("INCOME_REASONING:"): | |
| reasoning = line.replace("INCOME_REASONING:", "").strip() | |
| elif line.startswith("ALTERNATIVES:"): | |
| alts = line.replace("ALTERNATIVES:", "").strip() | |
| # Confidence bar | |
| try: | |
| conf_int = int(re.search(r"\d", conf_raw).group()) | |
| except: | |
| conf_int = 3 | |
| bar_pct = conf_int * 20 | |
| alts_html = "" | |
| if alts: | |
| alts_html = f'<div class="alternatives">Also possible: <span>{alts}</span></div>' | |
| s3_content = f""" | |
| <div class="pred-block"> | |
| <div class="pred-label">Income Prediction</div> | |
| <div class="pred-value">{pred or "β"}</div> | |
| <div class="confidence-bar-wrap"> | |
| <div class="confidence-bar-bg"> | |
| <div class="confidence-bar-fill" style="width:{bar_pct}%"></div> | |
| </div> | |
| <div class="confidence-label">Confidence {conf_int}/5</div> | |
| </div> | |
| <div class="reasoning-text">{reasoning or s3_text[:200]}</div> | |
| {alts_html} | |
| </div>""" | |
| else: | |
| s3_content = '<div style="font-size:12px;color:#bbb;padding:6px 0;">Run inference to see results</div>' | |
| def card(cls, badge, title, content, active): | |
| dim_cls = "active" if active else "dim" | |
| return f""" | |
| <div class="stage-card {cls} {dim_cls}"> | |
| <div class="stage-header"> | |
| <span class="stage-badge">{badge}</span> | |
| <span class="stage-title">{title}</span> | |
| </div> | |
| {content} | |
| </div>""" | |
| def arrow(label, active): | |
| opacity = "1" if active else "0.3" | |
| return f""" | |
| <div class="chain-arrow" style="opacity:{opacity}"> | |
| <div class="arrow-line"></div> | |
| <div class="arrow-label">{label}</div> | |
| <div class="arrow-line"></div> | |
| <div class="arrow-tip"></div> | |
| </div>""" | |
| html = CHAIN_CSS + '<div class="hicotraj-chain">' | |
| html += card("s1", "Stage 1", "Factual Feature Extraction", s1_content, s1_active) | |
| html += arrow("behavioral abstraction", s2_active) | |
| html += card("s2", "Stage 2", "Behavioral Pattern Analysis", s2_content, s2_active) | |
| html += arrow("demographic inference", s3_active) | |
| html += card("s3", "Stage 3", "Demographic Inference", s3_content, s3_active) | |
| html += "</div>" | |
| return html | |
| # ββ Map & demo ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_map(agent_sp): | |
| agent_sp = agent_sp.reset_index(drop=True).copy() | |
| agent_sp["latitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp)) | |
| agent_sp["longitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp)) | |
| lat = agent_sp["latitude"].mean() | |
| lon = agent_sp["longitude"].mean() | |
| m = folium.Map(location=[lat, lon], zoom_start=12, tiles="CartoDB positron") | |
| coords = list(zip(agent_sp["latitude"], agent_sp["longitude"])) | |
| if len(coords) > 1: | |
| folium.PolyLine(coords, color="#cc000055", weight=1.5, opacity=0.4).add_to(m) | |
| n = len(agent_sp) | |
| for i, row in agent_sp.iterrows(): | |
| ratio = i / max(n - 1, 1) | |
| r = int(255 - ratio * (255 - 139)) | |
| g = int(204 * (1 - ratio) ** 2) | |
| b = 0 | |
| color = f"#{r:02x}{g:02x}{b:02x}" | |
| folium.CircleMarker( | |
| location=[row["latitude"], row["longitude"]], | |
| radius=7, color=color, fill=True, fill_color=color, fill_opacity=0.9, | |
| popup=folium.Popup( | |
| f"<b>#{i+1} {row['name']}</b><br>" | |
| f"{row['start_datetime'].strftime('%a %m/%d %H:%M')}<br>" | |
| f"{int(row['duration_min'])} min<br>{row['act_label']}", | |
| max_width=220 | |
| ) | |
| ).add_to(m) | |
| m.get_root().width = "100%" | |
| m.get_root().height = "420px" | |
| return m._repr_html_() | |
| def build_demo_text(row): | |
| age = int(row["age"]) if row["age"] > 0 else "Unknown" | |
| return ( | |
| f"Age: {age} | " | |
| f"Sex: {SEX_MAP.get(int(row['sex']), row['sex'])} | " | |
| f"Race: {RACE_MAP.get(int(row['race']), row['race'])} | " | |
| f"Education: {EDU_MAP.get(int(row['education']), row['education'])} | " | |
| f"Income: {INC_MAP.get(int(row['hh_income']), row['hh_income'])}" | |
| ) | |
| # ββ Callbacks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def on_select(agent_id): | |
| agent_id = int(agent_id) | |
| agent_sp = sp[sp["agent_id"] == agent_id].sort_values("start_datetime") | |
| agent_demo = demo[demo["agent_id"] == agent_id].iloc[0] | |
| map_html = build_map(agent_sp) | |
| demo_text = build_demo_text(agent_demo) | |
| raw_text = build_mobility_summary(agent_sp) + "\n\n" + build_weekly_checkin(agent_sp) | |
| chain_html = render_chain(status="idle") | |
| return map_html, raw_text, demo_text, chain_html | |
| def run_inference(agent_id, hf_token): | |
| if not hf_token or not hf_token.strip(): | |
| yield render_chain(s3_text="β οΈ Please enter your Hugging Face token first.", status="done") | |
| return | |
| agent_id = int(agent_id) | |
| agent_sp = sp[sp["agent_id"] == agent_id].sort_values("start_datetime") | |
| traj_text = build_mobility_summary(agent_sp) + "\n\n" + build_weekly_checkin(agent_sp) | |
| try: | |
| client = InferenceClient(token=hf_token.strip()) | |
| yield render_chain(status="running1") | |
| s1 = call_llm(client, STEP1_SYSTEM, traj_text, max_tokens=400) | |
| yield render_chain(s1_text=s1, status="running2") | |
| s2_input = f"Features:\n{s1}\n\nNow analyze behavioral patterns." | |
| s2 = call_llm(client, STEP2_SYSTEM, s2_input, max_tokens=300) | |
| yield render_chain(s1_text=s1, s2_text=s2, status="running3") | |
| s3_input = f"Features:\n{s1}\n\nBehavioral analysis:\n{s2}\n\nNow infer income." | |
| s3 = call_llm(client, STEP3_SYSTEM, s3_input, max_tokens=300) | |
| yield render_chain(s1_text=s1, s2_text=s2, s3_text=s3, status="done") | |
| except Exception as e: | |
| yield render_chain(s3_text=f"β Error: {str(e)}", status="done") | |
| def call_llm(client, system_prompt, user_content, max_tokens=400): | |
| response = client.chat.completions.create( | |
| model=MODEL_ID, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_content}, | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=0.3, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="HiCoTraj Demo", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("## HiCoTraj β Trajectory Visualization & Hierarchical CoT Demo") | |
| gr.Markdown("*Zero-Shot Demographic Reasoning via Hierarchical Chain-of-Thought Prompting from Trajectory*") | |
| with gr.Row(): | |
| hf_token_box = gr.Textbox( | |
| label="Hugging Face Token", | |
| placeholder="hf_...", | |
| type="password", | |
| scale=2 | |
| ) | |
| with gr.Row(): | |
| agent_dd = gr.Dropdown( | |
| choices=[str(a) for a in sample_agents], | |
| label="Select Agent", | |
| value=str(sample_agents[0]), | |
| scale=1 | |
| ) | |
| demo_label = gr.Textbox( | |
| label="Ground Truth Demographics", | |
| interactive=False, | |
| scale=4 | |
| ) | |
| with gr.Row(): | |
| # LEFT: map + NUMOSIM data | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Trajectory Map") | |
| map_out = gr.HTML() | |
| gr.Markdown("### NUMOSIM Raw Data") | |
| raw_out = gr.Textbox( | |
| lines=25, interactive=False, | |
| label="Mobility Summary + Weekly Check-in" | |
| ) | |
| # RIGHT: reasoning chain | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Hierarchical Chain-of-Thought Reasoning") | |
| run_btn = gr.Button("βΆ Run HiCoTraj Inference", variant="primary") | |
| chain_out = gr.HTML(value=render_chain(status="idle")) | |
| agent_dd.change( | |
| fn=on_select, inputs=agent_dd, | |
| outputs=[map_out, raw_out, demo_label, chain_out] | |
| ) | |
| app.load( | |
| fn=on_select, inputs=agent_dd, | |
| outputs=[map_out, raw_out, demo_label, chain_out] | |
| ) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[agent_dd, hf_token_box], | |
| outputs=[chain_out] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |