import streamlit as st
import keras
import numpy as np
from PIL import Image
import io, urllib.request, os
st.set_page_config(layout="wide")
@st.cache_data(show_spinner=False, ttl=600)
def fetch_satellite_tile(lat, lng, zoom=16, size="640x640", api_key=""):
if not api_key:
raise RuntimeError("Missing Google Static Maps API key in env var 'goog_api'.")
url = (
f"https://maps.googleapis.com/maps/api/staticmap?"
f"center={lat},{lng}&zoom={zoom}&size={size}&maptype=satellite&key={api_key}"
)
buffer = io.BytesIO(urllib.request.urlopen(url, timeout=10).read())
return Image.open(buffer).convert("RGB")
@st.cache_resource
def get_model():
# compile=False skips optimizer/state rebuild, which saves time
return keras.models.load_model("0.0008-0.92.keras", compile=False)
st.markdown("""
""", unsafe_allow_html=True)
#title
col1, col2 = st.columns(2)
with col1:
_, col = st.columns(2)
with col:
st.markdown(
"
Overpass Detector
",
unsafe_allow_html=True
)
st.markdown(
"""
This page runs a trained machine learning model to detect road overpasses in satellite imagery.
""",
unsafe_allow_html=True
)
with col2:
_, col, _ = st.columns(3)
with col:
st.image('overpass.png')
st.write("---")
#load model and initialize image size required by model. uploaded images are resized to indicated size
img_height = 640
img_width = 640
state = st.session_state
state.loaded_model = get_model()
if "lat" not in state:
state.lat = 39.11
if "lng" not in state:
state.lng = -86.56
if "coords_submitted" not in state:
state.coords_submitted = False
if "img" not in state:
state.img = None
# Preload default image once
if state.img is None:
try:
api_key = os.getenv("goog_api", "")
state.img = fetch_satellite_tile(state.lat, state.lng, api_key=api_key)
except Exception as e:
st.info(f"Couldn’t fetch default tile: {e}")
col1, col2, col3 = st.columns([0.8, 1.8, 1.0]) # adjust ratios to taste
with col3:
st.subheader('Enter latitude/longitude coordinates:')
with st.form("coords_form"):
c1, c2 = st.columns(2)
with c1:
st.number_input('Latitude', key="lat", min_value=-90.0, max_value=90.0, step=0.01, format="%.2f")
st.write('The current lat/long are:')
with c2:
st.number_input('Longitude', key="lng", min_value=-180.0, max_value=180.0, step=0.01, format="%.2f")
st.write(f"{st.session_state.lat:.2f}, {st.session_state.lng:.2f}")
submitted = st.form_submit_button("Get Image and Prediction")
if submitted:
try:
api_key = os.getenv("goog_api", "")
state.img = fetch_satellite_tile(st.session_state.lat, st.session_state.lng, api_key=api_key)
st.rerun() # ensures the updated lat/lng render immediately
except Exception as e:
st.error(f"Error fetching image: {e}")
with col2:
if state.coords_submitted:
state.coords_submitted = False
try:
api_key = os.getenv("goog_api", "")
state.img = fetch_satellite_tile(state.lat, state.lng, api_key=api_key)
except Exception as e:
st.error(f"Error fetching image: {e}")
if state.img is not None:
st.markdown(
"",
unsafe_allow_html=True
)
st.image(state.img, width=640)
st.markdown("
", unsafe_allow_html=True)
with col1:
if state.img is not None:
img_array = np.array(state.img)
batch_size = 1
img_array = np.reshape(img_array,[batch_size,img_height,img_width,3])
with st.spinner("Running inference..."):
result = state.loaded_model.predict(img_array)
crossing_chance = result[0][1]*100
pct = np.round(crossing_chance,decimals=2)
label = "Overpass Likely" if pct >= 50 else "Overpass Unlikely"
st.markdown(
f"""
{label}
Model probability: {pct}%
""",
unsafe_allow_html=True
)