phanny commited on
Commit
9b7170c
·
1 Parent(s): 727ab75
Files changed (2) hide show
  1. app.py +97 -50
  2. src/chat.py +15 -4
app.py CHANGED
@@ -46,6 +46,9 @@ CITY_COORDS = {
46
  "roxbury": (42.33, -71.08),
47
  "allston": (42.35, -71.13),
48
  "amesbury": (42.86, -70.93),
 
 
 
49
  "austin": (30.27, -97.74),
50
  "san antonio": (29.42, -98.49),
51
  "san francisco": (37.77, -122.42),
@@ -53,6 +56,38 @@ CITY_COORDS = {
53
  "lake view terrace": (34.27, -118.37),
54
  "chicago": (41.88, -87.63),
55
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Geocode cache so any city/state from search results can be shown on the map.
57
  _GEOCODE_CACHE = {}
58
  # Show all proposed facilities (chat returns up to 5; allow more for geocoding).
@@ -202,7 +237,8 @@ def _get_facility_coords(facilities):
202
  result = []
203
  geocode_count = [0]
204
  for f in facilities:
205
- if not isinstance(f, dict):
 
206
  continue
207
  coord = _facility_coord(f, geocode_count)
208
  if coord:
@@ -224,11 +260,11 @@ def _build_google_map_html(facilities, force_update_id=None, selected_facility_n
224
  center_lon = sum(lons) / len(lons)
225
  zoom = 10
226
  markers_data = []
227
- for lat, lon, f in facility_coords:
228
  name = (f.get("facility_name") or f.get("name") or "Facility").replace("<", "&lt;").replace(">", "&gt;")
229
  info = _popup_html(f)
230
  sel = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
231
- markers_data.append({"lat": lat, "lng": lon, "name": name, "info": info, "selected": sel})
232
  # Base64-encode markers so srcdoc HTML escaping cannot break the JSON
233
  markers_json = json.dumps(markers_data)
234
  markers_b64 = base64.b64encode(markers_json.encode("utf-8")).decode("ascii")
@@ -246,10 +282,8 @@ function init() {{
246
  if (markersData && markersData.length) {{
247
  markersData.forEach(function(m) {{
248
  var pos = {{ lat: m.lat, lng: m.lng }};
249
- var opts = {{ position: pos, map: map, title: m.name }};
250
- if (m.selected) {{
251
- opts.icon = {{ path: google.maps.SymbolPath.CIRCLE, scale: 14, fillColor: "#c62828", fillOpacity: 1, strokeColor: "#fff", strokeWeight: 3 }};
252
- }}
253
  var marker = new google.maps.Marker(opts);
254
  marker.addListener("click", function() {{ infowindow.setContent(m.info); infowindow.open(map, marker); }});
255
  if (!bounds) bounds = new google.maps.LatLngBounds(pos, pos);
@@ -333,13 +367,22 @@ def _build_folium_map_html(facilities, user_location_str=None, force_update_id=N
333
  if route:
334
  folium.PolyLine(route, color="teal", weight=4, opacity=0.8).add_to(m)
335
 
336
- for (lat, lon), f in facility_coords:
337
  is_selected = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
338
- folium.Marker(
339
- [lat, lon],
 
 
 
 
 
340
  popup=folium.Popup(_popup_html(f), max_width=280),
341
- tooltip=f.get("facility_name") or f.get("name") or "Facility",
342
- icon=folium.Icon(color="red" if is_selected else "green", icon="star" if is_selected else "plus-sign"),
 
 
 
 
343
  ).add_to(m)
344
 
345
  # Center and zoom to show all proposed locations
@@ -371,9 +414,13 @@ def _build_folium_map_html(facilities, user_location_str=None, force_update_id=N
371
  f'Map could not be loaded. ({str(e)[:80]})</div>'
372
  )
373
 
374
- DISCLAIMER = (
375
- "**Disclaimer:** Information is from SAMHSA data. Always verify with the facility or "
376
- "[findtreatment.gov](https://findtreatment.gov) before making decisions. This tool does not endorse any facility."
 
 
 
 
377
  )
378
 
379
  DESCRIPTION = (
@@ -388,7 +435,35 @@ EXAMPLES = [
388
  ]
389
 
390
  CSS = """
391
- .disclaimer { font-size: 0.85em; color: #555; padding: 0.5rem 0.75rem; background: #f8f9fa; border-radius: 8px; margin-bottom: 0.75rem; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  .map-pane { padding: 0.25rem 0 0 0; }
393
  .map-pane .map-html { border-radius: 12px; overflow: hidden; box-shadow: 0 2px 12px rgba(0,0,0,0.08); }
394
  .map-pane iframe { border-radius: 12px; }
@@ -435,7 +510,7 @@ def create_demo():
435
  with gr.Blocks(title="SAMHSA Treatment Locator") as demo:
436
  gr.Markdown("# SAMHSA Treatment Locator")
437
  gr.Markdown(DESCRIPTION)
438
- gr.Markdown(f"<div class='disclaimer'>{DISCLAIMER}</div>", elem_classes=["disclaimer"])
439
 
440
  state = gr.State(DEFAULT_STATE)
441
 
@@ -459,12 +534,6 @@ def create_demo():
459
  if "type" in __import__("inspect").signature(gr.Chatbot).parameters:
460
  _chat_kw["type"] = "messages"
461
  chat = gr.Chatbot(**_chat_kw)
462
- facility_dropdown = gr.Dropdown(
463
- choices=[],
464
- value=None,
465
- label="Choose a treatment center (pin updates on map)",
466
- allow_custom_value=False,
467
- )
468
  with gr.Row():
469
  msg = gr.Textbox(
470
  placeholder="Type a message…",
@@ -481,16 +550,13 @@ def create_demo():
481
  examples_per_page=6,
482
  )
483
 
484
- def _facility_names(facilities):
485
- return [f.get("facility_name") or f.get("name") or "Facility" for f in facilities]
486
-
487
  def user_submit(message, history, state):
488
  update_id = str(time.time())
489
  if not message or not message.strip():
490
  facilities = list(state.get("last_results") or [])
491
  sel = state.get("selected_facility_name")
492
  map_html_out = _build_map_html(facilities, None, update_id, sel)
493
- return history, state, "", map_html_out, gr.update(choices=_facility_names(facilities))
494
  try:
495
  history_tuples = _messages_to_tuples(history)
496
  reply, new_state = chatbot.get_response(message, history_tuples, state)
@@ -500,8 +566,7 @@ def create_demo():
500
  facilities = list(new_state.get("last_results") or [])
501
  sel = new_state.get("selected_facility_name")
502
  map_html_out = _build_map_html(facilities, None, update_id, sel)
503
- dropdown_value = sel if sel and any(f.get("facility_name") == sel or f.get("name") == sel for f in facilities) else None
504
- return new_history_messages, new_state, "", map_html_out, gr.update(choices=_facility_names(facilities), value=dropdown_value)
505
  except Exception as e:
506
  err_msg = str(e)[:200]
507
  reply = f"Sorry, something went wrong: {err_msg}"
@@ -513,35 +578,17 @@ def create_demo():
513
  facilities = list(state.get("last_results") or [])
514
  sel = state.get("selected_facility_name")
515
  map_html_out = _build_map_html(facilities, None, update_id, sel)
516
- return new_history_messages, state, "", map_html_out, gr.update()
517
-
518
- def on_facility_select(choice, state):
519
- if not choice:
520
- state = dict(state or {})
521
- state["selected_facility_name"] = None
522
- facilities = list(state.get("last_results") or [])
523
- map_html_out = _build_map_html(facilities, None, str(time.time()), None)
524
- return map_html_out, state
525
- state = dict(state or {})
526
- state["selected_facility_name"] = choice
527
- facilities = list(state.get("last_results") or [])
528
- map_html_out = _build_map_html(facilities, None, str(time.time()), choice)
529
- return map_html_out, state
530
 
531
  submit_btn.click(
532
  user_submit,
533
  inputs=[msg, chat, state],
534
- outputs=[chat, state, msg, map_html, facility_dropdown],
535
  )
536
  msg.submit(
537
  user_submit,
538
  inputs=[msg, chat, state],
539
- outputs=[chat, state, msg, map_html, facility_dropdown],
540
- )
541
- facility_dropdown.change(
542
- on_facility_select,
543
- inputs=[facility_dropdown, state],
544
- outputs=[map_html, state],
545
  )
546
 
547
  return demo
 
46
  "roxbury": (42.33, -71.08),
47
  "allston": (42.35, -71.13),
48
  "amesbury": (42.86, -70.93),
49
+ "athol": (42.59, -72.23),
50
+ "abilene": (32.45, -99.73),
51
+ "addison": (32.96, -96.83),
52
  "austin": (30.27, -97.74),
53
  "san antonio": (29.42, -98.49),
54
  "san francisco": (37.77, -122.42),
 
56
  "lake view terrace": (34.27, -118.37),
57
  "chicago": (41.88, -87.63),
58
  }
59
+
60
+
61
+ def _normalize_facility(f):
62
+ """Ensure facility dict from state or search has string values and expected keys. Accepts any key casing."""
63
+ if not f or not isinstance(f, dict):
64
+ return None
65
+ # Case-insensitive key lookup (state/JSON may use different casing)
66
+ key_map = {k.lower(): k for k in f.keys() if isinstance(k, str)}
67
+ def get_any(*names):
68
+ for n in names:
69
+ k = n.lower()
70
+ if k in key_map:
71
+ val = f.get(key_map[k])
72
+ if val is not None and str(val).strip() and str(val).lower() != "nan":
73
+ return str(val).strip()
74
+ return None
75
+ city = get_any("city", "City")
76
+ state = get_any("state", "State")
77
+ if not city and not state:
78
+ return None
79
+ out = {"city": city or "", "state": state or ""}
80
+ for name, out_key in (
81
+ ("facility_name", "facility_name"), ("name", "facility_name"),
82
+ ("address", "address"), ("phone", "phone"),
83
+ ("treatment_type", "treatment_type"), ("services", "services"),
84
+ ):
85
+ val = get_any(name)
86
+ if val:
87
+ out[out_key] = val
88
+ if "facility_name" not in out:
89
+ out["facility_name"] = get_any("facility_name", "name") or "Facility"
90
+ return out
91
  # Geocode cache so any city/state from search results can be shown on the map.
92
  _GEOCODE_CACHE = {}
93
  # Show all proposed facilities (chat returns up to 5; allow more for geocoding).
 
237
  result = []
238
  geocode_count = [0]
239
  for f in facilities:
240
+ f = _normalize_facility(f)
241
+ if not f:
242
  continue
243
  coord = _facility_coord(f, geocode_count)
244
  if coord:
 
260
  center_lon = sum(lons) / len(lons)
261
  zoom = 10
262
  markers_data = []
263
+ for i, (lat, lon, f) in enumerate(facility_coords):
264
  name = (f.get("facility_name") or f.get("name") or "Facility").replace("<", "&lt;").replace(">", "&gt;")
265
  info = _popup_html(f)
266
  sel = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
267
+ markers_data.append({"lat": lat, "lng": lon, "name": name, "info": info, "selected": sel, "label": str(i + 1)})
268
  # Base64-encode markers so srcdoc HTML escaping cannot break the JSON
269
  markers_json = json.dumps(markers_data)
270
  markers_b64 = base64.b64encode(markers_json.encode("utf-8")).decode("ascii")
 
282
  if (markersData && markersData.length) {{
283
  markersData.forEach(function(m) {{
284
  var pos = {{ lat: m.lat, lng: m.lng }};
285
+ var opts = {{ position: pos, map: map, title: m.name, label: {{ text: m.label || "", color: "white", fontWeight: "bold" }} }};
286
+ if (m.selected) {{ opts.animation = google.maps.Animation.BOUNCE; opts.label = {{ text: "★", color: "white", fontWeight: "bold" }}; }}
 
 
287
  var marker = new google.maps.Marker(opts);
288
  marker.addListener("click", function() {{ infowindow.setContent(m.info); infowindow.open(map, marker); }});
289
  if (!bounds) bounds = new google.maps.LatLngBounds(pos, pos);
 
367
  if route:
368
  folium.PolyLine(route, color="teal", weight=4, opacity=0.8).add_to(m)
369
 
370
+ for i, ((lat, lon), f) in enumerate(facility_coords):
371
  is_selected = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
372
+ name = f.get("facility_name") or f.get("name") or "Facility"
373
+ tooltip = f"{i + 1}. {name}"
374
+ color = "#c62828" if is_selected else "#319795"
375
+ fill_color = "#e53935" if is_selected else "#26a69a"
376
+ folium.CircleMarker(
377
+ location=[lat, lon],
378
+ radius=12,
379
  popup=folium.Popup(_popup_html(f), max_width=280),
380
+ tooltip=tooltip,
381
+ color=color,
382
+ fill=True,
383
+ fill_color=fill_color,
384
+ fill_opacity=0.9,
385
+ weight=2,
386
  ).add_to(m)
387
 
388
  # Center and zoom to show all proposed locations
 
414
  f'Map could not be loaded. ({str(e)[:80]})</div>'
415
  )
416
 
417
+ DISCLAIMER_HTML = (
418
+ '<div class="disclaimer">'
419
+ '<p class="disclaimer-title">Disclaimer</p>'
420
+ '<p class="disclaimer-text">Information is from SAMHSA data. Always verify with the facility or '
421
+ '<a href="https://findtreatment.gov" target="_blank" rel="noopener">findtreatment.gov</a> '
422
+ 'before making decisions. This tool does not endorse any facility.</p>'
423
+ '</div>'
424
  )
425
 
426
  DESCRIPTION = (
 
435
  ]
436
 
437
  CSS = """
438
+ .disclaimer {
439
+ font-size: 0.9rem;
440
+ color: #2d3748;
441
+ padding: 0.875rem 1rem;
442
+ background: #e2e8f0;
443
+ border-radius: 10px;
444
+ margin-bottom: 0.75rem;
445
+ border-left: 4px solid #319795;
446
+ box-shadow: 0 1px 3px rgba(0,0,0,0.08);
447
+ }
448
+ .disclaimer-title {
449
+ font-weight: 600;
450
+ color: #1a202c;
451
+ margin: 0 0 0.35em 0;
452
+ font-size: 0.95em;
453
+ }
454
+ .disclaimer-text {
455
+ margin: 0;
456
+ line-height: 1.5;
457
+ color: #2d3748;
458
+ }
459
+ .disclaimer a {
460
+ color: #2c7a7b;
461
+ text-decoration: none;
462
+ font-weight: 600;
463
+ }
464
+ .disclaimer a:hover {
465
+ text-decoration: underline;
466
+ }
467
  .map-pane { padding: 0.25rem 0 0 0; }
468
  .map-pane .map-html { border-radius: 12px; overflow: hidden; box-shadow: 0 2px 12px rgba(0,0,0,0.08); }
469
  .map-pane iframe { border-radius: 12px; }
 
510
  with gr.Blocks(title="SAMHSA Treatment Locator") as demo:
511
  gr.Markdown("# SAMHSA Treatment Locator")
512
  gr.Markdown(DESCRIPTION)
513
+ gr.HTML(DISCLAIMER_HTML)
514
 
515
  state = gr.State(DEFAULT_STATE)
516
 
 
534
  if "type" in __import__("inspect").signature(gr.Chatbot).parameters:
535
  _chat_kw["type"] = "messages"
536
  chat = gr.Chatbot(**_chat_kw)
 
 
 
 
 
 
537
  with gr.Row():
538
  msg = gr.Textbox(
539
  placeholder="Type a message…",
 
550
  examples_per_page=6,
551
  )
552
 
 
 
 
553
  def user_submit(message, history, state):
554
  update_id = str(time.time())
555
  if not message or not message.strip():
556
  facilities = list(state.get("last_results") or [])
557
  sel = state.get("selected_facility_name")
558
  map_html_out = _build_map_html(facilities, None, update_id, sel)
559
+ return history, state, "", map_html_out
560
  try:
561
  history_tuples = _messages_to_tuples(history)
562
  reply, new_state = chatbot.get_response(message, history_tuples, state)
 
566
  facilities = list(new_state.get("last_results") or [])
567
  sel = new_state.get("selected_facility_name")
568
  map_html_out = _build_map_html(facilities, None, update_id, sel)
569
+ return new_history_messages, new_state, "", map_html_out
 
570
  except Exception as e:
571
  err_msg = str(e)[:200]
572
  reply = f"Sorry, something went wrong: {err_msg}"
 
578
  facilities = list(state.get("last_results") or [])
579
  sel = state.get("selected_facility_name")
580
  map_html_out = _build_map_html(facilities, None, update_id, sel)
581
+ return new_history_messages, state, "", map_html_out
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
  submit_btn.click(
584
  user_submit,
585
  inputs=[msg, chat, state],
586
+ outputs=[chat, state, msg, map_html],
587
  )
588
  msg.submit(
589
  user_submit,
590
  inputs=[msg, chat, state],
591
+ outputs=[chat, state, msg, map_html],
 
 
 
 
 
592
  )
593
 
594
  return demo
src/chat.py CHANGED
@@ -281,12 +281,23 @@ class Chatbot:
281
  max_tokens=800,
282
  temperature=0.5,
283
  )
284
- reply = (response.choices[0].message.content or "").strip()
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  new_state = {
287
- "criteria": criteria,
288
- "last_results": last_results,
289
- "last_facility_detail": last_facility_detail,
290
  "selected_facility_name": selected_facility_name,
291
  }
292
  return reply, new_state
 
281
  max_tokens=800,
282
  temperature=0.5,
283
  )
284
+ raw = response.choices[0].message.content
285
+ if isinstance(raw, list):
286
+ reply = "".join(
287
+ (b.get("text", "") if isinstance(b, dict) else str(b))
288
+ for b in raw
289
+ ).strip()
290
+ else:
291
+ reply = (raw or "").strip()
292
+
293
+ # Return a copy of last_results so Gradio state updates reliably (map re-renders)
294
+ results_for_state = list(last_results) if last_results else []
295
+ detail_for_state = dict(last_facility_detail) if isinstance(last_facility_detail, dict) else last_facility_detail
296
 
297
  new_state = {
298
+ "criteria": dict(criteria),
299
+ "last_results": results_for_state,
300
+ "last_facility_detail": detail_for_state,
301
  "selected_facility_name": selected_facility_name,
302
  }
303
  return reply, new_state