Spaces:
Sleeping
Sleeping
phanny commited on
Commit ·
9b7170c
1
Parent(s): 727ab75
fixed ui
Browse files- app.py +97 -50
- src/chat.py +15 -4
app.py
CHANGED
|
@@ -46,6 +46,9 @@ CITY_COORDS = {
|
|
| 46 |
"roxbury": (42.33, -71.08),
|
| 47 |
"allston": (42.35, -71.13),
|
| 48 |
"amesbury": (42.86, -70.93),
|
|
|
|
|
|
|
|
|
|
| 49 |
"austin": (30.27, -97.74),
|
| 50 |
"san antonio": (29.42, -98.49),
|
| 51 |
"san francisco": (37.77, -122.42),
|
|
@@ -53,6 +56,38 @@ CITY_COORDS = {
|
|
| 53 |
"lake view terrace": (34.27, -118.37),
|
| 54 |
"chicago": (41.88, -87.63),
|
| 55 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Geocode cache so any city/state from search results can be shown on the map.
|
| 57 |
_GEOCODE_CACHE = {}
|
| 58 |
# Show all proposed facilities (chat returns up to 5; allow more for geocoding).
|
|
@@ -202,7 +237,8 @@ def _get_facility_coords(facilities):
|
|
| 202 |
result = []
|
| 203 |
geocode_count = [0]
|
| 204 |
for f in facilities:
|
| 205 |
-
|
|
|
|
| 206 |
continue
|
| 207 |
coord = _facility_coord(f, geocode_count)
|
| 208 |
if coord:
|
|
@@ -224,11 +260,11 @@ def _build_google_map_html(facilities, force_update_id=None, selected_facility_n
|
|
| 224 |
center_lon = sum(lons) / len(lons)
|
| 225 |
zoom = 10
|
| 226 |
markers_data = []
|
| 227 |
-
for lat, lon, f in facility_coords:
|
| 228 |
name = (f.get("facility_name") or f.get("name") or "Facility").replace("<", "<").replace(">", ">")
|
| 229 |
info = _popup_html(f)
|
| 230 |
sel = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
|
| 231 |
-
markers_data.append({"lat": lat, "lng": lon, "name": name, "info": info, "selected": sel})
|
| 232 |
# Base64-encode markers so srcdoc HTML escaping cannot break the JSON
|
| 233 |
markers_json = json.dumps(markers_data)
|
| 234 |
markers_b64 = base64.b64encode(markers_json.encode("utf-8")).decode("ascii")
|
|
@@ -246,10 +282,8 @@ function init() {{
|
|
| 246 |
if (markersData && markersData.length) {{
|
| 247 |
markersData.forEach(function(m) {{
|
| 248 |
var pos = {{ lat: m.lat, lng: m.lng }};
|
| 249 |
-
var opts = {{ position: pos, map: map, title: m.name }};
|
| 250 |
-
if (m.selected) {{
|
| 251 |
-
opts.icon = {{ path: google.maps.SymbolPath.CIRCLE, scale: 14, fillColor: "#c62828", fillOpacity: 1, strokeColor: "#fff", strokeWeight: 3 }};
|
| 252 |
-
}}
|
| 253 |
var marker = new google.maps.Marker(opts);
|
| 254 |
marker.addListener("click", function() {{ infowindow.setContent(m.info); infowindow.open(map, marker); }});
|
| 255 |
if (!bounds) bounds = new google.maps.LatLngBounds(pos, pos);
|
|
@@ -333,13 +367,22 @@ def _build_folium_map_html(facilities, user_location_str=None, force_update_id=N
|
|
| 333 |
if route:
|
| 334 |
folium.PolyLine(route, color="teal", weight=4, opacity=0.8).add_to(m)
|
| 335 |
|
| 336 |
-
for (lat, lon), f in facility_coords:
|
| 337 |
is_selected = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
popup=folium.Popup(_popup_html(f), max_width=280),
|
| 341 |
-
tooltip=
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
).add_to(m)
|
| 344 |
|
| 345 |
# Center and zoom to show all proposed locations
|
|
@@ -371,9 +414,13 @@ def _build_folium_map_html(facilities, user_location_str=None, force_update_id=N
|
|
| 371 |
f'Map could not be loaded. ({str(e)[:80]})</div>'
|
| 372 |
)
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
)
|
| 378 |
|
| 379 |
DESCRIPTION = (
|
|
@@ -388,7 +435,35 @@ EXAMPLES = [
|
|
| 388 |
]
|
| 389 |
|
| 390 |
CSS = """
|
| 391 |
-
.disclaimer {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
.map-pane { padding: 0.25rem 0 0 0; }
|
| 393 |
.map-pane .map-html { border-radius: 12px; overflow: hidden; box-shadow: 0 2px 12px rgba(0,0,0,0.08); }
|
| 394 |
.map-pane iframe { border-radius: 12px; }
|
|
@@ -435,7 +510,7 @@ def create_demo():
|
|
| 435 |
with gr.Blocks(title="SAMHSA Treatment Locator") as demo:
|
| 436 |
gr.Markdown("# SAMHSA Treatment Locator")
|
| 437 |
gr.Markdown(DESCRIPTION)
|
| 438 |
-
gr.
|
| 439 |
|
| 440 |
state = gr.State(DEFAULT_STATE)
|
| 441 |
|
|
@@ -459,12 +534,6 @@ def create_demo():
|
|
| 459 |
if "type" in __import__("inspect").signature(gr.Chatbot).parameters:
|
| 460 |
_chat_kw["type"] = "messages"
|
| 461 |
chat = gr.Chatbot(**_chat_kw)
|
| 462 |
-
facility_dropdown = gr.Dropdown(
|
| 463 |
-
choices=[],
|
| 464 |
-
value=None,
|
| 465 |
-
label="Choose a treatment center (pin updates on map)",
|
| 466 |
-
allow_custom_value=False,
|
| 467 |
-
)
|
| 468 |
with gr.Row():
|
| 469 |
msg = gr.Textbox(
|
| 470 |
placeholder="Type a message…",
|
|
@@ -481,16 +550,13 @@ def create_demo():
|
|
| 481 |
examples_per_page=6,
|
| 482 |
)
|
| 483 |
|
| 484 |
-
def _facility_names(facilities):
|
| 485 |
-
return [f.get("facility_name") or f.get("name") or "Facility" for f in facilities]
|
| 486 |
-
|
| 487 |
def user_submit(message, history, state):
|
| 488 |
update_id = str(time.time())
|
| 489 |
if not message or not message.strip():
|
| 490 |
facilities = list(state.get("last_results") or [])
|
| 491 |
sel = state.get("selected_facility_name")
|
| 492 |
map_html_out = _build_map_html(facilities, None, update_id, sel)
|
| 493 |
-
return history, state, "", map_html_out
|
| 494 |
try:
|
| 495 |
history_tuples = _messages_to_tuples(history)
|
| 496 |
reply, new_state = chatbot.get_response(message, history_tuples, state)
|
|
@@ -500,8 +566,7 @@ def create_demo():
|
|
| 500 |
facilities = list(new_state.get("last_results") or [])
|
| 501 |
sel = new_state.get("selected_facility_name")
|
| 502 |
map_html_out = _build_map_html(facilities, None, update_id, sel)
|
| 503 |
-
|
| 504 |
-
return new_history_messages, new_state, "", map_html_out, gr.update(choices=_facility_names(facilities), value=dropdown_value)
|
| 505 |
except Exception as e:
|
| 506 |
err_msg = str(e)[:200]
|
| 507 |
reply = f"Sorry, something went wrong: {err_msg}"
|
|
@@ -513,35 +578,17 @@ def create_demo():
|
|
| 513 |
facilities = list(state.get("last_results") or [])
|
| 514 |
sel = state.get("selected_facility_name")
|
| 515 |
map_html_out = _build_map_html(facilities, None, update_id, sel)
|
| 516 |
-
return new_history_messages, state, "", map_html_out
|
| 517 |
-
|
| 518 |
-
def on_facility_select(choice, state):
|
| 519 |
-
if not choice:
|
| 520 |
-
state = dict(state or {})
|
| 521 |
-
state["selected_facility_name"] = None
|
| 522 |
-
facilities = list(state.get("last_results") or [])
|
| 523 |
-
map_html_out = _build_map_html(facilities, None, str(time.time()), None)
|
| 524 |
-
return map_html_out, state
|
| 525 |
-
state = dict(state or {})
|
| 526 |
-
state["selected_facility_name"] = choice
|
| 527 |
-
facilities = list(state.get("last_results") or [])
|
| 528 |
-
map_html_out = _build_map_html(facilities, None, str(time.time()), choice)
|
| 529 |
-
return map_html_out, state
|
| 530 |
|
| 531 |
submit_btn.click(
|
| 532 |
user_submit,
|
| 533 |
inputs=[msg, chat, state],
|
| 534 |
-
outputs=[chat, state, msg, map_html
|
| 535 |
)
|
| 536 |
msg.submit(
|
| 537 |
user_submit,
|
| 538 |
inputs=[msg, chat, state],
|
| 539 |
-
outputs=[chat, state, msg, map_html
|
| 540 |
-
)
|
| 541 |
-
facility_dropdown.change(
|
| 542 |
-
on_facility_select,
|
| 543 |
-
inputs=[facility_dropdown, state],
|
| 544 |
-
outputs=[map_html, state],
|
| 545 |
)
|
| 546 |
|
| 547 |
return demo
|
|
|
|
| 46 |
"roxbury": (42.33, -71.08),
|
| 47 |
"allston": (42.35, -71.13),
|
| 48 |
"amesbury": (42.86, -70.93),
|
| 49 |
+
"athol": (42.59, -72.23),
|
| 50 |
+
"abilene": (32.45, -99.73),
|
| 51 |
+
"addison": (32.96, -96.83),
|
| 52 |
"austin": (30.27, -97.74),
|
| 53 |
"san antonio": (29.42, -98.49),
|
| 54 |
"san francisco": (37.77, -122.42),
|
|
|
|
| 56 |
"lake view terrace": (34.27, -118.37),
|
| 57 |
"chicago": (41.88, -87.63),
|
| 58 |
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _normalize_facility(f):
|
| 62 |
+
"""Ensure facility dict from state or search has string values and expected keys. Accepts any key casing."""
|
| 63 |
+
if not f or not isinstance(f, dict):
|
| 64 |
+
return None
|
| 65 |
+
# Case-insensitive key lookup (state/JSON may use different casing)
|
| 66 |
+
key_map = {k.lower(): k for k in f.keys() if isinstance(k, str)}
|
| 67 |
+
def get_any(*names):
|
| 68 |
+
for n in names:
|
| 69 |
+
k = n.lower()
|
| 70 |
+
if k in key_map:
|
| 71 |
+
val = f.get(key_map[k])
|
| 72 |
+
if val is not None and str(val).strip() and str(val).lower() != "nan":
|
| 73 |
+
return str(val).strip()
|
| 74 |
+
return None
|
| 75 |
+
city = get_any("city", "City")
|
| 76 |
+
state = get_any("state", "State")
|
| 77 |
+
if not city and not state:
|
| 78 |
+
return None
|
| 79 |
+
out = {"city": city or "", "state": state or ""}
|
| 80 |
+
for name, out_key in (
|
| 81 |
+
("facility_name", "facility_name"), ("name", "facility_name"),
|
| 82 |
+
("address", "address"), ("phone", "phone"),
|
| 83 |
+
("treatment_type", "treatment_type"), ("services", "services"),
|
| 84 |
+
):
|
| 85 |
+
val = get_any(name)
|
| 86 |
+
if val:
|
| 87 |
+
out[out_key] = val
|
| 88 |
+
if "facility_name" not in out:
|
| 89 |
+
out["facility_name"] = get_any("facility_name", "name") or "Facility"
|
| 90 |
+
return out
|
| 91 |
# Geocode cache so any city/state from search results can be shown on the map.
|
| 92 |
_GEOCODE_CACHE = {}
|
| 93 |
# Show all proposed facilities (chat returns up to 5; allow more for geocoding).
|
|
|
|
| 237 |
result = []
|
| 238 |
geocode_count = [0]
|
| 239 |
for f in facilities:
|
| 240 |
+
f = _normalize_facility(f)
|
| 241 |
+
if not f:
|
| 242 |
continue
|
| 243 |
coord = _facility_coord(f, geocode_count)
|
| 244 |
if coord:
|
|
|
|
| 260 |
center_lon = sum(lons) / len(lons)
|
| 261 |
zoom = 10
|
| 262 |
markers_data = []
|
| 263 |
+
for i, (lat, lon, f) in enumerate(facility_coords):
|
| 264 |
name = (f.get("facility_name") or f.get("name") or "Facility").replace("<", "<").replace(">", ">")
|
| 265 |
info = _popup_html(f)
|
| 266 |
sel = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
|
| 267 |
+
markers_data.append({"lat": lat, "lng": lon, "name": name, "info": info, "selected": sel, "label": str(i + 1)})
|
| 268 |
# Base64-encode markers so srcdoc HTML escaping cannot break the JSON
|
| 269 |
markers_json = json.dumps(markers_data)
|
| 270 |
markers_b64 = base64.b64encode(markers_json.encode("utf-8")).decode("ascii")
|
|
|
|
| 282 |
if (markersData && markersData.length) {{
|
| 283 |
markersData.forEach(function(m) {{
|
| 284 |
var pos = {{ lat: m.lat, lng: m.lng }};
|
| 285 |
+
var opts = {{ position: pos, map: map, title: m.name, label: {{ text: m.label || "", color: "white", fontWeight: "bold" }} }};
|
| 286 |
+
if (m.selected) {{ opts.animation = google.maps.Animation.BOUNCE; opts.label = {{ text: "★", color: "white", fontWeight: "bold" }}; }}
|
|
|
|
|
|
|
| 287 |
var marker = new google.maps.Marker(opts);
|
| 288 |
marker.addListener("click", function() {{ infowindow.setContent(m.info); infowindow.open(map, marker); }});
|
| 289 |
if (!bounds) bounds = new google.maps.LatLngBounds(pos, pos);
|
|
|
|
| 367 |
if route:
|
| 368 |
folium.PolyLine(route, color="teal", weight=4, opacity=0.8).add_to(m)
|
| 369 |
|
| 370 |
+
for i, ((lat, lon), f) in enumerate(facility_coords):
|
| 371 |
is_selected = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
|
| 372 |
+
name = f.get("facility_name") or f.get("name") or "Facility"
|
| 373 |
+
tooltip = f"{i + 1}. {name}"
|
| 374 |
+
color = "#c62828" if is_selected else "#319795"
|
| 375 |
+
fill_color = "#e53935" if is_selected else "#26a69a"
|
| 376 |
+
folium.CircleMarker(
|
| 377 |
+
location=[lat, lon],
|
| 378 |
+
radius=12,
|
| 379 |
popup=folium.Popup(_popup_html(f), max_width=280),
|
| 380 |
+
tooltip=tooltip,
|
| 381 |
+
color=color,
|
| 382 |
+
fill=True,
|
| 383 |
+
fill_color=fill_color,
|
| 384 |
+
fill_opacity=0.9,
|
| 385 |
+
weight=2,
|
| 386 |
).add_to(m)
|
| 387 |
|
| 388 |
# Center and zoom to show all proposed locations
|
|
|
|
| 414 |
f'Map could not be loaded. ({str(e)[:80]})</div>'
|
| 415 |
)
|
| 416 |
|
| 417 |
+
DISCLAIMER_HTML = (
|
| 418 |
+
'<div class="disclaimer">'
|
| 419 |
+
'<p class="disclaimer-title">Disclaimer</p>'
|
| 420 |
+
'<p class="disclaimer-text">Information is from SAMHSA data. Always verify with the facility or '
|
| 421 |
+
'<a href="https://findtreatment.gov" target="_blank" rel="noopener">findtreatment.gov</a> '
|
| 422 |
+
'before making decisions. This tool does not endorse any facility.</p>'
|
| 423 |
+
'</div>'
|
| 424 |
)
|
| 425 |
|
| 426 |
DESCRIPTION = (
|
|
|
|
| 435 |
]
|
| 436 |
|
| 437 |
CSS = """
|
| 438 |
+
.disclaimer {
|
| 439 |
+
font-size: 0.9rem;
|
| 440 |
+
color: #2d3748;
|
| 441 |
+
padding: 0.875rem 1rem;
|
| 442 |
+
background: #e2e8f0;
|
| 443 |
+
border-radius: 10px;
|
| 444 |
+
margin-bottom: 0.75rem;
|
| 445 |
+
border-left: 4px solid #319795;
|
| 446 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.08);
|
| 447 |
+
}
|
| 448 |
+
.disclaimer-title {
|
| 449 |
+
font-weight: 600;
|
| 450 |
+
color: #1a202c;
|
| 451 |
+
margin: 0 0 0.35em 0;
|
| 452 |
+
font-size: 0.95em;
|
| 453 |
+
}
|
| 454 |
+
.disclaimer-text {
|
| 455 |
+
margin: 0;
|
| 456 |
+
line-height: 1.5;
|
| 457 |
+
color: #2d3748;
|
| 458 |
+
}
|
| 459 |
+
.disclaimer a {
|
| 460 |
+
color: #2c7a7b;
|
| 461 |
+
text-decoration: none;
|
| 462 |
+
font-weight: 600;
|
| 463 |
+
}
|
| 464 |
+
.disclaimer a:hover {
|
| 465 |
+
text-decoration: underline;
|
| 466 |
+
}
|
| 467 |
.map-pane { padding: 0.25rem 0 0 0; }
|
| 468 |
.map-pane .map-html { border-radius: 12px; overflow: hidden; box-shadow: 0 2px 12px rgba(0,0,0,0.08); }
|
| 469 |
.map-pane iframe { border-radius: 12px; }
|
|
|
|
| 510 |
with gr.Blocks(title="SAMHSA Treatment Locator") as demo:
|
| 511 |
gr.Markdown("# SAMHSA Treatment Locator")
|
| 512 |
gr.Markdown(DESCRIPTION)
|
| 513 |
+
gr.HTML(DISCLAIMER_HTML)
|
| 514 |
|
| 515 |
state = gr.State(DEFAULT_STATE)
|
| 516 |
|
|
|
|
| 534 |
if "type" in __import__("inspect").signature(gr.Chatbot).parameters:
|
| 535 |
_chat_kw["type"] = "messages"
|
| 536 |
chat = gr.Chatbot(**_chat_kw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
with gr.Row():
|
| 538 |
msg = gr.Textbox(
|
| 539 |
placeholder="Type a message…",
|
|
|
|
| 550 |
examples_per_page=6,
|
| 551 |
)
|
| 552 |
|
|
|
|
|
|
|
|
|
|
| 553 |
def user_submit(message, history, state):
|
| 554 |
update_id = str(time.time())
|
| 555 |
if not message or not message.strip():
|
| 556 |
facilities = list(state.get("last_results") or [])
|
| 557 |
sel = state.get("selected_facility_name")
|
| 558 |
map_html_out = _build_map_html(facilities, None, update_id, sel)
|
| 559 |
+
return history, state, "", map_html_out
|
| 560 |
try:
|
| 561 |
history_tuples = _messages_to_tuples(history)
|
| 562 |
reply, new_state = chatbot.get_response(message, history_tuples, state)
|
|
|
|
| 566 |
facilities = list(new_state.get("last_results") or [])
|
| 567 |
sel = new_state.get("selected_facility_name")
|
| 568 |
map_html_out = _build_map_html(facilities, None, update_id, sel)
|
| 569 |
+
return new_history_messages, new_state, "", map_html_out
|
|
|
|
| 570 |
except Exception as e:
|
| 571 |
err_msg = str(e)[:200]
|
| 572 |
reply = f"Sorry, something went wrong: {err_msg}"
|
|
|
|
| 578 |
facilities = list(state.get("last_results") or [])
|
| 579 |
sel = state.get("selected_facility_name")
|
| 580 |
map_html_out = _build_map_html(facilities, None, update_id, sel)
|
| 581 |
+
return new_history_messages, state, "", map_html_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
|
| 583 |
submit_btn.click(
|
| 584 |
user_submit,
|
| 585 |
inputs=[msg, chat, state],
|
| 586 |
+
outputs=[chat, state, msg, map_html],
|
| 587 |
)
|
| 588 |
msg.submit(
|
| 589 |
user_submit,
|
| 590 |
inputs=[msg, chat, state],
|
| 591 |
+
outputs=[chat, state, msg, map_html],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
)
|
| 593 |
|
| 594 |
return demo
|
src/chat.py
CHANGED
|
@@ -281,12 +281,23 @@ class Chatbot:
|
|
| 281 |
max_tokens=800,
|
| 282 |
temperature=0.5,
|
| 283 |
)
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
new_state = {
|
| 287 |
-
"criteria": criteria,
|
| 288 |
-
"last_results":
|
| 289 |
-
"last_facility_detail":
|
| 290 |
"selected_facility_name": selected_facility_name,
|
| 291 |
}
|
| 292 |
return reply, new_state
|
|
|
|
| 281 |
max_tokens=800,
|
| 282 |
temperature=0.5,
|
| 283 |
)
|
| 284 |
+
raw = response.choices[0].message.content
|
| 285 |
+
if isinstance(raw, list):
|
| 286 |
+
reply = "".join(
|
| 287 |
+
(b.get("text", "") if isinstance(b, dict) else str(b))
|
| 288 |
+
for b in raw
|
| 289 |
+
).strip()
|
| 290 |
+
else:
|
| 291 |
+
reply = (raw or "").strip()
|
| 292 |
+
|
| 293 |
+
# Return a copy of last_results so Gradio state updates reliably (map re-renders)
|
| 294 |
+
results_for_state = list(last_results) if last_results else []
|
| 295 |
+
detail_for_state = dict(last_facility_detail) if isinstance(last_facility_detail, dict) else last_facility_detail
|
| 296 |
|
| 297 |
new_state = {
|
| 298 |
+
"criteria": dict(criteria),
|
| 299 |
+
"last_results": results_for_state,
|
| 300 |
+
"last_facility_detail": detail_for_state,
|
| 301 |
"selected_facility_name": selected_facility_name,
|
| 302 |
}
|
| 303 |
return reply, new_state
|