Spaces:
Sleeping
Sleeping
| """ | |
| SAMHSA Treatment Locator – Gradio app for HuggingFace Spaces. | |
| Two-pane layout: map (left) + chat (right). When GOOGLE_MAPS_API_KEY is set in .env, | |
| the map uses Google Maps (JavaScript API); otherwise Folium/OpenStreetMap. | |
| Facility markers and search results are shown on the map. | |
| """ | |
| import base64 | |
| import html as html_module | |
| import json | |
| import os | |
| import time | |
| import folium | |
| import gradio as gr | |
| import requests | |
| # Load .env from project root so GOOGLE_MAPS_API_KEY is set regardless of launch cwd | |
| _APP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| _ENV_PATH = os.path.join(_APP_DIR, ".env") | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv(_ENV_PATH) | |
| except ImportError: | |
| pass | |
| from src.chat import DEFAULT_STATE, Chatbot | |
| # Use Google Maps when key is set; otherwise Folium/OSM. Enable "Maps JavaScript API" in Google Cloud. | |
| GOOGLE_MAPS_API_KEY = (os.environ.get("GOOGLE_MAPS_API_KEY") or "").strip() or None | |
| if not GOOGLE_MAPS_API_KEY and os.path.isfile(_ENV_PATH): | |
| with open(_ENV_PATH) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line.startswith("GOOGLE_MAPS_API_KEY=") and "=" in line: | |
| key = line.split("=", 1)[1].strip().strip('"').strip("'") | |
| if key: | |
| GOOGLE_MAPS_API_KEY = key | |
| break | |
| # City/place -> (lat, lon) for map pins (fast lookup). Add more as needed. | |
| CITY_COORDS = { | |
| "boston": (42.36, -71.06), | |
| "belmont": (42.40, -71.18), | |
| "roxbury": (42.33, -71.08), | |
| "allston": (42.35, -71.13), | |
| "amesbury": (42.86, -70.93), | |
| "athol": (42.59, -72.23), | |
| "abilene": (32.45, -99.73), | |
| "addison": (32.96, -96.83), | |
| "austin": (30.27, -97.74), | |
| "san antonio": (29.42, -98.49), | |
| "san francisco": (37.77, -122.42), | |
| "los angeles": (34.05, -118.25), | |
| "lake view terrace": (34.27, -118.37), | |
| "chicago": (41.88, -87.63), | |
| } | |
| def _normalize_facility(f): | |
| """Ensure facility dict from state or search has string values and expected keys. Accepts any key casing.""" | |
| if not f or not isinstance(f, dict): | |
| return None | |
| # Case-insensitive key lookup (state/JSON may use different casing) | |
| key_map = {k.lower(): k for k in f.keys() if isinstance(k, str)} | |
| def get_any(*names): | |
| for n in names: | |
| k = n.lower() | |
| if k in key_map: | |
| val = f.get(key_map[k]) | |
| if val is not None and str(val).strip() and str(val).lower() != "nan": | |
| return str(val).strip() | |
| return None | |
| city = get_any("city", "City") | |
| state = get_any("state", "State") | |
| if not city and not state: | |
| return None | |
| out = {"city": city or "", "state": state or ""} | |
| for name, out_key in ( | |
| ("facility_name", "facility_name"), ("name", "facility_name"), | |
| ("address", "address"), ("phone", "phone"), | |
| ("treatment_type", "treatment_type"), ("services", "services"), | |
| ): | |
| val = get_any(name) | |
| if val: | |
| out[out_key] = val | |
| if "facility_name" not in out: | |
| out["facility_name"] = get_any("facility_name", "name") or "Facility" | |
| return out | |
| # Geocode cache so any city/state from search results can be shown on the map. | |
| _GEOCODE_CACHE = {} | |
| # Show all proposed facilities (chat returns up to 5; allow more for geocoding). | |
| _MAX_GEOCODE_PER_MAP = 20 | |
| MAP_HEIGHT_PX = 420 | |
| def _geocode(location_str): | |
| """Resolve address/city to (lat, lon) using Nominatim. Returns None on failure. Uses cache.""" | |
| if not location_str or not location_str.strip(): | |
| return None | |
| key = location_str.strip().lower() | |
| if key in _GEOCODE_CACHE: | |
| return _GEOCODE_CACHE[key] | |
| try: | |
| from geopy.geocoders import Nominatim | |
| from geopy.extra.rate_limiter import RateLimiter | |
| geocoder = Nominatim(user_agent="samhsa-treatment-locator") | |
| geocode = RateLimiter(geocoder.geocode, min_delay_seconds=1.0) | |
| result = geocode(location_str.strip(), country_codes="us") | |
| if result: | |
| coord = (result.latitude, result.longitude) | |
| _GEOCODE_CACHE[key] = coord | |
| return coord | |
| except Exception: | |
| pass | |
| return CITY_COORDS.get(key) | |
| def _decode_google_polyline(encoded: str): | |
| """Decode Google's encoded polyline string to list of (lat, lon).""" | |
| # https://developers.google.com/maps/documentation/utilities/polylinealgorithm | |
| coords = [] | |
| i = 0 | |
| lat = 0 | |
| lon = 0 | |
| while i < len(encoded): | |
| b = 0 | |
| shift = 0 | |
| result = 0 | |
| while True: | |
| b = ord(encoded[i]) - 63 | |
| i += 1 | |
| result |= (b & 0x1F) << shift | |
| shift += 5 | |
| if b < 0x20: | |
| break | |
| dlat = ~(result >> 1) if result & 1 else result >> 1 | |
| lat += dlat | |
| shift = 0 | |
| result = 0 | |
| while True: | |
| b = ord(encoded[i]) - 63 | |
| i += 1 | |
| result |= (b & 0x1F) << shift | |
| shift += 5 | |
| if b < 0x20: | |
| break | |
| dlon = ~(result >> 1) if result & 1 else result >> 1 | |
| lon += dlon | |
| coords.append((lat * 1e-5, lon * 1e-5)) | |
| return coords | |
| def _get_route(lat1, lon1, lat2, lon2): | |
| """Return list of (lat, lon) for driving route. Uses Google Directions if GOOGLE_MAPS_API_KEY set, else OSRM.""" | |
| if GOOGLE_MAPS_API_KEY: | |
| try: | |
| url = ( | |
| "https://maps.googleapis.com/maps/api/directions/json" | |
| f"?origin={lat1},{lon1}&destination={lat2},{lon2}&key={GOOGLE_MAPS_API_KEY}" | |
| ) | |
| r = requests.get(url, timeout=10) | |
| if r.status_code != 200: | |
| return _get_route_osrm(lat1, lon1, lat2, lon2) | |
| data = r.json() | |
| if data.get("status") != "OK" or not data.get("routes"): | |
| return _get_route_osrm(lat1, lon1, lat2, lon2) | |
| points = data["routes"][0].get("overview_polyline", {}).get("points") | |
| if points: | |
| return _decode_google_polyline(points) | |
| except Exception: | |
| pass | |
| return _get_route_osrm(lat1, lon1, lat2, lon2) | |
| return _get_route_osrm(lat1, lon1, lat2, lon2) | |
| def _get_route_osrm(lat1, lon1, lat2, lon2): | |
| """Return list of (lat, lon) for driving route from OSRM (free), or None.""" | |
| try: | |
| url = ( | |
| "https://router.project-osrm.org/route/v1/driving/" | |
| f"{lon1},{lat1};{lon2},{lat2}?overview=full&geometries=geojson" | |
| ) | |
| r = requests.get(url, timeout=5) | |
| if r.status_code != 200: | |
| return None | |
| data = r.json() | |
| if not data.get("routes"): | |
| return None | |
| coords = data["routes"][0]["geometry"]["coordinates"] | |
| return [(c[1], c[0]) for c in coords] | |
| except Exception: | |
| return None | |
| def _facility_coord(f, geocode_count=None): | |
| """(lat, lon) for a facility: CITY_COORDS first, then geocode city+state (cached). f uses CSV keys.""" | |
| city = (f.get("city") or "").strip() | |
| state = (f.get("state") or "").strip() | |
| if not city and not state: | |
| return None | |
| city_lower = city.lower() | |
| state_lower = state.lower() | |
| coord = CITY_COORDS.get(city_lower) or (CITY_COORDS.get(state_lower) if state_lower else None) | |
| if coord: | |
| return coord | |
| location_str = f"{city}, {state}".strip(", ") | |
| if not location_str: | |
| return None | |
| if geocode_count is not None and len(geocode_count) == 1 and location_str.lower() not in _GEOCODE_CACHE: | |
| if geocode_count[0] >= _MAX_GEOCODE_PER_MAP: | |
| return None | |
| geocode_count[0] += 1 | |
| return _geocode(location_str) | |
| def _popup_html(f): | |
| """Short HTML for Folium popup. f uses CSV keys (facility_name, address, etc.).""" | |
| name = f.get("facility_name") or f.get("name") or "Facility" | |
| addr = f.get("address") or "" | |
| city = f.get("city") or "" | |
| st = f.get("state") or "" | |
| phone = f.get("phone") or "" | |
| t = f.get("treatment_type") or "" | |
| parts = [f"<b>{name}</b>", f"{addr}, {city} {st}".strip(", ")] | |
| if phone: | |
| parts.append(f"📞 {phone}") | |
| if t: | |
| parts.append(f"Type: {t}") | |
| return "<br>".join(parts) | |
| def _get_facility_coords(facilities): | |
| """Return list of (lat, lon, facility_dict) for facilities that have coordinates.""" | |
| result = [] | |
| geocode_count = [0] | |
| for f in facilities: | |
| f = _normalize_facility(f) | |
| if not f: | |
| continue | |
| coord = _facility_coord(f, geocode_count) | |
| if coord: | |
| result.append((coord[0], coord[1], f)) | |
| return result | |
| def _build_google_map_html(facilities, force_update_id=None, selected_facility_name=None): | |
| """Build Google Maps in an iframe via srcdoc so scripts run (interactive map). Requires GOOGLE_MAPS_API_KEY.""" | |
| if not GOOGLE_MAPS_API_KEY: | |
| return _build_folium_map_html(facilities, None, force_update_id, selected_facility_name) | |
| try: | |
| facility_coords = _get_facility_coords(facilities) | |
| center_lat, center_lon, zoom = 39.5, -98.5, 3 | |
| if facility_coords: | |
| lats = [c[0] for c in facility_coords] | |
| lons = [c[1] for c in facility_coords] | |
| center_lat = sum(lats) / len(lats) | |
| center_lon = sum(lons) / len(lons) | |
| zoom = 10 | |
| markers_data = [] | |
| for i, (lat, lon, f) in enumerate(facility_coords): | |
| name = (f.get("facility_name") or f.get("name") or "Facility").replace("<", "<").replace(">", ">") | |
| info = _popup_html(f) | |
| sel = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name | |
| markers_data.append({"lat": lat, "lng": lon, "name": name, "info": info, "selected": sel, "label": str(i + 1)}) | |
| # Base64-encode markers so srcdoc HTML escaping cannot break the JSON | |
| markers_json = json.dumps(markers_data) | |
| markers_b64 = base64.b64encode(markers_json.encode("utf-8")).decode("ascii") | |
| script_url = f"https://maps.googleapis.com/maps/api/js?key={GOOGLE_MAPS_API_KEY}&callback=init" | |
| doc = f"""<!DOCTYPE html><html><head><meta charset="utf-8"></head><body style="margin:0"> | |
| <div id="map" style="width:100%;height:100%;min-height:{MAP_HEIGHT_PX}px;"></div> | |
| <script> | |
| var center = {{ lat: {center_lat}, lng: {center_lon} }}; | |
| var zoom = {zoom}; | |
| var markersData = JSON.parse(atob("{markers_b64}")); | |
| function init() {{ | |
| var map = new google.maps.Map(document.getElementById("map"), {{ center: center, zoom: zoom, mapTypeControl: true, fullscreenControl: true, zoomControl: true, scaleControl: true }}); | |
| var infowindow = new google.maps.InfoWindow(); | |
| var bounds = null; | |
| if (markersData && markersData.length) {{ | |
| markersData.forEach(function(m) {{ | |
| var pos = {{ lat: m.lat, lng: m.lng }}; | |
| var opts = {{ position: pos, map: map, title: m.name, label: {{ text: m.label || "", color: "white", fontWeight: "bold" }} }}; | |
| if (m.selected) {{ opts.animation = google.maps.Animation.BOUNCE; opts.label = {{ text: "★", color: "white", fontWeight: "bold" }}; }} | |
| var marker = new google.maps.Marker(opts); | |
| marker.addListener("click", function() {{ infowindow.setContent(m.info); infowindow.open(map, marker); }}); | |
| if (!bounds) bounds = new google.maps.LatLngBounds(pos, pos); | |
| else bounds.extend(pos); | |
| }}); | |
| if (bounds && markersData.length > 0) {{ | |
| map.fitBounds(bounds, {{ top: 40, right: 40, bottom: 40, left: 40 }}); | |
| if (markersData.length === 1) map.setZoom(12); | |
| }} | |
| }} | |
| }} | |
| </script> | |
| <script src="{script_url}" async defer></script> | |
| </body></html>""" | |
| # Embed in iframe via srcdoc so the document runs in its own context and scripts execute | |
| escaped = html_module.escape(doc, quote=True) | |
| html = f'<iframe srcdoc="{escaped}" style="width:100%;height:{MAP_HEIGHT_PX}px;border:0;border-radius:12px;" title="Google Map"></iframe>' | |
| if force_update_id is not None: | |
| html += f"<!-- map-update:{force_update_id} -->" | |
| return html | |
| except Exception as e: | |
| return ( | |
| f'<div style="width:100%;height:{MAP_HEIGHT_PX}px;display:flex;align-items:center;justify-content:center;' | |
| f'background:#f5f5f5;border-radius:12px;color:#666;">' | |
| f'Map could not be loaded. ({str(e)[:80]})</div>' | |
| ) | |
| def _build_map_html(facilities, user_location_str=None, force_update_id=None, selected_facility_name=None): | |
| """Build map HTML: Google Maps (iframe) when key is set, else Folium/OSM.""" | |
| if GOOGLE_MAPS_API_KEY: | |
| return _build_google_map_html(facilities, force_update_id, selected_facility_name) | |
| return _build_folium_map_html(facilities, user_location_str, force_update_id, selected_facility_name) | |
| def _build_folium_map_html(facilities, user_location_str=None, force_update_id=None, selected_facility_name=None): | |
| """ | |
| Build Folium (Leaflet) map as HTML. Scroll over map to zoom, drag to pan. | |
| If user_location_str is set, center on it and show a route to the first facility. | |
| selected_facility_name: if set, the matching facility is shown with a red star icon. | |
| force_update_id: optional unique value (e.g. timestamp) so Gradio re-renders the HTML each time. | |
| Returns safe fallback HTML on any error. | |
| """ | |
| try: | |
| center_lat, center_lon, zoom = 39.5, -98.5, 3 | |
| user_lat_lon = _geocode(user_location_str) if user_location_str else None | |
| facility_coords = [(c[0], c[1], c[2]) for c in _get_facility_coords(facilities)] | |
| facility_coords = [((c[0], c[1]), c[2]) for c in facility_coords] | |
| if user_lat_lon: | |
| center_lat, center_lon = user_lat_lon | |
| zoom = 11 | |
| elif facility_coords: | |
| lats = [c[0] for c, _ in facility_coords] | |
| lons = [c[1] for c, _ in facility_coords] | |
| center_lat = sum(lats) / len(lats) | |
| center_lon = sum(lons) / len(lons) | |
| zoom = 10 | |
| m = folium.Map( | |
| location=[center_lat, center_lon], | |
| zoom_start=zoom, | |
| tiles="OpenStreetMap", | |
| control_scale=True, | |
| zoom_control=True, | |
| ) | |
| m.options["scrollWheelZoom"] = True | |
| m.options["touchZoom"] = True | |
| m.options["dragging"] = True | |
| if user_lat_lon: | |
| folium.Marker( | |
| user_lat_lon, | |
| popup="You", | |
| tooltip="Your location", | |
| icon=folium.Icon(color="blue", icon="info-sign"), | |
| ).add_to(m) | |
| if user_lat_lon and facility_coords: | |
| dest_lat, dest_lon = facility_coords[0][0][0], facility_coords[0][0][1] | |
| route = _get_route(user_lat_lon[0], user_lat_lon[1], dest_lat, dest_lon) | |
| if route: | |
| folium.PolyLine(route, color="teal", weight=4, opacity=0.8).add_to(m) | |
| for i, ((lat, lon), f) in enumerate(facility_coords): | |
| is_selected = selected_facility_name and (f.get("facility_name") or f.get("name") or "") == selected_facility_name | |
| name = f.get("facility_name") or f.get("name") or "Facility" | |
| tooltip = f"{i + 1}. {name}" | |
| color = "#c62828" if is_selected else "#319795" | |
| fill_color = "#e53935" if is_selected else "#26a69a" | |
| folium.CircleMarker( | |
| location=[lat, lon], | |
| radius=12, | |
| popup=folium.Popup(_popup_html(f), max_width=280), | |
| tooltip=tooltip, | |
| color=color, | |
| fill=True, | |
| fill_color=fill_color, | |
| fill_opacity=0.9, | |
| weight=2, | |
| ).add_to(m) | |
| # Center and zoom to show all proposed locations | |
| if facility_coords: | |
| lats = [c[0] for c, _ in facility_coords] | |
| lons = [c[1] for c, _ in facility_coords] | |
| sw = [min(lats), min(lons)] | |
| ne = [max(lats), max(lons)] | |
| # Avoid zero-size bounds (single marker): add small buffer | |
| buf = 0.01 | |
| if ne[0] - sw[0] < buf: | |
| sw[0] -= buf | |
| ne[0] += buf | |
| if ne[1] - sw[1] < buf: | |
| sw[1] -= buf | |
| ne[1] += buf | |
| m.fit_bounds([sw, ne]) | |
| html = m._repr_html_() | |
| wrapper = f'<div style="width:100%;height:{MAP_HEIGHT_PX}px;overflow:hidden;border-radius:12px;">{html}</div>' | |
| # Force Gradio to re-render: append a unique comment so the value always changes | |
| if force_update_id is not None: | |
| wrapper += f"<!-- map-update:{force_update_id} -->" | |
| return wrapper | |
| except Exception as e: | |
| return ( | |
| f'<div style="width:100%;height:{MAP_HEIGHT_PX}px;display:flex;align-items:center;justify-content:center;' | |
| f'background:#f5f5f5;border-radius:12px;color:#666;font-family:sans-serif;">' | |
| f'Map could not be loaded. ({str(e)[:80]})</div>' | |
| ) | |
| DISCLAIMER_HTML = ( | |
| '<div class="disclaimer">' | |
| '<p class="disclaimer-title">Disclaimer</p>' | |
| '<p class="disclaimer-text">Information is from SAMHSA data. Always verify with the facility or ' | |
| '<a href="https://findtreatment.gov" target="_blank" rel="noopener">findtreatment.gov</a> ' | |
| 'before making decisions. This tool does not endorse any facility.</p>' | |
| '</div>' | |
| ) | |
| DESCRIPTION = ( | |
| "Find treatment facilities by chatting: say where you are (city or state), what type of care you need, " | |
| "and payment (e.g. Medicaid, insurance). Results show on the map. Data from SAMHSA only." | |
| ) | |
| EXAMPLES = [ | |
| "I'm looking for outpatient alcohol treatment in Boston with Medicaid.", | |
| "Do you have options for veterans in Texas?", | |
| "Hi, I need help finding a residential program in California that accepts Medicaid.", | |
| ] | |
| CSS = """ | |
| .disclaimer { | |
| font-size: 0.9rem; | |
| color: #2d3748; | |
| padding: 0.875rem 1rem; | |
| background: #e2e8f0; | |
| border-radius: 10px; | |
| margin-bottom: 0.75rem; | |
| border-left: 4px solid #319795; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.08); | |
| } | |
| .disclaimer-title { | |
| font-weight: 600; | |
| color: #1a202c; | |
| margin: 0 0 0.35em 0; | |
| font-size: 0.95em; | |
| } | |
| .disclaimer-text { | |
| margin: 0; | |
| line-height: 1.5; | |
| color: #2d3748; | |
| } | |
| .disclaimer a { | |
| color: #2c7a7b; | |
| text-decoration: none; | |
| font-weight: 600; | |
| } | |
| .disclaimer a:hover { | |
| text-decoration: underline; | |
| } | |
| .map-pane { padding: 0.25rem 0 0 0; } | |
| .map-pane .map-html { border-radius: 12px; overflow: hidden; box-shadow: 0 2px 12px rgba(0,0,0,0.08); } | |
| .map-pane iframe { border-radius: 12px; } | |
| .try-label { font-size: 0.9em; margin-bottom: 0.25rem; } | |
| """ | |
| def _messages_to_tuples(history): | |
| """Convert Gradio 6 messages format to [(user, assistant), ...] for get_response.""" | |
| if not history: | |
| return [] | |
| out = [] | |
| for item in history: | |
| if isinstance(item, (list, tuple)) and len(item) >= 2: | |
| out.append([item[0], item[1]]) | |
| elif isinstance(item, dict): | |
| role, content = item.get("role"), item.get("content", "") | |
| if role == "user": | |
| out.append([content, ""]) | |
| elif role == "assistant": | |
| if out: | |
| out[-1][1] = content | |
| else: | |
| out.append(["", content]) | |
| else: | |
| out.append(["", str(item)]) | |
| return out | |
| def _tuples_to_messages(history): | |
| """Convert [(user, assistant), ...] to Gradio 6 messages format.""" | |
| out = [] | |
| for user_msg, assistant_msg in history or []: | |
| if user_msg: | |
| out.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| out.append({"role": "assistant", "content": assistant_msg}) | |
| return out | |
| def create_demo(): | |
| chatbot = Chatbot() | |
| with gr.Blocks(title="SAMHSA Treatment Locator") as demo: | |
| gr.Markdown("# SAMHSA Treatment Locator") | |
| gr.Markdown(DESCRIPTION) | |
| gr.HTML(DISCLAIMER_HTML) | |
| state = gr.State(DEFAULT_STATE) | |
| with gr.Row(): | |
| # Left: Folium (Leaflet) map — scroll over map to zoom, drag to pan | |
| with gr.Column(scale=5, min_width=320, elem_classes=["map-pane"]): | |
| gr.Markdown("**Map** — scroll over map to zoom, drag to pan. Search in chat to see facilities.") | |
| map_html = gr.HTML( | |
| value=_build_map_html([], None), | |
| elem_classes=["map-html"], | |
| ) | |
| # Right: chat | |
| with gr.Column(scale=5, min_width=320): | |
| gr.Markdown("**Chat** — tell me location, treatment type, and payment.") | |
| _chat_kw = { | |
| "label": "Conversation", | |
| "placeholder": "E.g. I'm in Boston, need outpatient treatment with Medicaid.", | |
| "height": 420, | |
| "show_label": False, | |
| } | |
| if "type" in __import__("inspect").signature(gr.Chatbot).parameters: | |
| _chat_kw["type"] = "messages" | |
| chat = gr.Chatbot(**_chat_kw) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Type a message…", | |
| show_label=False, | |
| container=False, | |
| scale=8, | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| gr.Markdown("**Try:**", elem_classes=["try-label"]) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=msg, | |
| label=None, | |
| examples_per_page=6, | |
| ) | |
| def user_submit(message, history, state): | |
| update_id = str(time.time()) | |
| if not message or not message.strip(): | |
| facilities = list(state.get("last_results") or []) | |
| sel = state.get("selected_facility_name") | |
| map_html_out = _build_map_html(facilities, None, update_id, sel) | |
| return history, state, "", map_html_out | |
| try: | |
| history_tuples = _messages_to_tuples(history) | |
| reply, new_state = chatbot.get_response(message, history_tuples, state) | |
| new_state = dict(new_state) | |
| new_history_tuples = history_tuples + [[message, reply]] | |
| new_history_messages = _tuples_to_messages(new_history_tuples) | |
| facilities = list(new_state.get("last_results") or []) | |
| sel = new_state.get("selected_facility_name") | |
| map_html_out = _build_map_html(facilities, None, update_id, sel) | |
| return new_history_messages, new_state, "", map_html_out | |
| except Exception as e: | |
| err_msg = str(e)[:200] | |
| reply = f"Sorry, something went wrong: {err_msg}" | |
| if "token" in err_msg.lower() or "auth" in err_msg.lower(): | |
| reply += " Check that HF_TOKEN is set in .env for the chat model." | |
| history_tuples = _messages_to_tuples(history) | |
| new_history_tuples = history_tuples + [[message, reply]] | |
| new_history_messages = _tuples_to_messages(new_history_tuples) | |
| facilities = list(state.get("last_results") or []) | |
| sel = state.get("selected_facility_name") | |
| map_html_out = _build_map_html(facilities, None, update_id, sel) | |
| return new_history_messages, state, "", map_html_out | |
| submit_btn.click( | |
| user_submit, | |
| inputs=[msg, chat, state], | |
| outputs=[chat, state, msg, map_html], | |
| ) | |
| msg.submit( | |
| user_submit, | |
| inputs=[msg, chat, state], | |
| outputs=[chat, state, msg, map_html], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| import inspect | |
| demo = create_demo() | |
| sig = inspect.signature(demo.launch) | |
| kwargs = {} | |
| if "css" in sig.parameters: | |
| kwargs["css"] = CSS | |
| if "theme" in sig.parameters and hasattr(gr, "themes"): | |
| kwargs["theme"] = gr.themes.Soft(primary_hue="teal", secondary_hue="slate") | |
| demo.launch(**kwargs) | |