Spaces:
Sleeping
Sleeping
artelabsuper
commited on
Commit
·
99ee6d2
1
Parent(s):
7d56262
cached models improve speed
Browse files
app.py
CHANGED
|
@@ -8,38 +8,41 @@ from matplotlib import colors
|
|
| 8 |
|
| 9 |
if not hasattr(st, 'paths'):
|
| 10 |
st.paths = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Load Model
|
| 13 |
# @title Load pretrained weights
|
| 14 |
|
| 15 |
-
|
| 16 |
-
best_model_annual_file_name = "best_model_annual.pth"
|
| 17 |
-
|
| 18 |
-
first_input_batch = torch.zeros(71, 9, 5, 48, 48)
|
| 19 |
-
# first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
|
| 20 |
-
daily_model = FPN(opt, first_input_batch, opt.win_size)
|
| 21 |
-
annual_model = SimpleNN(opt)
|
| 22 |
-
|
| 23 |
-
if torch.cuda.is_available():
|
| 24 |
-
daily_model = torch.nn.DataParallel(daily_model).cuda()
|
| 25 |
-
annual_model = torch.nn.DataParallel(annual_model).cuda()
|
| 26 |
-
daily_model = torch.nn.DataParallel(daily_model).cuda()
|
| 27 |
-
annual_model = torch.nn.DataParallel(annual_model).cuda()
|
| 28 |
-
else:
|
| 29 |
-
daily_model = torch.nn.DataParallel(daily_model).cpu()
|
| 30 |
-
annual_model = torch.nn.DataParallel(annual_model).cpu()
|
| 31 |
-
daily_model = torch.nn.DataParallel(daily_model).cpu()
|
| 32 |
-
annual_model = torch.nn.DataParallel(annual_model).cpu()
|
| 33 |
-
|
| 34 |
-
print('trying to resume previous saved models...')
|
| 35 |
-
state = resume(
|
| 36 |
-
os.path.join(opt.resume_path, best_model_daily_file_name),
|
| 37 |
-
model=daily_model, optimizer=None)
|
| 38 |
-
state = resume(
|
| 39 |
-
os.path.join(opt.resume_path, best_model_annual_file_name),
|
| 40 |
-
model=annual_model, optimizer=None)
|
| 41 |
-
daily_model = daily_model.eval()
|
| 42 |
-
annual_model = annual_model.eval()
|
| 43 |
|
| 44 |
st.title('Lombardia Sentinel 2 daily Crop Mapping')
|
| 45 |
st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
|
|
@@ -85,14 +88,14 @@ if sample_path is not None:
|
|
| 85 |
if torch.cuda.is_available():
|
| 86 |
x_dailies = x_dailies.cuda()
|
| 87 |
|
| 88 |
-
feat_daily, outs_daily = daily_model.forward(x_dailies)
|
| 89 |
# return to original size of batch and year
|
| 90 |
outs_daily = outs_daily.view(
|
| 91 |
opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
|
| 92 |
feat_daily = feat_daily.view(
|
| 93 |
opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
|
| 94 |
|
| 95 |
-
_, out_annual = annual_model.forward(feat_daily)
|
| 96 |
pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
|
| 97 |
pred_annual = pred_annual.cpu().numpy()
|
| 98 |
# Remapping the labels
|
|
@@ -158,7 +161,7 @@ if st.paths is not None:
|
|
| 158 |
st.paths, index=st.paths.index('patch-pred-nn.tif'))
|
| 159 |
|
| 160 |
file_path = os.path.join(folder, file_picker)
|
| 161 |
-
print(file_path)
|
| 162 |
target, profile = read(file_path)
|
| 163 |
target = np.squeeze(target)
|
| 164 |
target = [classes_color_map[p] for p in target]
|
|
@@ -169,7 +172,7 @@ if st.paths is not None:
|
|
| 169 |
|
| 170 |
markdown_legend = ''
|
| 171 |
for c, l in zip(classes_color_map, labels_map):
|
| 172 |
-
print(colors.to_hex(c))
|
| 173 |
markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
|
| 174 |
|
| 175 |
col1, col2 = st.columns(2)
|
|
|
|
| 8 |
|
| 9 |
if not hasattr(st, 'paths'):
|
| 10 |
st.paths = None
|
| 11 |
+
if not hasattr(st, 'daily_model'):
|
| 12 |
+
best_model_daily_file_name = "best_model_daily.pth"
|
| 13 |
+
best_model_annual_file_name = "best_model_annual.pth"
|
| 14 |
+
|
| 15 |
+
first_input_batch = torch.zeros(71, 9, 5, 48, 48)
|
| 16 |
+
# first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
|
| 17 |
+
st.daily_model = FPN(opt, first_input_batch, opt.win_size)
|
| 18 |
+
st.annual_model = SimpleNN(opt)
|
| 19 |
+
|
| 20 |
+
if torch.cuda.is_available():
|
| 21 |
+
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
|
| 22 |
+
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
|
| 23 |
+
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
|
| 24 |
+
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
|
| 25 |
+
else:
|
| 26 |
+
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
|
| 27 |
+
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
|
| 28 |
+
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
|
| 29 |
+
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
|
| 30 |
+
|
| 31 |
+
print('trying to resume previous saved models...')
|
| 32 |
+
state = resume(
|
| 33 |
+
os.path.join(opt.resume_path, best_model_daily_file_name),
|
| 34 |
+
model=st.daily_model, optimizer=None)
|
| 35 |
+
state = resume(
|
| 36 |
+
os.path.join(opt.resume_path, best_model_annual_file_name),
|
| 37 |
+
model=st.annual_model, optimizer=None)
|
| 38 |
+
st.daily_model = st.daily_model.eval()
|
| 39 |
+
st.annual_model = st.annual_model.eval()
|
| 40 |
+
|
| 41 |
|
| 42 |
# Load Model
|
| 43 |
# @title Load pretrained weights
|
| 44 |
|
| 45 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
st.title('Lombardia Sentinel 2 daily Crop Mapping')
|
| 48 |
st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
|
|
|
|
| 88 |
if torch.cuda.is_available():
|
| 89 |
x_dailies = x_dailies.cuda()
|
| 90 |
|
| 91 |
+
feat_daily, outs_daily = st.daily_model.forward(x_dailies)
|
| 92 |
# return to original size of batch and year
|
| 93 |
outs_daily = outs_daily.view(
|
| 94 |
opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
|
| 95 |
feat_daily = feat_daily.view(
|
| 96 |
opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
|
| 97 |
|
| 98 |
+
_, out_annual = st.annual_model.forward(feat_daily)
|
| 99 |
pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
|
| 100 |
pred_annual = pred_annual.cpu().numpy()
|
| 101 |
# Remapping the labels
|
|
|
|
| 161 |
st.paths, index=st.paths.index('patch-pred-nn.tif'))
|
| 162 |
|
| 163 |
file_path = os.path.join(folder, file_picker)
|
| 164 |
+
# print(file_path)
|
| 165 |
target, profile = read(file_path)
|
| 166 |
target = np.squeeze(target)
|
| 167 |
target = [classes_color_map[p] for p in target]
|
|
|
|
| 172 |
|
| 173 |
markdown_legend = ''
|
| 174 |
for c, l in zip(classes_color_map, labels_map):
|
| 175 |
+
# print(colors.to_hex(c))
|
| 176 |
markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
|
| 177 |
|
| 178 |
col1, col2 = st.columns(2)
|