"""
SAMHSA Treatment Locator – Gradio app for HuggingFace Spaces.
Two-pane layout: map (left) + chat (right). When GOOGLE_MAPS_API_KEY is set in .env,
the map uses Google Maps (JavaScript API); otherwise Folium/OpenStreetMap.
Facility markers and search results are shown on the map.
"""
import base64
import html as html_module
import json
import os
import time
import folium
import gradio as gr
import requests
# Load .env from project root so GOOGLE_MAPS_API_KEY is set regardless of launch cwd
_APP_DIR = os.path.dirname(os.path.abspath(__file__))
_ENV_PATH = os.path.join(_APP_DIR, ".env")
try:
from dotenv import load_dotenv
load_dotenv(_ENV_PATH)
except ImportError:
pass
from src.chat import DEFAULT_STATE, Chatbot
# Use Google Maps when key is set; otherwise Folium/OSM. Enable "Maps JavaScript API" in Google Cloud.
GOOGLE_MAPS_API_KEY = (os.environ.get("GOOGLE_MAPS_API_KEY") or "").strip() or None
if not GOOGLE_MAPS_API_KEY and os.path.isfile(_ENV_PATH):
with open(_ENV_PATH) as f:
for line in f:
line = line.strip()
if line.startswith("GOOGLE_MAPS_API_KEY=") and "=" in line:
key = line.split("=", 1)[1].strip().strip('"').strip("'")
if key:
GOOGLE_MAPS_API_KEY = key
break
# City/place -> (lat, lon) for map pins (fast lookup). Add more as needed.
CITY_COORDS = {
"boston": (42.36, -71.06),
"belmont": (42.40, -71.18),
"roxbury": (42.33, -71.08),
"allston": (42.35, -71.13),
"amesbury": (42.86, -70.93),
"athol": (42.59, -72.23),
"abilene": (32.45, -99.73),
"addison": (32.96, -96.83),
"austin": (30.27, -97.74),
"san antonio": (29.42, -98.49),
"san francisco": (37.77, -122.42),
"los angeles": (34.05, -118.25),
"lake view terrace": (34.27, -118.37),
"chicago": (41.88, -87.63),
}
def _normalize_facility(f):
"""Ensure facility dict from state or search has string values and expected keys. Accepts any key casing."""
if not f or not isinstance(f, dict):
return None
# Case-insensitive key lookup (state/JSON may use different casing)
key_map = {k.lower(): k for k in f.keys() if isinstance(k, str)}
def get_any(*names):
for n in names:
k = n.lower()
if k in key_map:
val = f.get(key_map[k])
if val is not None and str(val).strip() and str(val).lower() != "nan":
return str(val).strip()
return None
city = get_any("city", "City")
state = get_any("state", "State")
if not city and not state:
return None
out = {"city": city or "", "state": state or ""}
for name, out_key in (
("facility_name", "facility_name"), ("name", "facility_name"),
("address", "address"), ("phone", "phone"),
("treatment_type", "treatment_type"), ("services", "services"),
):
val = get_any(name)
if val:
out[out_key] = val
if "facility_name" not in out:
out["facility_name"] = get_any("facility_name", "name") or "Facility"
return out
# Geocode cache so any city/state from search results can be shown on the map.
_GEOCODE_CACHE = {}
# Show all proposed facilities (chat returns up to 5; allow more for geocoding).
_MAX_GEOCODE_PER_MAP = 20
MAP_HEIGHT_PX = 420
def _geocode(location_str):
"""Resolve address/city to (lat, lon) using Nominatim. Returns None on failure. Uses cache."""
if not location_str or not location_str.strip():
return None
key = location_str.strip().lower()
if key in _GEOCODE_CACHE:
return _GEOCODE_CACHE[key]
try:
from geopy.geocoders import Nominatim
from geopy.extra.rate_limiter import RateLimiter
geocoder = Nominatim(user_agent="samhsa-treatment-locator")
geocode = RateLimiter(geocoder.geocode, min_delay_seconds=1.0)
result = geocode(location_str.strip(), country_codes="us")
if result:
coord = (result.latitude, result.longitude)
_GEOCODE_CACHE[key] = coord
return coord
except Exception:
pass
return CITY_COORDS.get(key)
def _decode_google_polyline(encoded: str):
"""Decode Google's encoded polyline string to list of (lat, lon)."""
# https://developers.google.com/maps/documentation/utilities/polylinealgorithm
coords = []
i = 0
lat = 0
lon = 0
while i < len(encoded):
b = 0
shift = 0
result = 0
while True:
b = ord(encoded[i]) - 63
i += 1
result |= (b & 0x1F) << shift
shift += 5
if b < 0x20:
break
dlat = ~(result >> 1) if result & 1 else result >> 1
lat += dlat
shift = 0
result = 0
while True:
b = ord(encoded[i]) - 63
i += 1
result |= (b & 0x1F) << shift
shift += 5
if b < 0x20:
break
dlon = ~(result >> 1) if result & 1 else result >> 1
lon += dlon
coords.append((lat * 1e-5, lon * 1e-5))
return coords
def _get_route(lat1, lon1, lat2, lon2):
"""Return list of (lat, lon) for driving route. Uses Google Directions if GOOGLE_MAPS_API_KEY set, else OSRM."""
if GOOGLE_MAPS_API_KEY:
try:
url = (
"https://maps.googleapis.com/maps/api/directions/json"
f"?origin={lat1},{lon1}&destination={lat2},{lon2}&key={GOOGLE_MAPS_API_KEY}"
)
r = requests.get(url, timeout=10)
if r.status_code != 200:
return _get_route_osrm(lat1, lon1, lat2, lon2)
data = r.json()
if data.get("status") != "OK" or not data.get("routes"):
return _get_route_osrm(lat1, lon1, lat2, lon2)
points = data["routes"][0].get("overview_polyline", {}).get("points")
if points:
return _decode_google_polyline(points)
except Exception:
pass
return _get_route_osrm(lat1, lon1, lat2, lon2)
return _get_route_osrm(lat1, lon1, lat2, lon2)
def _get_route_osrm(lat1, lon1, lat2, lon2):
"""Return list of (lat, lon) for driving route from OSRM (free), or None."""
try:
url = (
"https://router.project-osrm.org/route/v1/driving/"
f"{lon1},{lat1};{lon2},{lat2}?overview=full&geometries=geojson"
)
r = requests.get(url, timeout=5)
if r.status_code != 200:
return None
data = r.json()
if not data.get("routes"):
return None
coords = data["routes"][0]["geometry"]["coordinates"]
return [(c[1], c[0]) for c in coords]
except Exception:
return None
def _facility_coord(f, geocode_count=None):
"""(lat, lon) for a facility: CITY_COORDS first, then geocode city+state (cached). f uses CSV keys."""
city = (f.get("city") or "").strip()
state = (f.get("state") or "").strip()
if not city and not state:
return None
city_lower = city.lower()
state_lower = state.lower()
coord = CITY_COORDS.get(city_lower) or (CITY_COORDS.get(state_lower) if state_lower else None)
if coord:
return coord
location_str = f"{city}, {state}".strip(", ")
if not location_str:
return None
if geocode_count is not None and len(geocode_count) == 1 and location_str.lower() not in _GEOCODE_CACHE:
if geocode_count[0] >= _MAX_GEOCODE_PER_MAP:
return None
geocode_count[0] += 1
return _geocode(location_str)
def _popup_html(f):
"""Short HTML for Folium popup. f uses CSV keys (facility_name, address, etc.)."""
name = f.get("facility_name") or f.get("name") or "Facility"
addr = f.get("address") or ""
city = f.get("city") or ""
st = f.get("state") or ""
phone = f.get("phone") or ""
t = f.get("treatment_type") or ""
parts = [f"{name}", f"{addr}, {city} {st}".strip(", ")]
if phone:
parts.append(f"📞 {phone}")
if t:
parts.append(f"Type: {t}")
return "
".join(parts)
def _get_facility_coords(facilities):
"""Return list of (lat, lon, facility_dict) for facilities that have coordinates."""
result = []
geocode_count = [0]
for f in facilities:
f = _normalize_facility(f)
if not f:
continue
coord = _facility_coord(f, geocode_count)
if coord:
result.append((coord[0], coord[1], f))
return result
def _build_google_map_html(facilities, force_update_id=None, selected_facility_name=None):
"""Build Google Maps in an iframe via srcdoc so scripts run (interactive map). Requires GOOGLE_MAPS_API_KEY."""
if not GOOGLE_MAPS_API_KEY:
return _build_folium_map_html(facilities, None, force_update_id, selected_facility_name)
try:
facility_coords = _get_facility_coords(facilities)
center_lat, center_lon, zoom = 39.5, -98.5, 3
if facility_coords:
lats = [c[0] for c in facility_coords]
lons = [c[1] for c in facility_coords]
center_lat = sum(lats) / len(lats)
center_lon = sum(lons) / len(lons)
zoom = 10
markers_data = []
for i, (lat, lon, f) in enumerate(facility_coords):
name = (f.get("facility_name") or f.get("name") or "Facility").replace("<", "<").replace(">", ">")
info = _popup_html(f)
sel = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name
markers_data.append({"lat": lat, "lng": lon, "name": name, "info": info, "selected": sel, "label": str(i + 1)})
# Base64-encode markers so srcdoc HTML escaping cannot break the JSON
markers_json = json.dumps(markers_data)
markers_b64 = base64.b64encode(markers_json.encode("utf-8")).decode("ascii")
script_url = f"https://maps.googleapis.com/maps/api/js?key={GOOGLE_MAPS_API_KEY}&callback=init"
doc = f"""
Disclaimer
' 'Information is from SAMHSA data. Always verify with the facility or ' 'findtreatment.gov ' 'before making decisions. This tool does not endorse any facility.
' '