Spaces:
Running
Running
fix
Browse files- app.py +12 -5
- js/interactive_grid.js +10 -2
app.py
CHANGED
|
@@ -38,6 +38,9 @@ def parse_bool_string(s):
|
|
| 38 |
|
| 39 |
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
| 40 |
|
|
|
|
|
|
|
|
|
|
| 41 |
left_map = np.zeros((1, 14 * 14 + 1))
|
| 42 |
right_map = np.zeros((1, 14 * 14 + 1))
|
| 43 |
|
|
@@ -47,12 +50,14 @@ def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
|
| 47 |
|
| 48 |
if np.count_nonzero(selected_cells_bef) == 0:
|
| 49 |
left_map[0, 0] = 1.0
|
|
|
|
| 50 |
|
| 51 |
if np.count_nonzero(selected_cells_aft) == 0:
|
| 52 |
right_map[0, 0] = 1.0
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
-
return left_map, right_map
|
| 56 |
|
| 57 |
def _get_rawimage(image_path):
|
| 58 |
# Pair x L x T x 3 x H x W
|
|
@@ -103,7 +108,7 @@ def predict_image(image_bef, image_aft, json_data_bef, json_data_aft):
|
|
| 103 |
selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32)
|
| 104 |
selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32)
|
| 105 |
|
| 106 |
-
left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
|
| 107 |
|
| 108 |
left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
|
| 109 |
|
|
@@ -130,8 +135,10 @@ def predict_image(image_bef, image_aft, json_data_bef, json_data_aft):
|
|
| 130 |
pred = f"{decode_text}"
|
| 131 |
|
| 132 |
# Include information about selected cells
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
| 135 |
|
| 136 |
return pred, selected_info_bef, selected_info_aft
|
| 137 |
|
|
@@ -289,7 +296,7 @@ with gr.Blocks() as demo:
|
|
| 289 |
fn=predict_image,
|
| 290 |
inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft],
|
| 291 |
outputs=[prediction, selected_info_bef, selected_info_aft],
|
| 292 |
-
_js="(sel_attn_bef, sel_attn_aft) => { return [
|
| 293 |
)
|
| 294 |
|
| 295 |
image_bef.change(
|
|
|
|
| 38 |
|
| 39 |
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
| 40 |
|
| 41 |
+
first_ = True
|
| 42 |
+
second_ = True
|
| 43 |
+
|
| 44 |
left_map = np.zeros((1, 14 * 14 + 1))
|
| 45 |
right_map = np.zeros((1, 14 * 14 + 1))
|
| 46 |
|
|
|
|
| 50 |
|
| 51 |
if np.count_nonzero(selected_cells_bef) == 0:
|
| 52 |
left_map[0, 0] = 1.0
|
| 53 |
+
first_ = False
|
| 54 |
|
| 55 |
if np.count_nonzero(selected_cells_aft) == 0:
|
| 56 |
right_map[0, 0] = 1.0
|
| 57 |
+
second_ = False
|
| 58 |
|
| 59 |
|
| 60 |
+
return left_map, right_map, first_, second_
|
| 61 |
|
| 62 |
def _get_rawimage(image_path):
|
| 63 |
# Pair x L x T x 3 x H x W
|
|
|
|
| 108 |
selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32)
|
| 109 |
selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32)
|
| 110 |
|
| 111 |
+
left_map, right_map, first_, second_ = get_intervention_vector(selected_cells_bef, selected_cells_aft)
|
| 112 |
|
| 113 |
left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
|
| 114 |
|
|
|
|
| 135 |
pred = f"{decode_text}"
|
| 136 |
|
| 137 |
# Include information about selected cells
|
| 138 |
+
i, j = np.nonzero(selected_cells_bef)
|
| 139 |
+
selected_info_bef = f"{list(zip(i, j))}" if first_ else "No image patch was selected"
|
| 140 |
+
i, j = np.nonzero(selected_cells_aft)
|
| 141 |
+
selected_info_aft = f"{list(zip(i, j))}" if second_ else "No image patch was selected"
|
| 142 |
|
| 143 |
return pred, selected_info_bef, selected_info_aft
|
| 144 |
|
|
|
|
| 296 |
fn=predict_image,
|
| 297 |
inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft],
|
| 298 |
outputs=[prediction, selected_info_bef, selected_info_aft],
|
| 299 |
+
_js="(image_bef, image_aft, sel_attn_bef, sel_attn_aft) => { return [image_bef, image_aft, read_js_Data_bef(), read_js_Data_aft()]; }"
|
| 300 |
)
|
| 301 |
|
| 302 |
image_bef.change(
|
js/interactive_grid.js
CHANGED
|
@@ -298,15 +298,23 @@ function importBackgroundAfter(image_after) {
|
|
| 298 |
}
|
| 299 |
}
|
| 300 |
|
| 301 |
-
function
|
| 302 |
console.log("read_js_Data");
|
| 303 |
console.log("read_js_Data");
|
| 304 |
console.log("read_js_Data");
|
| 305 |
console.log("read_js_Data");
|
| 306 |
console.log("read_js_Data");
|
| 307 |
-
return grid_bef
|
| 308 |
}
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
function set_grid_from_data(data) {
|
| 312 |
if (data.length !== gridSize || data[0].length !== gridSize) {
|
|
|
|
| 298 |
}
|
| 299 |
}
|
| 300 |
|
| 301 |
+
function read_js_Data_bef() {
|
| 302 |
console.log("read_js_Data");
|
| 303 |
console.log("read_js_Data");
|
| 304 |
console.log("read_js_Data");
|
| 305 |
console.log("read_js_Data");
|
| 306 |
console.log("read_js_Data");
|
| 307 |
+
return grid_bef;
|
| 308 |
}
|
| 309 |
|
| 310 |
+
function read_js_Data_aft() {
|
| 311 |
+
console.log("read_js_Data");
|
| 312 |
+
console.log("read_js_Data");
|
| 313 |
+
console.log("read_js_Data");
|
| 314 |
+
console.log("read_js_Data");
|
| 315 |
+
console.log("read_js_Data");
|
| 316 |
+
return grid_aft;
|
| 317 |
+
}
|
| 318 |
|
| 319 |
function set_grid_from_data(data) {
|
| 320 |
if (data.length !== gridSize || data[0].length !== gridSize) {
|