Spaces:
Running
Running
Yao-Ting Yao
commited on
Commit
·
ec0e823
1
Parent(s):
8254c8e
Update streamlit_app.py
Browse filesPre-generate thumbnails and save in AWS S3 bucket. Read thumbnails' url in click function.
- src/streamlit_app.py +17 -105
src/streamlit_app.py
CHANGED
|
@@ -142,94 +142,6 @@ chips_df = pd.read_csv("data/embeddings_df_v0.11_test.csv")
|
|
| 142 |
# set anonymous S3FileSystem to read files from public bucket
|
| 143 |
s3 = s3fs.S3FileSystem(anon=True)
|
| 144 |
|
| 145 |
-
## helper function
|
| 146 |
-
def gen_chip_urls(row, s3_prefix):
|
| 147 |
-
'''
|
| 148 |
-
Generate S3 urls for chips
|
| 149 |
-
:param row: dictionary with chip_id and dates
|
| 150 |
-
:param s3_prefix: S3 url prefix
|
| 151 |
-
:return s3_urls: a list of urls
|
| 152 |
-
'''
|
| 153 |
-
s3_urls = []
|
| 154 |
-
dates = ast.literal_eval(row["dates"])
|
| 155 |
-
for date in dates:
|
| 156 |
-
filename = f"s2_{row['chip_id']:06}_{date}.tif"
|
| 157 |
-
s3_url = f"{s3_prefix}/{filename}"
|
| 158 |
-
s3_urls.append(s3_url)
|
| 159 |
-
return s3_urls
|
| 160 |
-
|
| 161 |
-
def mask_nodata(band, nodata_values=(-999,)):
|
| 162 |
-
'''
|
| 163 |
-
Mask nodata to nan
|
| 164 |
-
:param band
|
| 165 |
-
:param nodata_values:nodata values in chips is -999
|
| 166 |
-
:return band
|
| 167 |
-
'''
|
| 168 |
-
band = band.astype(float)
|
| 169 |
-
for val in nodata_values:
|
| 170 |
-
band[band == val] = np.nan
|
| 171 |
-
return band
|
| 172 |
-
|
| 173 |
-
def normalize(band):
|
| 174 |
-
'''
|
| 175 |
-
Normalize a band to 0-1 range(float)
|
| 176 |
-
:param band (ndarray)
|
| 177 |
-
return normalize band (ndarray); when max equals min, returns zeros.
|
| 178 |
-
'''
|
| 179 |
-
if np.nanmean(band) >= 4000:
|
| 180 |
-
band = band / 6000
|
| 181 |
-
else:
|
| 182 |
-
band = band / 4000
|
| 183 |
-
band = np.clip(band, None, 1)
|
| 184 |
-
|
| 185 |
-
return band
|
| 186 |
-
|
| 187 |
-
def create_thumbnail(url):
|
| 188 |
-
'''
|
| 189 |
-
Read S3 file into memory, using rasterio to create a png thumbnail then encode as a base64 string url
|
| 190 |
-
:param url: chip url
|
| 191 |
-
:return a base64-encoded png string, returns an empty string when an error occurs
|
| 192 |
-
'''
|
| 193 |
-
try:
|
| 194 |
-
# read raw bytes from s3 file
|
| 195 |
-
with s3.open(url, "rb") as f:
|
| 196 |
-
data = f.read()
|
| 197 |
-
|
| 198 |
-
# wrap the raw bytes into an memory file
|
| 199 |
-
with MemoryFile(data) as memfile:
|
| 200 |
-
|
| 201 |
-
# read memory file with rasterio
|
| 202 |
-
with memfile.open() as src:
|
| 203 |
-
# mask nodata to have correct calculate normalization
|
| 204 |
-
# band1->blue, band2->green, band3->red
|
| 205 |
-
|
| 206 |
-
blue = src.read(1).astype(float)
|
| 207 |
-
green = src.read(2).astype(float)
|
| 208 |
-
red = src.read(3).astype(float)
|
| 209 |
-
|
| 210 |
-
blue = normalize(mask_nodata(blue))
|
| 211 |
-
green = normalize(mask_nodata(green))
|
| 212 |
-
red = normalize(mask_nodata(red))
|
| 213 |
-
|
| 214 |
-
# stack in RGB
|
| 215 |
-
rgb = np.dstack((red, green, blue))
|
| 216 |
-
|
| 217 |
-
# convert float(0-1) to uint8 (0-255)
|
| 218 |
-
rgb_8bit = (rgb * 255).astype(np.uint8)
|
| 219 |
-
|
| 220 |
-
# convert to png in memory
|
| 221 |
-
pil_img = Image.fromarray(rgb_8bit)
|
| 222 |
-
buf = io.BytesIO()
|
| 223 |
-
pil_img.save(buf, format='PNG')
|
| 224 |
-
|
| 225 |
-
# encoded into base64 HTML format
|
| 226 |
-
encoded = base64.b64encode(buf.getvalue()).decode('utf-8')
|
| 227 |
-
return f"data:image/png;base64,{encoded}"
|
| 228 |
-
|
| 229 |
-
except Exception as e:
|
| 230 |
-
# return an empty string for Exception
|
| 231 |
-
return ""
|
| 232 |
-
|
| 233 |
def get_lat(geometry):
|
| 234 |
lat = wkt.loads(geometry).coords.xy[1][0]
|
| 235 |
|
|
@@ -240,7 +152,6 @@ def get_lon(geometry):
|
|
| 240 |
|
| 241 |
return lon
|
| 242 |
|
| 243 |
-
|
| 244 |
## generate json
|
| 245 |
# title: plot title
|
| 246 |
# xaxis_title: x axis title
|
|
@@ -255,12 +166,6 @@ title_js = json.dumps(config["title"])
|
|
| 255 |
xaxis_js = json.dumps(config["xaxis_title"])
|
| 256 |
yaxis_js = json.dumps(config["yaxis_title"])
|
| 257 |
|
| 258 |
-
# set prefix
|
| 259 |
-
s3_prefix="s3://gfm-bench"
|
| 260 |
-
|
| 261 |
-
# generate S3 file URLs
|
| 262 |
-
chips_df["urls"] = chips_df.apply(lambda row: gen_chip_urls(row, s3_prefix), axis=1)
|
| 263 |
-
|
| 264 |
# set lc(str) for categorical data for plotting
|
| 265 |
chips_df["lc"] = chips_df["lc"].astype(str)
|
| 266 |
# add latitude and longitude
|
|
@@ -300,13 +205,21 @@ color_dict_label = {
|
|
| 300 |
'Rangeland': '#f7980a'
|
| 301 |
}
|
| 302 |
|
| 303 |
-
# create thumbnail
|
| 304 |
-
chips_df["thumbs"] = chips_df["urls"].apply(
|
| 305 |
-
lambda urls: [create_thumbnail(p) for p in urls]
|
| 306 |
-
)
|
| 307 |
# create dates Python list
|
| 308 |
chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval)
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
# build a list of points dictionary
|
| 311 |
points = (
|
| 312 |
chips_df
|
|
@@ -314,12 +227,13 @@ points = (
|
|
| 314 |
"cls_dim1": "x",
|
| 315 |
"cls_dim2": "y",
|
| 316 |
"Land Cover": "category"
|
| 317 |
-
})[["x","y","chip_id", "latitude", "longitude","category","
|
| 318 |
.assign(
|
| 319 |
id = chips_df["chip_id"],
|
| 320 |
lat = chips_df["latitude"],
|
| 321 |
lon = chips_df["longitude"],
|
| 322 |
-
color=chips_df["Land Cover"].map(color_dict_label)
|
|
|
|
| 323 |
.to_dict(orient="records")
|
| 324 |
)
|
| 325 |
|
|
@@ -380,8 +294,7 @@ plot_html = f"""
|
|
| 380 |
customdata:pts.map(p=>[
|
| 381 |
p.id,
|
| 382 |
p.lat,
|
| 383 |
-
p.lon
|
| 384 |
-
p.thumbs
|
| 385 |
]),
|
| 386 |
mode: 'markers',
|
| 387 |
type: 'scatter',
|
|
@@ -438,8 +351,7 @@ plot_html = f"""
|
|
| 438 |
gd.on('plotly_click', evt => {{
|
| 439 |
// grab thumbs and dates through point index
|
| 440 |
const idx = evt.points[0].pointIndex;
|
| 441 |
-
const
|
| 442 |
-
const thumbs = cds[3];
|
| 443 |
const dates = points[idx].dates_list;
|
| 444 |
// grab image container and clear out old thumbs
|
| 445 |
const container = document.getElementById('image-container');
|
|
|
|
| 142 |
# set anonymous S3FileSystem to read files from public bucket
|
| 143 |
s3 = s3fs.S3FileSystem(anon=True)
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
def get_lat(geometry):
|
| 146 |
lat = wkt.loads(geometry).coords.xy[1][0]
|
| 147 |
|
|
|
|
| 152 |
|
| 153 |
return lon
|
| 154 |
|
|
|
|
| 155 |
## generate json
|
| 156 |
# title: plot title
|
| 157 |
# xaxis_title: x axis title
|
|
|
|
| 166 |
xaxis_js = json.dumps(config["xaxis_title"])
|
| 167 |
yaxis_js = json.dumps(config["yaxis_title"])
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
# set lc(str) for categorical data for plotting
|
| 170 |
chips_df["lc"] = chips_df["lc"].astype(str)
|
| 171 |
# add latitude and longitude
|
|
|
|
| 205 |
'Rangeland': '#f7980a'
|
| 206 |
}
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
# create dates Python list
|
| 209 |
chips_df["dates_list"] = chips_df["dates"].apply(ast.literal_eval)
|
| 210 |
|
| 211 |
+
# set prefix
|
| 212 |
+
s3_url="https://gfm-bench.s3.amazonaws.com/thumbnails"
|
| 213 |
+
|
| 214 |
+
# create thumb_urls column
|
| 215 |
+
chips_df["thumb_urls"] = chips_df.apply(
|
| 216 |
+
lambda r: [
|
| 217 |
+
f"{s3_url}/s2_{r.chip_id:06}_{date}.png"
|
| 218 |
+
for date in r.dates_list
|
| 219 |
+
],
|
| 220 |
+
axis=1
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
# build a list of points dictionary
|
| 224 |
points = (
|
| 225 |
chips_df
|
|
|
|
| 227 |
"cls_dim1": "x",
|
| 228 |
"cls_dim2": "y",
|
| 229 |
"Land Cover": "category"
|
| 230 |
+
})[["x","y","chip_id", "latitude", "longitude","category","dates_list"]]
|
| 231 |
.assign(
|
| 232 |
id = chips_df["chip_id"],
|
| 233 |
lat = chips_df["latitude"],
|
| 234 |
lon = chips_df["longitude"],
|
| 235 |
+
color=chips_df["Land Cover"].map(color_dict_label),
|
| 236 |
+
thumbs = chips_df["thumb_urls"])
|
| 237 |
.to_dict(orient="records")
|
| 238 |
)
|
| 239 |
|
|
|
|
| 294 |
customdata:pts.map(p=>[
|
| 295 |
p.id,
|
| 296 |
p.lat,
|
| 297 |
+
p.lon
|
|
|
|
| 298 |
]),
|
| 299 |
mode: 'markers',
|
| 300 |
type: 'scatter',
|
|
|
|
| 351 |
gd.on('plotly_click', evt => {{
|
| 352 |
// grab thumbs and dates through point index
|
| 353 |
const idx = evt.points[0].pointIndex;
|
| 354 |
+
const thumbs = points[idx].thumbs;
|
|
|
|
| 355 |
const dates = points[idx].dates_list;
|
| 356 |
// grab image container and clear out old thumbs
|
| 357 |
const container = document.getElementById('image-container');
|