|
|
|
|
|
import base64 |
|
|
import os |
|
|
import sys |
|
|
import csv |
|
|
import spaces |
|
|
import glob |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import rasterio as rio |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib as mpl |
|
|
from io import BytesIO |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
from matplotlib import rcParams |
|
|
from msclip.inference import run_inference_classification |
|
|
from msclip.inference.utils import build_model |
|
|
|
|
|
rcParams["font.size"] = 9 |
|
|
rcParams["axes.titlesize"] = 9 |
|
|
IMG_PX = 300 |
|
|
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
|
|
|
|
|
model, preprocess, tokenizer = build_model() |
|
|
|
|
|
EXAMPLES = { |
|
|
"EuroSAT": { |
|
|
"images": glob.glob("examples/eurosat/*.tif"), |
|
|
"classes": [ |
|
|
"Annual crop", "Forest", "Herbaceous vegetation", "Highway", "Industrial", |
|
|
"Pasture", "Permanent crop", "Residential", "River", "Sea lake" |
|
|
] |
|
|
}, |
|
|
"Meter-ML": { |
|
|
"images": glob.glob("examples/meterml/*.tif"), |
|
|
"classes": [ |
|
|
"Concentrated animal feeding operations", |
|
|
"Landfills", |
|
|
"Coal mines", |
|
|
"Other features", |
|
|
"Natural gas processing plants", |
|
|
"Oil refineries and petroleum terminals", |
|
|
"Wastewater treatment plants", |
|
|
] |
|
|
}, |
|
|
"TerraMesh": { |
|
|
"images": glob.glob("examples/terramesh/*.tif"), |
|
|
"classes": [ |
|
|
"Village", "Beach", "River", "Ice", "Fields", "Mountains", "Desert" |
|
|
] |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors] |
|
|
|
|
|
|
|
|
def build_colormap(class_names): |
|
|
return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))} |
|
|
|
|
|
|
|
|
def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000): |
|
|
""" |
|
|
array: numpy array with dimensions [C, H, W] |
|
|
returns 0-1 scaled array |
|
|
""" |
|
|
|
|
|
|
|
|
limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance]) |
|
|
limit_high = limit_high.clip(default) |
|
|
limit_low = limit_low.clip(0, 1000) |
|
|
limit_low = np.where(median > default / 2, limit_low, 0) |
|
|
|
|
|
|
|
|
array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling) |
|
|
array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling) |
|
|
|
|
|
|
|
|
limit_low, limit_high = np.quantile(array, q=[tolerance / 10, 1. - tolerance / 10]) |
|
|
limit_high = limit_high.clip(default, 20000) |
|
|
limit_low = limit_low.clip(0, 500) |
|
|
limit_low = np.where(median > default / 2, limit_low, 0) |
|
|
|
|
|
|
|
|
array = (array - limit_low) / (limit_high - limit_low) |
|
|
|
|
|
return array |
|
|
|
|
|
|
|
|
def _s2_to_rgb(data, smooth_quantiles=True): |
|
|
|
|
|
if data.shape[0] > 13: |
|
|
|
|
|
rgb = data[:, :, [3, 2, 1]] |
|
|
else: |
|
|
|
|
|
rgb = data[[3, 2, 1]].transpose((1, 2, 0)) |
|
|
|
|
|
if smooth_quantiles: |
|
|
rgb = _rgb_smooth_quantiles(rgb) |
|
|
else: |
|
|
rgb = rgb / 2000 |
|
|
|
|
|
|
|
|
rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8) |
|
|
|
|
|
return rgb |
|
|
|
|
|
|
|
|
def _img_to_b64(path: str | Path) -> str: |
|
|
"""Encode image as base64 (optionally downsized).""" |
|
|
with rio.open(path) as src: |
|
|
data = src.read() |
|
|
rgb = _s2_to_rgb(data) |
|
|
img = Image.fromarray(rgb) |
|
|
side = max(img.size) |
|
|
|
|
|
canvas = Image.new("RGB", (side, side), (255, 255, 255)) |
|
|
canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2)) |
|
|
canvas = canvas.resize((IMG_PX, IMG_PX)) |
|
|
buf = BytesIO() |
|
|
canvas.save(buf, format="PNG") |
|
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
|
|
|
|
def _bar_chart(top_scores, img_name, cmap) -> str: |
|
|
scores = top_scores.values.tolist() |
|
|
labels = top_scores.index.tolist() |
|
|
while len(scores) < 3: |
|
|
scores.append(0) |
|
|
labels.append("") |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(3, 1)) |
|
|
y_pos = np.arange(3) |
|
|
|
|
|
colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0) |
|
|
for cls, val in zip(labels, scores)] |
|
|
|
|
|
ax.barh(y_pos, scores, height=0.7, color=colors) |
|
|
ax.set_xlim(0, 1) |
|
|
ax.invert_yaxis() |
|
|
ax.axis("off") |
|
|
img_name = os.path.splitext(img_name)[0] |
|
|
if len(img_name) > 25: |
|
|
img_name = img_name[:23] + "..." |
|
|
ax.set_title(img_name) |
|
|
|
|
|
for i, (cls, val) in enumerate(zip(labels, scores)): |
|
|
if len(cls) > 25: |
|
|
cls = cls[:23] + "..." |
|
|
if val > 0: |
|
|
ax.text(0.02, i + 0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center") |
|
|
|
|
|
buf = BytesIO() |
|
|
fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True) |
|
|
plt.close(fig) |
|
|
b64 = base64.b64encode(buf.getvalue()).decode() |
|
|
return f'<img src="data:image/png;base64,{b64}" style="display:block;margin:auto;width:{IMG_PX}px;" />' |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def classify(images, class_text): |
|
|
class_names = [c.strip() for c in class_text.split(",") if c.strip()] |
|
|
cards = [] |
|
|
|
|
|
df = run_inference_classification( |
|
|
model=model, |
|
|
preprocess=preprocess, |
|
|
tokenizer=tokenizer, |
|
|
image_path=images, |
|
|
class_names=class_names, |
|
|
verbose=False |
|
|
) |
|
|
for img_path, (id, row) in zip(images, df.iterrows()): |
|
|
scores = row[2:].astype(float) |
|
|
top = scores.sort_values(ascending=False)[:3] |
|
|
top = top[top > 0.01] |
|
|
cmap = build_colormap(class_names) |
|
|
|
|
|
cards.append(f""" |
|
|
<div style="width:{IMG_PX}px;margin:18px auto;text-align:left;"> |
|
|
<img src="data:image/png;base64,{_img_to_b64(img_path)}" |
|
|
style="width:{IMG_PX}px;height:{IMG_PX}px;object-fit:cover; |
|
|
border-radius:8px;box-shadow:0 2px 6px rgba(0,0,0,.15);display:block;margin:auto;"> |
|
|
{_bar_chart(top, os.path.basename(img_path), cmap)} |
|
|
</div>""") |
|
|
|
|
|
return ( |
|
|
"<div style='display:flex;flex-wrap:wrap;justify-content:center;'>" |
|
|
+ "".join(cards) |
|
|
+ "</div>" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
terramesh_html = classify(EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"])) |
|
|
eurosat_html = classify(EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"])) |
|
|
meterml_html = classify(EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"])) |
|
|
|
|
|
|
|
|
def load_eurosat_example(): |
|
|
return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]), eurosat_html |
|
|
|
|
|
|
|
|
def load_meterml_example(): |
|
|
return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]), meterml_html |
|
|
|
|
|
|
|
|
def load_terramesh_example(): |
|
|
return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]), terramesh_html |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
css=""" |
|
|
.gradio-container |
|
|
#result_box, |
|
|
#result_box.gr-skeleton {min-height:280px !important;} |
|
|
""") as demo: |
|
|
gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP") |
|
|
gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. " |
|
|
"You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). " |
|
|
"We provide three sets of example images with class names and cached outputs. " |
|
|
"The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). " |
|
|
"The images are classified based on the similarity between the images embeddings and text embeddings. " |
|
|
"You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ") |
|
|
with gr.Row(): |
|
|
img_in = gr.File( |
|
|
label="Upload S-2 images", file_count="multiple", type="filepath" |
|
|
) |
|
|
cls_in = gr.Textbox( |
|
|
value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]), |
|
|
label="Class names (comma‑separated)", |
|
|
) |
|
|
|
|
|
run_btn = gr.Button("Classify", variant="primary") |
|
|
|
|
|
|
|
|
gr.Markdown("#### Load examples") |
|
|
with gr.Row(): |
|
|
btn_terramesh = gr.Button("TerraMesh") |
|
|
btn_eurosat = gr.Button("EuroSAT") |
|
|
btn_meterml = gr.Button("Meter-ML") |
|
|
|
|
|
out_html = gr.HTML(label="Results", |
|
|
elem_id="result_box", |
|
|
min_height=280) |
|
|
|
|
|
run_btn.click(classify, inputs=[img_in, cls_in], outputs=out_html) |
|
|
|
|
|
btn_terramesh.click( |
|
|
load_terramesh_example, |
|
|
outputs=[img_in, cls_in, out_html], |
|
|
) |
|
|
|
|
|
btn_eurosat.click( |
|
|
load_eurosat_example, |
|
|
outputs=[img_in, cls_in, out_html], |
|
|
) |
|
|
|
|
|
btn_meterml.click( |
|
|
load_meterml_example, |
|
|
outputs=[img_in, cls_in, out_html], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|