|
|
import streamlit as st |
|
|
import folium |
|
|
from folium import plugins |
|
|
from streamlit_folium import st_folium |
|
|
import rasterio |
|
|
from rasterio.warp import calculate_default_transform, reproject, Resampling |
|
|
import joblib |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import geopandas as gpd |
|
|
from pathlib import Path |
|
|
from matplotlib import colors as colors |
|
|
import time |
|
|
from rasterio.crs import CRS |
|
|
|
|
|
from worldcereal.job import INFERENCE_JOB_OPTIONS, create_embeddings_process_graph |
|
|
from openeo_gfmap import TemporalContext, BoundingBoxExtent |
|
|
from worldcereal.parameters import EmbeddingsParameters |
|
|
|
|
|
import openeo |
|
|
import os |
|
|
|
|
|
|
|
|
crop_classes = { |
|
|
"tuz": "#32cd32", |
|
|
"burak": "#8b008b", |
|
|
"jęczmień": "#ffd700", |
|
|
"kukurydza": "#ffa500", |
|
|
"lucerna": "#9acd32", |
|
|
"mieszanka": "#daa520", |
|
|
"owies": "#f0e68c", |
|
|
"pszenica": "#f5deb3", |
|
|
"pszenżyto": "#bdb76b", |
|
|
"rzepak": "#ffff00", |
|
|
"sad": "#228b22", |
|
|
"słonecznik": "#ff4500", |
|
|
"ziemniak": "#a0522d", |
|
|
"łubin": "#9370db", |
|
|
"żyto": "#cd853f", |
|
|
"inne": "#808080" |
|
|
} |
|
|
class_to_id = {name: i for i, name in enumerate(crop_classes.keys())} |
|
|
id_to_class = {i: name for name, i in class_to_id.items()} |
|
|
|
|
|
st.set_page_config(page_title="Crop Map", layout="wide") |
|
|
|
|
|
model_path = Path("models/random_forest_crop_classifier_06.joblib") |
|
|
demo_dir = Path("embeddings") |
|
|
temp_dir = Path("embeddings/temp_analysis") |
|
|
temp_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def get_class_color_rgba(class_name, alpha=180): |
|
|
hex_color = crop_classes.get(class_name, "#000000") |
|
|
rgb = colors.hex2color(hex_color) |
|
|
return (int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255), alpha) |
|
|
|
|
|
|
|
|
def create_legend_html(stats_legend): |
|
|
html_parts = [ |
|
|
"<div style='background-color: rgba(255, 255, 255, 0.1); padding: 10px; border-radius: 5px; font-family: sans-serif;'>" |
|
|
] |
|
|
|
|
|
for _, row in stats_legend.iterrows(): |
|
|
crop = row['Crop'] |
|
|
color = crop_classes.get(crop, "#000000") |
|
|
percent = row['Percentage'] |
|
|
|
|
|
row_html = ( |
|
|
f"<div style='display: flex; align-items: center; margin-bottom: 4px;'>" |
|
|
f"<div style='width: 15px; height: 15px; background-color: {color}; margin-right: 10px; border-radius: 3px;'></div>" |
|
|
f"<span style='font-size: 14px; flex-grow: 1;'>{crop}</span>" |
|
|
f"<span style='font-weight: bold; font-size: 14px;'>{percent:.1f}%</span>" |
|
|
f"</div>" |
|
|
) |
|
|
html_parts.append(row_html) |
|
|
|
|
|
html_parts.append("</div>") |
|
|
return "".join(html_parts) |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
if not model_path.exists(): return None |
|
|
return joblib.load(model_path) |
|
|
|
|
|
@st.cache_data |
|
|
def run_prediction(tif_path, _model): |
|
|
with rasterio.open(tif_path) as src: |
|
|
embedding = src.read() |
|
|
src_transform = src.transform |
|
|
src_crs = src.crs |
|
|
h, w = src.height, src.width |
|
|
|
|
|
n_channels = embedding.shape[0] |
|
|
reshaped = embedding.transpose(1, 2, 0).reshape(-1, n_channels) |
|
|
|
|
|
|
|
|
batch_size = 50000 |
|
|
preds = [] |
|
|
for i in range(0, reshaped.shape[0], batch_size): |
|
|
batch = reshaped[i:i + batch_size] |
|
|
batch = np.nan_to_num(batch) |
|
|
preds.append(_model.predict(batch)) |
|
|
|
|
|
raw_class_map_str = np.concatenate(preds).reshape(h, w) |
|
|
|
|
|
raw_class_map_int = np.zeros((h, w), dtype=np.uint8) |
|
|
for class_name, class_id in class_to_id.items(): |
|
|
raw_class_map_int[raw_class_map_str == class_name] = class_id |
|
|
|
|
|
src_crs_str = src_crs.to_string() |
|
|
dst_crs = CRS.from_string('EPSG:4326') |
|
|
|
|
|
left, bottom, right, top = rasterio.transform.array_bounds(h, w, src_transform) |
|
|
transform, dst_width, dst_height = calculate_default_transform( |
|
|
src_crs_str, dst_crs, w, h, left=left, bottom=bottom, right=right, top=top |
|
|
) |
|
|
|
|
|
destination = np.zeros((dst_height, dst_width), dtype=np.uint8) |
|
|
reproject( |
|
|
source=raw_class_map_int, |
|
|
destination=destination, |
|
|
src_transform=src_transform, |
|
|
src_crs=src_crs_str, |
|
|
dst_transform=transform, |
|
|
dst_crs=dst_crs, |
|
|
resampling=Resampling.nearest |
|
|
) |
|
|
|
|
|
bounds_orig = rasterio.transform.array_bounds(dst_height, dst_width, transform) |
|
|
folium_bounds = [[bounds_orig[1], bounds_orig[0]], [bounds_orig[3], bounds_orig[2]]] |
|
|
|
|
|
return destination, folium_bounds |
|
|
|
|
|
def run_openeo_job(lat, lon, size_km=1.0): |
|
|
""" |
|
|
Runs WorldCereal job for a small box around lat/lon. |
|
|
Returns path to downloaded tif or None. |
|
|
""" |
|
|
con = openeo.connect("https://openeo.dataspace.copernicus.eu") |
|
|
|
|
|
refresh_token = os.environ.get("OPENEO_REFRESH_TOKEN") |
|
|
client_id = os.environ.get("OPENEO_CLIENT_ID") |
|
|
client_secret = os.environ.get("OPENEO_CLIENT_SECRET") |
|
|
|
|
|
try: |
|
|
if refresh_token: |
|
|
print("Using Refresh Token Auth") |
|
|
con.authenticate_oidc_refresh_token( |
|
|
refresh_token=refresh_token, |
|
|
client_id="sh-b1c3a958-52d4-40fe-a333-153595d1c71e" |
|
|
) |
|
|
elif client_id and client_secret: |
|
|
print("Using Client Credentials Auth") |
|
|
con.authenticate_oidc_client_credentials( |
|
|
client_id=client_id, |
|
|
client_secret=client_secret |
|
|
) |
|
|
else: |
|
|
if os.environ.get("SPACE_ID"): |
|
|
st.error("uthentication Error: No secrets found. Please set OPENEO_REFRESH_TOKEN in Hugging Face Space settings.") |
|
|
return None |
|
|
|
|
|
print("Using Interactive Auth (Local)") |
|
|
con.authenticate_oidc() |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Authentication Failed: {str(e)}") |
|
|
return None |
|
|
|
|
|
try: |
|
|
offset = (size_km / 111) / 2 |
|
|
west, east = lon - offset, lon + offset |
|
|
south, north = lat - offset, lat + offset |
|
|
|
|
|
spatial_extent = BoundingBoxExtent( |
|
|
west=west, south=south, east=east, north=north, epsg=4326 |
|
|
) |
|
|
|
|
|
|
|
|
temporal_extent = TemporalContext("2025-01-01", "2025-12-31") |
|
|
|
|
|
st.info("Building OpenEO Process Graph...") |
|
|
embedding_params = EmbeddingsParameters() |
|
|
inference_result = create_embeddings_process_graph( |
|
|
spatial_extent=spatial_extent, |
|
|
temporal_extent=temporal_extent, |
|
|
embeddings_parameters=embedding_params, |
|
|
scale_uint16=True, |
|
|
connection=con |
|
|
) |
|
|
|
|
|
job_title = f"thesis_demo_{lat}_{lon}" |
|
|
st.info(f"Submitting Job: {job_title}...") |
|
|
job = inference_result.create_job( |
|
|
title=job_title, |
|
|
job_options=INFERENCE_JOB_OPTIONS, |
|
|
) |
|
|
|
|
|
job.start() |
|
|
job_id = job.job_id |
|
|
st.success(f"Job started. ID: {job_id}") |
|
|
|
|
|
status_box = st.empty() |
|
|
while True: |
|
|
metadata = job.describe_job() |
|
|
status = metadata.get("status") |
|
|
status_box.markdown(f"**Status:** `{status}` (refreshing every 5s...)") |
|
|
|
|
|
if status == "finished": |
|
|
break |
|
|
elif status in ["error", "canceled"]: |
|
|
st.error(f"Job failed with status: {status}") |
|
|
return None |
|
|
|
|
|
time.sleep(5) |
|
|
|
|
|
st.info("Downloading results...") |
|
|
results = job.get_results() |
|
|
output_path = temp_dir / f"embedding_{lat}_{lon}.tif" |
|
|
|
|
|
found = False |
|
|
for asset in results.get_assets(): |
|
|
if asset.metadata.get("type", "").startswith("image/tiff"): |
|
|
asset.download(str(output_path)) |
|
|
found = True |
|
|
break |
|
|
|
|
|
if found: |
|
|
return output_path |
|
|
else: |
|
|
st.error("No TIFF found in results.") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"OpenEO Error: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
st.title("Crop Map") |
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Control Panel") |
|
|
tif_files = list(demo_dir.glob("*.tif")) |
|
|
if not tif_files: |
|
|
st.error(f"No .tif files in {demo_dir}") |
|
|
st.stop() |
|
|
|
|
|
selected_tif = st.selectbox("Select Region", tif_files, format_func=lambda x: x.name) |
|
|
|
|
|
possible_name = selected_tif.stem.replace("_embedding", "") + ".geojson" |
|
|
geojson_path = selected_tif.parent / possible_name |
|
|
has_geojson = geojson_path.exists() |
|
|
|
|
|
if has_geojson: |
|
|
st.success(f"Linked: {geojson_path.name}") |
|
|
|
|
|
run_btn = st.button("Run Analysis", type="primary") |
|
|
|
|
|
if run_btn: |
|
|
model = load_model() |
|
|
if not model: |
|
|
st.error("Model not found") |
|
|
st.stop() |
|
|
|
|
|
with st.spinner("Processing..."): |
|
|
class_map, bounds = run_prediction(selected_tif, model) |
|
|
|
|
|
h, w = class_map.shape |
|
|
rgba_img = np.zeros((h, w, 4), dtype=np.uint8) |
|
|
unique_ids = np.unique(class_map) |
|
|
|
|
|
for uid in unique_ids: |
|
|
if uid not in id_to_class: continue |
|
|
crop = id_to_class[uid] |
|
|
c = get_class_color_rgba(crop, alpha=255) |
|
|
rgba_img[class_map == uid] = c |
|
|
|
|
|
gdf = None |
|
|
if has_geojson: |
|
|
gdf = gpd.read_file(geojson_path) |
|
|
if gdf.crs != "EPSG:4326": |
|
|
gdf = gdf.to_crs("EPSG:4326") |
|
|
gdf['geometry'] = gdf['geometry'].simplify(tolerance=0.0001) |
|
|
|
|
|
total = class_map.size |
|
|
counts = {id_to_class[uid]: np.sum(class_map == uid) for uid in unique_ids if uid in id_to_class} |
|
|
stats_df = pd.DataFrame([ |
|
|
{"Crop": k, "Pixels": v, "Percentage": v / total * 100} for k, v in counts.items() |
|
|
]).sort_values("Percentage", ascending=False) |
|
|
|
|
|
st.session_state['analysis_results'] = { |
|
|
"bounds": bounds, |
|
|
"rgba_img": rgba_img, |
|
|
"gdf": gdf, |
|
|
"stats_df": stats_df |
|
|
} |
|
|
|
|
|
tab1, tab2 = st.tabs(["Pre-loaded Regions", "Analyze New Area"]) |
|
|
|
|
|
with tab1: |
|
|
if 'analysis_results' in st.session_state: |
|
|
data = st.session_state['analysis_results'] |
|
|
bounds = data['bounds'] |
|
|
rgba_img = data['rgba_img'] |
|
|
gdf = data['gdf'] |
|
|
stats_df = data['stats_df'] |
|
|
|
|
|
c1, c2 = st.columns([3, 1]) |
|
|
|
|
|
with c1: |
|
|
center_lat = (bounds[0][0] + bounds[1][0]) / 2 |
|
|
center_lon = (bounds[0][1] + bounds[1][1]) / 2 |
|
|
|
|
|
overlay_opacity = st.slider("Overlay Opacity", 0.0, 1.0, 0.7, 0.1, key="opacity_tab1") |
|
|
|
|
|
m = folium.Map(location=[center_lat, center_lon], zoom_start=14, control_scale=True) |
|
|
|
|
|
folium.TileLayer( |
|
|
tiles='CartoDB positron', |
|
|
name='Light Map', |
|
|
overlay=False |
|
|
).add_to(m) |
|
|
|
|
|
folium.TileLayer( |
|
|
tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}', |
|
|
attr='Esri', |
|
|
name='Satellite', |
|
|
overlay=False |
|
|
).add_to(m) |
|
|
|
|
|
folium.raster_layers.ImageOverlay( |
|
|
image=rgba_img, |
|
|
bounds=bounds, |
|
|
opacity=overlay_opacity, |
|
|
name='Prediction', |
|
|
pixelated=True |
|
|
).add_to(m) |
|
|
|
|
|
if gdf is not None: |
|
|
folium.GeoJson( |
|
|
gdf, |
|
|
name="Fields", |
|
|
style_function=lambda x: {'color': 'white', 'weight': 1, 'fillOpacity': 0, 'dashArray': '5, 5'}, |
|
|
tooltip=folium.GeoJsonTooltip(fields=['roslina'], aliases=['Crop:']) |
|
|
).add_to(m) |
|
|
|
|
|
folium.LayerControl().add_to(m) |
|
|
plugins.Fullscreen().add_to(m) |
|
|
|
|
|
st_folium(m, height=600, use_container_width=True) |
|
|
|
|
|
with c2: |
|
|
st.subheader("Legend") |
|
|
st.markdown(create_legend_html(stats_df), unsafe_allow_html=True) |
|
|
st.dataframe(stats_df[["Crop", "Percentage"]], hide_index=True) |
|
|
|
|
|
with tab2: |
|
|
c1, c2 = st.columns([1, 2]) |
|
|
|
|
|
if 'tab2_results' not in st.session_state: |
|
|
st.session_state['tab2_results'] = None |
|
|
|
|
|
with c1: |
|
|
st.markdown("### 1. Select Area") |
|
|
lat = st.number_input("Latitude", value=50.93131691432723, format="%.4f") |
|
|
lon = st.number_input("Longitude", value=22.781513694631702, format="%.4f") |
|
|
|
|
|
if st.button("Generate the embedding and classify"): |
|
|
with st.spinner("Talking to Satellites... (This takes ~5 mins)"): |
|
|
tif_path = run_openeo_job(lat, lon) |
|
|
|
|
|
if tif_path: |
|
|
st.success("Embedding Generated!") |
|
|
|
|
|
model = load_model() |
|
|
class_map, bounds = run_prediction(tif_path, model) |
|
|
|
|
|
h, w = class_map.shape |
|
|
rgba_img = np.zeros((h, w, 4), dtype=np.uint8) |
|
|
unique_ids = np.unique(class_map) |
|
|
|
|
|
for uid in unique_ids: |
|
|
if uid not in id_to_class: continue |
|
|
crop = id_to_class[uid] |
|
|
c = get_class_color_rgba(crop, alpha=255) |
|
|
rgba_img[class_map == uid] = c |
|
|
|
|
|
total = class_map.size |
|
|
counts = {id_to_class[uid]: np.sum(class_map == uid) for uid in unique_ids if uid in id_to_class} |
|
|
stats_df = pd.DataFrame([ |
|
|
{"Crop": k, "Pixels": v, "Percentage": v / total * 100} for k, v in counts.items() |
|
|
]).sort_values("Percentage", ascending=False) |
|
|
|
|
|
st.session_state['tab2_results'] = { |
|
|
"bounds": bounds, |
|
|
"rgba_img": rgba_img, |
|
|
"stats_df": stats_df |
|
|
} |
|
|
|
|
|
st.success("Classification Complete") |
|
|
|
|
|
with c2: |
|
|
if st.session_state['tab2_results']: |
|
|
data = st.session_state['tab2_results'] |
|
|
bounds = data['bounds'] |
|
|
rgba_img = data['rgba_img'] |
|
|
stats_df = data['stats_df'] |
|
|
|
|
|
st.markdown("### 2. Analysis Results") |
|
|
|
|
|
center_lat = (bounds[0][0] + bounds[1][0]) / 2 |
|
|
center_lon = (bounds[0][1] + bounds[1][1]) / 2 |
|
|
|
|
|
overlay_opacity = st.slider("Overlay Opacity", 0.0, 1.0, 0.7, 0.1, key="opacity_tab2") |
|
|
|
|
|
m = folium.Map(location=[center_lat, center_lon], zoom_start=14, control_scale=True) |
|
|
|
|
|
folium.TileLayer( |
|
|
tiles='CartoDB positron', |
|
|
name='Light Map', |
|
|
overlay=False |
|
|
).add_to(m) |
|
|
|
|
|
folium.TileLayer( |
|
|
tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}', |
|
|
attr='Esri', |
|
|
name='Satellite', |
|
|
overlay=False |
|
|
).add_to(m) |
|
|
|
|
|
folium.raster_layers.ImageOverlay( |
|
|
image=rgba_img, |
|
|
bounds=bounds, |
|
|
opacity=overlay_opacity, |
|
|
name='Prediction', |
|
|
pixelated=True |
|
|
).add_to(m) |
|
|
|
|
|
folium.LayerControl().add_to(m) |
|
|
plugins.Fullscreen().add_to(m) |
|
|
|
|
|
st_folium(m, height=500, use_container_width=True) |
|
|
|
|
|
st.divider() |
|
|
col_leg, col_df = st.columns(2) |
|
|
with col_leg: |
|
|
st.subheader("Legend") |
|
|
st.markdown(create_legend_html(stats_df), unsafe_allow_html=True) |