""" SAMHSA Treatment Locator – Gradio app for HuggingFace Spaces. Two-pane layout: map (left) + chat (right). When GOOGLE_MAPS_API_KEY is set in .env, the map uses Google Maps (JavaScript API); otherwise Folium/OpenStreetMap. Facility markers and search results are shown on the map. """ import base64 import html as html_module import json import os import time import folium import gradio as gr import requests # Load .env from project root so GOOGLE_MAPS_API_KEY is set regardless of launch cwd _APP_DIR = os.path.dirname(os.path.abspath(__file__)) _ENV_PATH = os.path.join(_APP_DIR, ".env") try: from dotenv import load_dotenv load_dotenv(_ENV_PATH) except ImportError: pass from src.chat import DEFAULT_STATE, Chatbot # Use Google Maps when key is set; otherwise Folium/OSM. Enable "Maps JavaScript API" in Google Cloud. GOOGLE_MAPS_API_KEY = (os.environ.get("GOOGLE_MAPS_API_KEY") or "").strip() or None if not GOOGLE_MAPS_API_KEY and os.path.isfile(_ENV_PATH): with open(_ENV_PATH) as f: for line in f: line = line.strip() if line.startswith("GOOGLE_MAPS_API_KEY=") and "=" in line: key = line.split("=", 1)[1].strip().strip('"').strip("'") if key: GOOGLE_MAPS_API_KEY = key break # City/place -> (lat, lon) for map pins (fast lookup). Add more as needed. CITY_COORDS = { "boston": (42.36, -71.06), "belmont": (42.40, -71.18), "roxbury": (42.33, -71.08), "allston": (42.35, -71.13), "amesbury": (42.86, -70.93), "athol": (42.59, -72.23), "abilene": (32.45, -99.73), "addison": (32.96, -96.83), "austin": (30.27, -97.74), "san antonio": (29.42, -98.49), "san francisco": (37.77, -122.42), "los angeles": (34.05, -118.25), "lake view terrace": (34.27, -118.37), "chicago": (41.88, -87.63), } def _normalize_facility(f): """Ensure facility dict from state or search has string values and expected keys. Accepts any key casing.""" if not f or not isinstance(f, dict): return None # Case-insensitive key lookup (state/JSON may use different casing) key_map = {k.lower(): k for k in f.keys() if isinstance(k, str)} def get_any(*names): for n in names: k = n.lower() if k in key_map: val = f.get(key_map[k]) if val is not None and str(val).strip() and str(val).lower() != "nan": return str(val).strip() return None city = get_any("city", "City") state = get_any("state", "State") if not city and not state: return None out = {"city": city or "", "state": state or ""} for name, out_key in ( ("facility_name", "facility_name"), ("name", "facility_name"), ("address", "address"), ("phone", "phone"), ("treatment_type", "treatment_type"), ("services", "services"), ): val = get_any(name) if val: out[out_key] = val if "facility_name" not in out: out["facility_name"] = get_any("facility_name", "name") or "Facility" return out # Geocode cache so any city/state from search results can be shown on the map. _GEOCODE_CACHE = {} # Show all proposed facilities (chat returns up to 5; allow more for geocoding). _MAX_GEOCODE_PER_MAP = 20 MAP_HEIGHT_PX = 420 def _geocode(location_str): """Resolve address/city to (lat, lon) using Nominatim. Returns None on failure. Uses cache.""" if not location_str or not location_str.strip(): return None key = location_str.strip().lower() if key in _GEOCODE_CACHE: return _GEOCODE_CACHE[key] try: from geopy.geocoders import Nominatim from geopy.extra.rate_limiter import RateLimiter geocoder = Nominatim(user_agent="samhsa-treatment-locator") geocode = RateLimiter(geocoder.geocode, min_delay_seconds=1.0) result = geocode(location_str.strip(), country_codes="us") if result: coord = (result.latitude, result.longitude) _GEOCODE_CACHE[key] = coord return coord except Exception: pass return CITY_COORDS.get(key) def _decode_google_polyline(encoded: str): """Decode Google's encoded polyline string to list of (lat, lon).""" # https://developers.google.com/maps/documentation/utilities/polylinealgorithm coords = [] i = 0 lat = 0 lon = 0 while i < len(encoded): b = 0 shift = 0 result = 0 while True: b = ord(encoded[i]) - 63 i += 1 result |= (b & 0x1F) << shift shift += 5 if b < 0x20: break dlat = ~(result >> 1) if result & 1 else result >> 1 lat += dlat shift = 0 result = 0 while True: b = ord(encoded[i]) - 63 i += 1 result |= (b & 0x1F) << shift shift += 5 if b < 0x20: break dlon = ~(result >> 1) if result & 1 else result >> 1 lon += dlon coords.append((lat * 1e-5, lon * 1e-5)) return coords def _get_route(lat1, lon1, lat2, lon2): """Return list of (lat, lon) for driving route. Uses Google Directions if GOOGLE_MAPS_API_KEY set, else OSRM.""" if GOOGLE_MAPS_API_KEY: try: url = ( "https://maps.googleapis.com/maps/api/directions/json" f"?origin={lat1},{lon1}&destination={lat2},{lon2}&key={GOOGLE_MAPS_API_KEY}" ) r = requests.get(url, timeout=10) if r.status_code != 200: return _get_route_osrm(lat1, lon1, lat2, lon2) data = r.json() if data.get("status") != "OK" or not data.get("routes"): return _get_route_osrm(lat1, lon1, lat2, lon2) points = data["routes"][0].get("overview_polyline", {}).get("points") if points: return _decode_google_polyline(points) except Exception: pass return _get_route_osrm(lat1, lon1, lat2, lon2) return _get_route_osrm(lat1, lon1, lat2, lon2) def _get_route_osrm(lat1, lon1, lat2, lon2): """Return list of (lat, lon) for driving route from OSRM (free), or None.""" try: url = ( "https://router.project-osrm.org/route/v1/driving/" f"{lon1},{lat1};{lon2},{lat2}?overview=full&geometries=geojson" ) r = requests.get(url, timeout=5) if r.status_code != 200: return None data = r.json() if not data.get("routes"): return None coords = data["routes"][0]["geometry"]["coordinates"] return [(c[1], c[0]) for c in coords] except Exception: return None def _facility_coord(f, geocode_count=None): """(lat, lon) for a facility: CITY_COORDS first, then geocode city+state (cached). f uses CSV keys.""" city = (f.get("city") or "").strip() state = (f.get("state") or "").strip() if not city and not state: return None city_lower = city.lower() state_lower = state.lower() coord = CITY_COORDS.get(city_lower) or (CITY_COORDS.get(state_lower) if state_lower else None) if coord: return coord location_str = f"{city}, {state}".strip(", ") if not location_str: return None if geocode_count is not None and len(geocode_count) == 1 and location_str.lower() not in _GEOCODE_CACHE: if geocode_count[0] >= _MAX_GEOCODE_PER_MAP: return None geocode_count[0] += 1 return _geocode(location_str) def _popup_html(f): """Short HTML for Folium popup. f uses CSV keys (facility_name, address, etc.).""" name = f.get("facility_name") or f.get("name") or "Facility" addr = f.get("address") or "" city = f.get("city") or "" st = f.get("state") or "" phone = f.get("phone") or "" t = f.get("treatment_type") or "" parts = [f"{name}", f"{addr}, {city} {st}".strip(", ")] if phone: parts.append(f"📞 {phone}") if t: parts.append(f"Type: {t}") return "
".join(parts) def _get_facility_coords(facilities): """Return list of (lat, lon, facility_dict) for facilities that have coordinates.""" result = [] geocode_count = [0] for f in facilities: f = _normalize_facility(f) if not f: continue coord = _facility_coord(f, geocode_count) if coord: result.append((coord[0], coord[1], f)) return result def _build_google_map_html(facilities, force_update_id=None, selected_facility_name=None): """Build Google Maps in an iframe via srcdoc so scripts run (interactive map). Requires GOOGLE_MAPS_API_KEY.""" if not GOOGLE_MAPS_API_KEY: return _build_folium_map_html(facilities, None, force_update_id, selected_facility_name) try: facility_coords = _get_facility_coords(facilities) center_lat, center_lon, zoom = 39.5, -98.5, 3 if facility_coords: lats = [c[0] for c in facility_coords] lons = [c[1] for c in facility_coords] center_lat = sum(lats) / len(lats) center_lon = sum(lons) / len(lons) zoom = 10 markers_data = [] for i, (lat, lon, f) in enumerate(facility_coords): name = (f.get("facility_name") or f.get("name") or "Facility").replace("<", "<").replace(">", ">") info = _popup_html(f) sel = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name markers_data.append({"lat": lat, "lng": lon, "name": name, "info": info, "selected": sel, "label": str(i + 1)}) # Base64-encode markers so srcdoc HTML escaping cannot break the JSON markers_json = json.dumps(markers_data) markers_b64 = base64.b64encode(markers_json.encode("utf-8")).decode("ascii") script_url = f"https://maps.googleapis.com/maps/api/js?key={GOOGLE_MAPS_API_KEY}&callback=init" doc = f"""
""" # Embed in iframe via srcdoc so the document runs in its own context and scripts execute escaped = html_module.escape(doc, quote=True) html = f'' if force_update_id is not None: html += f"" return html except Exception as e: return ( f'
' f'Map could not be loaded. ({str(e)[:80]})
' ) def _build_map_html(facilities, user_location_str=None, force_update_id=None, selected_facility_name=None): """Build map HTML: Google Maps (iframe) when key is set, else Folium/OSM.""" if GOOGLE_MAPS_API_KEY: return _build_google_map_html(facilities, force_update_id, selected_facility_name) return _build_folium_map_html(facilities, user_location_str, force_update_id, selected_facility_name) def _build_folium_map_html(facilities, user_location_str=None, force_update_id=None, selected_facility_name=None): """ Build Folium (Leaflet) map as HTML. Scroll over map to zoom, drag to pan. If user_location_str is set, center on it and show a route to the first facility. selected_facility_name: if set, the matching facility is shown with a red star icon. force_update_id: optional unique value (e.g. timestamp) so Gradio re-renders the HTML each time. Returns safe fallback HTML on any error. """ try: center_lat, center_lon, zoom = 39.5, -98.5, 3 user_lat_lon = _geocode(user_location_str) if user_location_str else None facility_coords = [(c[0], c[1], c[2]) for c in _get_facility_coords(facilities)] facility_coords = [((c[0], c[1]), c[2]) for c in facility_coords] if user_lat_lon: center_lat, center_lon = user_lat_lon zoom = 11 elif facility_coords: lats = [c[0] for c, _ in facility_coords] lons = [c[1] for c, _ in facility_coords] center_lat = sum(lats) / len(lats) center_lon = sum(lons) / len(lons) zoom = 10 m = folium.Map( location=[center_lat, center_lon], zoom_start=zoom, tiles="OpenStreetMap", control_scale=True, zoom_control=True, ) m.options["scrollWheelZoom"] = True m.options["touchZoom"] = True m.options["dragging"] = True if user_lat_lon: folium.Marker( user_lat_lon, popup="You", tooltip="Your location", icon=folium.Icon(color="blue", icon="info-sign"), ).add_to(m) if user_lat_lon and facility_coords: dest_lat, dest_lon = facility_coords[0][0][0], facility_coords[0][0][1] route = _get_route(user_lat_lon[0], user_lat_lon[1], dest_lat, dest_lon) if route: folium.PolyLine(route, color="teal", weight=4, opacity=0.8).add_to(m) for i, ((lat, lon), f) in enumerate(facility_coords): is_selected = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name name = f.get("facility_name") or f.get("name") or "Facility" tooltip = f"{i + 1}. {name}" color = "#c62828" if is_selected else "#319795" fill_color = "#e53935" if is_selected else "#26a69a" folium.CircleMarker( location=[lat, lon], radius=12, popup=folium.Popup(_popup_html(f), max_width=280), tooltip=tooltip, color=color, fill=True, fill_color=fill_color, fill_opacity=0.9, weight=2, ).add_to(m) # Center and zoom to show all proposed locations if facility_coords: lats = [c[0] for c, _ in facility_coords] lons = [c[1] for c, _ in facility_coords] sw = [min(lats), min(lons)] ne = [max(lats), max(lons)] # Avoid zero-size bounds (single marker): add small buffer buf = 0.01 if ne[0] - sw[0] < buf: sw[0] -= buf ne[0] += buf if ne[1] - sw[1] < buf: sw[1] -= buf ne[1] += buf m.fit_bounds([sw, ne]) html = m._repr_html_() wrapper = f'
{html}
' # Force Gradio to re-render: append a unique comment so the value always changes if force_update_id is not None: wrapper += f"" return wrapper except Exception as e: return ( f'
' f'Map could not be loaded. ({str(e)[:80]})
' ) DISCLAIMER_HTML = ( '
' '

Disclaimer

' '

Information is from SAMHSA data. Always verify with the facility or ' 'findtreatment.gov ' 'before making decisions. This tool does not endorse any facility.

' '
' ) DESCRIPTION = ( "Find treatment facilities by chatting: say where you are (city or state), what type of care you need, " "and payment (e.g. Medicaid, insurance). Results show on the map. Data from SAMHSA only." ) EXAMPLES = [ "I'm looking for outpatient alcohol treatment in Boston with Medicaid.", "Do you have options for veterans in Texas?", "Hi, I need help finding a residential program in California that accepts Medicaid.", ] CSS = """ .disclaimer { font-size: 0.9rem; color: #2d3748; padding: 0.875rem 1rem; background: #e2e8f0; border-radius: 10px; margin-bottom: 0.75rem; border-left: 4px solid #319795; box-shadow: 0 1px 3px rgba(0,0,0,0.08); } .disclaimer-title { font-weight: 600; color: #1a202c; margin: 0 0 0.35em 0; font-size: 0.95em; } .disclaimer-text { margin: 0; line-height: 1.5; color: #2d3748; } .disclaimer a { color: #2c7a7b; text-decoration: none; font-weight: 600; } .disclaimer a:hover { text-decoration: underline; } .map-pane { padding: 0.25rem 0 0 0; } .map-pane .map-html { border-radius: 12px; overflow: hidden; box-shadow: 0 2px 12px rgba(0,0,0,0.08); } .map-pane iframe { border-radius: 12px; } .try-label { font-size: 0.9em; margin-bottom: 0.25rem; } """ def _messages_to_tuples(history): """Convert Gradio 6 messages format to [(user, assistant), ...] for get_response.""" if not history: return [] out = [] for item in history: if isinstance(item, (list, tuple)) and len(item) >= 2: out.append([item[0], item[1]]) elif isinstance(item, dict): role, content = item.get("role"), item.get("content", "") if role == "user": out.append([content, ""]) elif role == "assistant": if out: out[-1][1] = content else: out.append(["", content]) else: out.append(["", str(item)]) return out def _tuples_to_messages(history): """Convert [(user, assistant), ...] to Gradio 6 messages format.""" out = [] for user_msg, assistant_msg in history or []: if user_msg: out.append({"role": "user", "content": user_msg}) if assistant_msg: out.append({"role": "assistant", "content": assistant_msg}) return out def create_demo(): chatbot = Chatbot() with gr.Blocks(title="SAMHSA Treatment Locator") as demo: gr.Markdown("# SAMHSA Treatment Locator") gr.Markdown(DESCRIPTION) gr.HTML(DISCLAIMER_HTML) state = gr.State(DEFAULT_STATE) with gr.Row(): # Left: Folium (Leaflet) map — scroll over map to zoom, drag to pan with gr.Column(scale=5, min_width=320, elem_classes=["map-pane"]): gr.Markdown("**Map** — scroll over map to zoom, drag to pan. Search in chat to see facilities.") map_html = gr.HTML( value=_build_map_html([], None), elem_classes=["map-html"], ) # Right: chat with gr.Column(scale=5, min_width=320): gr.Markdown("**Chat** — tell me location, treatment type, and payment.") _chat_kw = { "label": "Conversation", "placeholder": "E.g. I'm in Boston, need outpatient treatment with Medicaid.", "height": 420, "show_label": False, } if "type" in __import__("inspect").signature(gr.Chatbot).parameters: _chat_kw["type"] = "messages" chat = gr.Chatbot(**_chat_kw) with gr.Row(): msg = gr.Textbox( placeholder="Type a message…", show_label=False, container=False, scale=8, ) submit_btn = gr.Button("Send", variant="primary", scale=1) gr.Markdown("**Try:**", elem_classes=["try-label"]) gr.Examples( examples=EXAMPLES, inputs=msg, label=None, examples_per_page=6, ) def user_submit(message, history, state): update_id = str(time.time()) if not message or not message.strip(): facilities = list(state.get("last_results") or []) sel = state.get("selected_facility_name") map_html_out = _build_map_html(facilities, None, update_id, sel) return history, state, "", map_html_out try: history_tuples = _messages_to_tuples(history) reply, new_state = chatbot.get_response(message, history_tuples, state) new_state = dict(new_state) new_history_tuples = history_tuples + [[message, reply]] new_history_messages = _tuples_to_messages(new_history_tuples) facilities = list(new_state.get("last_results") or []) sel = new_state.get("selected_facility_name") map_html_out = _build_map_html(facilities, None, update_id, sel) return new_history_messages, new_state, "", map_html_out except Exception as e: err_msg = str(e)[:200] reply = f"Sorry, something went wrong: {err_msg}" if "token" in err_msg.lower() or "auth" in err_msg.lower(): reply += " Check that HF_TOKEN is set in .env for the chat model." history_tuples = _messages_to_tuples(history) new_history_tuples = history_tuples + [[message, reply]] new_history_messages = _tuples_to_messages(new_history_tuples) facilities = list(state.get("last_results") or []) sel = state.get("selected_facility_name") map_html_out = _build_map_html(facilities, None, update_id, sel) return new_history_messages, state, "", map_html_out submit_btn.click( user_submit, inputs=[msg, chat, state], outputs=[chat, state, msg, map_html], ) msg.submit( user_submit, inputs=[msg, chat, state], outputs=[chat, state, msg, map_html], ) return demo if __name__ == "__main__": import inspect demo = create_demo() sig = inspect.signature(demo.launch) kwargs = {} if "css" in sig.parameters: kwargs["css"] = CSS if "theme" in sig.parameters and hasattr(gr, "themes"): kwargs["theme"] = gr.themes.Soft(primary_hue="teal", secondary_hue="slate") demo.launch(**kwargs)