Yao-Ting Yao commited on
Commit
8254c8e
·
1 Parent(s): e14621a

Create thumbnail_process.py

Browse files
Files changed (1) hide show
  1. src/thumbnail_process.py +116 -0
src/thumbnail_process.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from PIL import Image
4
+ import ast
5
+ import s3fs
6
+ from rasterio.io import MemoryFile
7
+ import os
8
+
9
+ # read csv
10
+ chips_df = pd.read_csv("../data/embeddings_df_v0.11_test.csv")
11
+
12
+ # set anonymous S3FileSystem to read files from public bucket
13
+ s3 = s3fs.S3FileSystem(anon=True)
14
+
15
+ ## helper function
16
+ def gen_chip_urls(row, s3_prefix):
17
+ '''
18
+ Generate S3 urls for chips
19
+ :param row: dictionary with chip_id and dates
20
+ :param s3_prefix: S3 url prefix
21
+ :return s3_urls: a list of urls
22
+ '''
23
+ s3_urls = []
24
+ dates = ast.literal_eval(row["dates"])
25
+ for date in dates:
26
+ filename = f"s2_{row['chip_id']:06}_{date}.tif"
27
+ s3_url = f"{s3_prefix}/{filename}"
28
+ s3_urls.append(s3_url)
29
+ return s3_urls
30
+
31
+ def mask_nodata(band, nodata_values=(-999,)):
32
+ '''
33
+ Mask nodata to nan
34
+ :param band
35
+ :param nodata_values:nodata values in chips is -999
36
+ :return band
37
+ '''
38
+ band = band.astype(float)
39
+ for val in nodata_values:
40
+ band[band == val] = np.nan
41
+ return band
42
+
43
+ def normalize(band):
44
+ '''
45
+ Normalize a band to 0-1 range(float)
46
+ :param band (ndarray)
47
+ return normalize band (ndarray); when max equals min, returns zeros.
48
+ '''
49
+ if np.nanmean(band) >= 4000:
50
+ band = band / 6000
51
+ else:
52
+ band = band / 4000
53
+ band = np.clip(band, None, 1)
54
+
55
+ return band
56
+
57
+ def create_thumbnail(url, output_dir):
58
+ '''
59
+ Read S3 file into memory, create and save a resized png thumbnail.
60
+ :param url: S3 file URL
61
+ :param output_dir: directory to save thumbnails
62
+ :return: saved file path (str) or "" if failed
63
+ '''
64
+ try:
65
+ os.makedirs(output_dir, exist_ok=True)
66
+
67
+ # read raw bytes from s3 file
68
+ with s3.open(url, "rb") as f:
69
+ data = f.read()
70
+
71
+ # wrap the raw bytes into an memory file
72
+ with MemoryFile(data) as memfile:
73
+
74
+ # read memory file with rasterio
75
+ with memfile.open() as src:
76
+ # mask nodata to have correct calculate normalization
77
+ # band1->blue, band2->green, band3->red
78
+
79
+ blue = src.read(1).astype(float)
80
+ green = src.read(2).astype(float)
81
+ red = src.read(3).astype(float)
82
+
83
+ blue = normalize(mask_nodata(blue))
84
+ green = normalize(mask_nodata(green))
85
+ red = normalize(mask_nodata(red))
86
+
87
+ # stack in RGB
88
+ rgb = np.dstack((red, green, blue))
89
+
90
+ # convert float(0-1) to uint8 (0-255)
91
+ rgb_8bit = (rgb * 255).astype(np.uint8)
92
+
93
+ # convert to png in memory
94
+ pil_img = Image.fromarray(rgb_8bit)
95
+
96
+ # save png to local
97
+ filename = os.path.basename(url).replace(".tif", ".png")
98
+ file_path = os.path.join(output_dir, filename)
99
+ pil_img.save(file_path, format="PNG")
100
+
101
+ return file_path
102
+
103
+ except Exception as e:
104
+ # return an empty string for Exception
105
+ return ""
106
+
107
+ # set prefix
108
+ s3_prefix="s3://gfm-bench"
109
+
110
+ # generate S3 file URLs
111
+ chips_df["urls"] = chips_df.apply(lambda row: gen_chip_urls(row, s3_prefix), axis=1)
112
+
113
+ # create thumbnail
114
+ chips_df["thumbs"] = chips_df["urls"].apply(
115
+ lambda urls: [create_thumbnail(p, output_dir="../data/thumbnails") for p in urls]
116
+ )