ginnyxxxxxxx commited on
Commit
07bd8a6
Β·
1 Parent(s): 6e0e306

label fix

Browse files
Files changed (1) hide show
  1. app.py +46 -26
app.py CHANGED
@@ -1,66 +1,84 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import folium
 
4
  import os
5
 
6
- # ── Data paths ───────────────────────────────────────────────────────────────
7
- BASE = os.path.dirname(os.path.abspath(__file__))
8
  STAY_POINTS = os.path.join(BASE, "data", "stay_points_sampled.csv")
9
  POI_PATH = os.path.join(BASE, "data", "poi_sampled.csv")
10
  DEMO_PATH = os.path.join(BASE, "data", "demographics_sampled.csv")
11
 
12
- AGE_MAP = {1:"<18", 2:"18-24", 3:"25-34", 4:"35-44", 5:"45-54", 6:"55-64", 7:"65+"}
13
- SEX_MAP = {1:"Male", 2:"Female"}
14
- EDU_MAP = {1:"No HS", 2:"HS Grad", 3:"Some College", 4:"Bachelor's", 5:"Graduate"}
15
- INC_MAP = {1:"<$10k", 2:"$10-15k", 3:"$15-25k", 4:"$25-35k", 5:"$35-50k",
16
- 6:"$50-75k", 7:"$75-100k", 8:"$100-125k", 9:"$125-150k", 10:">$150k"}
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # ── Load ─────────────────────────────────────────────────────────────────────
19
  print("Loading data...")
20
  sp = pd.read_csv(STAY_POINTS)
21
  poi = pd.read_csv(POI_PATH)
22
  demo = pd.read_csv(DEMO_PATH)
23
 
24
- sp = sp.merge(poi, on="poi_id", how="left")
25
  sp["start_datetime"] = pd.to_datetime(sp["start_datetime"], utc=True)
26
  sp["end_datetime"] = pd.to_datetime(sp["end_datetime"], utc=True)
27
  sp["duration_min"] = ((sp["end_datetime"] - sp["start_datetime"]).dt.total_seconds() / 60).round(1)
28
 
 
 
 
 
 
 
 
 
 
29
  sample_agents = sorted(sp["agent_id"].unique().tolist())
30
  print(f"Ready. {len(sample_agents)} agents loaded.")
31
 
32
- # ── Helpers ───────────────────────────────────────────────────────────────────
33
  def build_map(agent_sp):
34
- import numpy as np
35
- agent_sp = agent_sp.reset_index(drop=True)
36
  agent_sp["latitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
37
  agent_sp["longitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
38
-
39
  lat = agent_sp["latitude"].mean()
40
  lon = agent_sp["longitude"].mean()
41
  m = folium.Map(location=[lat, lon], zoom_start=12, tiles="CartoDB positron")
42
 
43
- n = len(agent_sp)
44
  coords = list(zip(agent_sp["latitude"], agent_sp["longitude"]))
45
-
46
  if len(coords) > 1:
47
  folium.PolyLine(coords, color="#aaaaaa", weight=1.5, opacity=0.4).add_to(m)
48
 
 
49
  for i, row in agent_sp.iterrows():
50
- # green β†’ red gradient
51
  ratio = i / max(n - 1, 1)
52
  r = int(255 * ratio)
53
  g = int(255 * (1 - ratio))
54
  color = f"#{r:02x}{g:02x}33"
55
-
56
  folium.CircleMarker(
57
  location=[row["latitude"], row["longitude"]],
58
  radius=7, color=color, fill=True, fill_color=color, fill_opacity=0.9,
59
  popup=folium.Popup(
60
  f"<b>#{i+1} {row['name']}</b><br>"
61
  f"{row['start_datetime'].strftime('%a %m/%d %H:%M')}<br>"
62
- f"{int(row['duration_min'])} min",
63
- max_width=200
 
64
  )
65
  ).add_to(m)
66
 
@@ -68,23 +86,26 @@ def build_map(agent_sp):
68
  m.get_root().height = "500px"
69
  return m._repr_html_()
70
 
 
71
  def build_poi_sequence(agent_sp):
72
  lines = []
73
  for _, row in agent_sp.iterrows():
74
  lines.append(
75
  f"{row['start_datetime'].strftime('%a %m/%d')} "
76
  f"{row['start_datetime'].strftime('%H:%M')}–{row['end_datetime'].strftime('%H:%M')} "
77
- f"({int(row['duration_min'])} min) | {row['name']} | {row['act_types']}"
78
  )
79
  return "\n".join(lines)
80
 
81
 
82
  def build_demo_text(row):
 
83
  return (
84
- f"Age: {AGE_MAP.get(row['age'], row['age'])} | "
85
- f"Sex: {SEX_MAP.get(row['sex'], row['sex'])} | "
86
- f"Education: {EDU_MAP.get(row['education'], row['education'])} | "
87
- f"Income: {INC_MAP.get(row['hh_income'], row['hh_income'])}"
 
88
  )
89
 
90
 
@@ -95,7 +116,6 @@ def on_select(agent_id):
95
  return build_map(agent_sp), build_poi_sequence(agent_sp), build_demo_text(agent_demo)
96
 
97
 
98
- # ── UI ────────────────────────────────────────────────────────────────────────
99
  with gr.Blocks(title="HiCoTraj Demo", theme=gr.themes.Soft()) as app:
100
  gr.Markdown("## HiCoTraj: Trajectory Visualization")
101
 
@@ -104,7 +124,7 @@ with gr.Blocks(title="HiCoTraj Demo", theme=gr.themes.Soft()) as app:
104
  label="Select Agent", value=str(sample_agents[0]))
105
  demo_label = gr.Textbox(label="Ground Truth Demographics", interactive=False)
106
 
107
- map_out = gr.HTML(label="Trajectory Map", value="<div style='height:500px'></div>")
108
  poi_out = gr.Textbox(label="POI Sequence", lines=20, interactive=False)
109
 
110
  agent_dd.change(fn=on_select, inputs=agent_dd, outputs=[map_out, poi_out, demo_label])
 
1
  import gradio as gr
2
  import pandas as pd
3
  import folium
4
+ import numpy as np
5
  import os
6
 
7
+ BASE = os.path.dirname(os.path.abspath(__file__))
 
8
  STAY_POINTS = os.path.join(BASE, "data", "stay_points_sampled.csv")
9
  POI_PATH = os.path.join(BASE, "data", "poi_sampled.csv")
10
  DEMO_PATH = os.path.join(BASE, "data", "demographics_sampled.csv")
11
 
12
+ SEX_MAP = {1:"Male", 2:"Female", -8:"Unknown", -7:"Prefer not to answer"}
13
+ EDU_MAP = {1:"Less than HS", 2:"HS Graduate/GED", 3:"Some College/Associate",
14
+ 4:"Bachelor's Degree", 5:"Graduate/Professional Degree",
15
+ -1:"N/A", -7:"Prefer not to answer", -8:"Unknown"}
16
+ INC_MAP = {1:"<$10,000", 2:"$10,000–$14,999", 3:"$15,000–$24,999",
17
+ 4:"$25,000–$34,999", 5:"$35,000–$49,999", 6:"$50,000–$74,999",
18
+ 7:"$75,000–$99,999", 8:"$100,000–$124,999", 9:"$125,000–$149,999",
19
+ 10:"$150,000–$199,999", 11:"$200,000+",
20
+ -7:"Prefer not to answer", -8:"Unknown", -9:"Not ascertained"}
21
+ RACE_MAP = {1:"White", 2:"Black or African American", 3:"Asian",
22
+ 4:"American Indian or Alaska Native",
23
+ 5:"Native Hawaiian or Other Pacific Islander",
24
+ 6:"Multiple races", 97:"Other",
25
+ -7:"Prefer not to answer", -8:"Unknown"}
26
+ ACT_MAP = {0:"Transportation", 1:"Home", 2:"Work", 3:"School", 4:"ChildCare",
27
+ 5:"BuyGoods", 6:"Services", 7:"EatOut", 8:"Errands", 9:"Recreation",
28
+ 10:"Exercise", 11:"Visit", 12:"HealthCare", 13:"Religious",
29
+ 14:"SomethingElse", 15:"DropOff"}
30
 
 
31
  print("Loading data...")
32
  sp = pd.read_csv(STAY_POINTS)
33
  poi = pd.read_csv(POI_PATH)
34
  demo = pd.read_csv(DEMO_PATH)
35
 
36
+ sp = sp.merge(poi, on="poi_id", how="left")
37
  sp["start_datetime"] = pd.to_datetime(sp["start_datetime"], utc=True)
38
  sp["end_datetime"] = pd.to_datetime(sp["end_datetime"], utc=True)
39
  sp["duration_min"] = ((sp["end_datetime"] - sp["start_datetime"]).dt.total_seconds() / 60).round(1)
40
 
41
+ def parse_act_types(x):
42
+ try:
43
+ codes = list(map(int, str(x).strip("[]").split()))
44
+ return ", ".join(ACT_MAP.get(c, str(c)) for c in codes)
45
+ except:
46
+ return str(x)
47
+
48
+ sp["act_label"] = sp["act_types"].apply(parse_act_types)
49
+
50
  sample_agents = sorted(sp["agent_id"].unique().tolist())
51
  print(f"Ready. {len(sample_agents)} agents loaded.")
52
 
53
+
54
  def build_map(agent_sp):
55
+ agent_sp = agent_sp.reset_index(drop=True).copy()
 
56
  agent_sp["latitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
57
  agent_sp["longitude"] += np.random.uniform(-0.0003, 0.0003, len(agent_sp))
58
+
59
  lat = agent_sp["latitude"].mean()
60
  lon = agent_sp["longitude"].mean()
61
  m = folium.Map(location=[lat, lon], zoom_start=12, tiles="CartoDB positron")
62
 
 
63
  coords = list(zip(agent_sp["latitude"], agent_sp["longitude"]))
 
64
  if len(coords) > 1:
65
  folium.PolyLine(coords, color="#aaaaaa", weight=1.5, opacity=0.4).add_to(m)
66
 
67
+ n = len(agent_sp)
68
  for i, row in agent_sp.iterrows():
 
69
  ratio = i / max(n - 1, 1)
70
  r = int(255 * ratio)
71
  g = int(255 * (1 - ratio))
72
  color = f"#{r:02x}{g:02x}33"
 
73
  folium.CircleMarker(
74
  location=[row["latitude"], row["longitude"]],
75
  radius=7, color=color, fill=True, fill_color=color, fill_opacity=0.9,
76
  popup=folium.Popup(
77
  f"<b>#{i+1} {row['name']}</b><br>"
78
  f"{row['start_datetime'].strftime('%a %m/%d %H:%M')}<br>"
79
+ f"{int(row['duration_min'])} min<br>"
80
+ f"{row['act_label']}",
81
+ max_width=220
82
  )
83
  ).add_to(m)
84
 
 
86
  m.get_root().height = "500px"
87
  return m._repr_html_()
88
 
89
+
90
  def build_poi_sequence(agent_sp):
91
  lines = []
92
  for _, row in agent_sp.iterrows():
93
  lines.append(
94
  f"{row['start_datetime'].strftime('%a %m/%d')} "
95
  f"{row['start_datetime'].strftime('%H:%M')}–{row['end_datetime'].strftime('%H:%M')} "
96
+ f"({int(row['duration_min'])} min) | {row['name']} | {row['act_label']}"
97
  )
98
  return "\n".join(lines)
99
 
100
 
101
  def build_demo_text(row):
102
+ age = int(row["age"]) if row["age"] > 0 else "Unknown"
103
  return (
104
+ f"Age: {age} | "
105
+ f"Sex: {SEX_MAP.get(int(row['sex']), row['sex'])} | "
106
+ f"Race: {RACE_MAP.get(int(row['race']), row['race'])} | "
107
+ f"Education: {EDU_MAP.get(int(row['education']), row['education'])} | "
108
+ f"Income: {INC_MAP.get(int(row['hh_income']), row['hh_income'])}"
109
  )
110
 
111
 
 
116
  return build_map(agent_sp), build_poi_sequence(agent_sp), build_demo_text(agent_demo)
117
 
118
 
 
119
  with gr.Blocks(title="HiCoTraj Demo", theme=gr.themes.Soft()) as app:
120
  gr.Markdown("## HiCoTraj: Trajectory Visualization")
121
 
 
124
  label="Select Agent", value=str(sample_agents[0]))
125
  demo_label = gr.Textbox(label="Ground Truth Demographics", interactive=False)
126
 
127
+ map_out = gr.HTML(label="Trajectory Map")
128
  poi_out = gr.Textbox(label="POI Sequence", lines=20, interactive=False)
129
 
130
  agent_dd.change(fn=on_select, inputs=agent_dd, outputs=[map_out, poi_out, demo_label])