Spaces:
Build error
Build error
Jakub Kwiatkowski
commited on
Commit
·
e986ee1
1
Parent(s):
9396266
Add model.
Browse files- main.py +49 -0
- models.py +13 -0
- raven_utils/__init__.py +10 -0
- raven_utils/config/__init__.py +0 -0
- raven_utils/config/constant.py +54 -0
- raven_utils/config/models.py +9 -0
- raven_utils/const.py +2 -0
- raven_utils/constant.py +53 -0
- raven_utils/data.py +46 -0
- raven_utils/decode.py +100 -0
- raven_utils/depricated/__init__.py +0 -0
- raven_utils/depricated/old_raven.py +490 -0
- raven_utils/draw.py +174 -0
- raven_utils/entity.py +6 -0
- raven_utils/group.py +11 -0
- raven_utils/inference.py +15 -0
- raven_utils/models/__init__.py +0 -0
- raven_utils/models/attn.py +187 -0
- raven_utils/models/attn2.py +187 -0
- raven_utils/models/augment.py +0 -0
- raven_utils/models/body.py +276 -0
- raven_utils/models/class_.py +31 -0
- raven_utils/models/head.py +159 -0
- raven_utils/models/loss.py +630 -0
- raven_utils/models/loss_3.py +638 -0
- raven_utils/models/multi_transformer.py +274 -0
- raven_utils/models/raven.py +239 -0
- raven_utils/models/trans.py +74 -0
- raven_utils/models/transformer.py +133 -0
- raven_utils/models/transformer_2.py +146 -0
- raven_utils/models/transformer_3.py +206 -0
- raven_utils/models/uitls_.py +16 -0
- raven_utils/output.py +16 -0
- raven_utils/params.py +110 -0
- raven_utils/properties.py +16 -0
- raven_utils/range_mask.py +16 -0
- raven_utils/render/__init__.py +0 -0
- raven_utils/render/const.py +86 -0
- raven_utils/render/rendering.py +304 -0
- raven_utils/render_.py +104 -0
- raven_utils/rules.py +21 -0
- raven_utils/target.py +50 -0
- raven_utils/uitls.py +64 -0
- saved_model/1/keras_metadata.pb +3 -0
- saved_model/1/saved_model.pb +3 -0
- saved_model/1/variables/variables.data-00000-of-00001 +3 -0
- saved_model/1/variables/variables.index +3 -0
- utils.py +84 -0
main.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from utils import load_example, run_nn, load_model_, next_, prev_
|
| 4 |
+
|
| 5 |
+
demo = gr.Blocks()
|
| 6 |
+
import models
|
| 7 |
+
|
| 8 |
+
with demo:
|
| 9 |
+
headline = gr.Markdown("## Raven resolver ")
|
| 10 |
+
markdown = gr.Markdown("Below we show all 9 images from raven matrix. "
|
| 11 |
+
"Model gets 8 images and predicts the properties of last one. "
|
| 12 |
+
"Based on this properties the answer image is render in the right panel. <br />"
|
| 13 |
+
"Note that angle rotation is only used as a noise. "
|
| 14 |
+
"There are not rules applied to angle property, so angle rotation of final output do not need to be the same as in example. "
|
| 15 |
+
"Additionally there are cases that other properties could be used as noise.")
|
| 16 |
+
with gr.Row():
|
| 17 |
+
with gr.Column():
|
| 18 |
+
with gr.Row():
|
| 19 |
+
text = gr.Textbox(models.START_IMAGE,
|
| 20 |
+
label="Write the example number from validation dataset (0, 14,000). You can also paste here matrix representation from generator.")
|
| 21 |
+
with gr.Row():
|
| 22 |
+
prev = gr.Button("Prev")
|
| 23 |
+
show = gr.Button("Show")
|
| 24 |
+
next = gr.Button("Next")
|
| 25 |
+
# button = gr.Button("Run")
|
| 26 |
+
with gr.Row():
|
| 27 |
+
image = gr.Image(value=load_example(models.START_IMAGE)[0], label="Raven matrix")
|
| 28 |
+
desc = gr.Markdown(value=load_example(models.START_IMAGE)[1])
|
| 29 |
+
|
| 30 |
+
with gr.Column():
|
| 31 |
+
with gr.Row():
|
| 32 |
+
output = gr.Image(label="Generated image", shape=(200, 200))
|
| 33 |
+
with gr.Row():
|
| 34 |
+
button = gr.Button("Run")
|
| 35 |
+
|
| 36 |
+
# text.change(load_example, inputs=text, outputs=[image, desc])
|
| 37 |
+
show.click(load_example, inputs=text, outputs=[image, desc])
|
| 38 |
+
# button.click(run_nn, inputs=image, outputs=output)
|
| 39 |
+
button.click(run_nn, inputs=text, outputs=output)
|
| 40 |
+
|
| 41 |
+
# next.click(next_, inputs=text, outputs=text)
|
| 42 |
+
# next.click(load_example, inputs=text, outputs=[image, desc])
|
| 43 |
+
next.click(next_, inputs=text, outputs=[text, image, desc])
|
| 44 |
+
|
| 45 |
+
# prev.click(prev_, inputs=text, outputs=text)
|
| 46 |
+
# prev.click(load_example, inputs=text, outputs=[image, desc])
|
| 47 |
+
prev.click(prev_, inputs=text, outputs=[text, image, desc])
|
| 48 |
+
|
| 49 |
+
demo.launch(debug=True)
|
models.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
START_IMAGE = 12000
|
| 4 |
+
|
| 5 |
+
from tensorflow.keras.models import load_model
|
| 6 |
+
model = load_model("saved_model/1")
|
| 7 |
+
|
| 8 |
+
from data_utils import nload, ims, DataSetFromFolder
|
| 9 |
+
data = nload("/home/jkwiatkowski/all/dataset/arr/val.npy")
|
| 10 |
+
indexes = nload("/home/jkwiatkowski/all/dataset/arr/val_target.npy")
|
| 11 |
+
|
| 12 |
+
folders = DataSetFromFolder("/home/jkwiatkowski/all/dataset/arr/RAVEN-10000-release/RAVEN-10000", file_type="dir")
|
| 13 |
+
properties = DataSetFromFolder(folders[:], file_type="xml", extension="val")
|
raven_utils/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import raven_utils.group as group
|
| 2 |
+
import raven_utils.entity as entity
|
| 3 |
+
import raven_utils.properties as properties
|
| 4 |
+
import raven_utils.target as target
|
| 5 |
+
import raven_utils.rules as rules
|
| 6 |
+
import raven_utils.output as output
|
| 7 |
+
import raven_utils.inference as inference
|
| 8 |
+
import raven_utils.decode as decode
|
| 9 |
+
import raven_utils.render_ as render_
|
| 10 |
+
import raven_utils.draw as draw
|
raven_utils/config/__init__.py
ADDED
|
File without changes
|
raven_utils/config/constant.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
RAVEN = "arr"
|
| 2 |
+
RAVEN_BIG = "arrb"
|
| 3 |
+
INDEX = "index"
|
| 4 |
+
LABELS = "labels"
|
| 5 |
+
TARGET_LABELS = "target_labels"
|
| 6 |
+
FEATURES = "features"
|
| 7 |
+
ACC_SAME = "acc_same"
|
| 8 |
+
ACC_CHOOSE_LOWER = "acc_choose_lower"
|
| 9 |
+
ACC_CHOOSE_UPPER = "acc_choose_upper"
|
| 10 |
+
ACC_NO_GROUP = "acc_NO_group"
|
| 11 |
+
CLASSIFICATION = "classification"
|
| 12 |
+
INFERENCE = "inference"
|
| 13 |
+
# PROPERTIES = "properties"
|
| 14 |
+
PROPERTY = "property"
|
| 15 |
+
MEMORY = "memory"
|
| 16 |
+
CONTROL = "control"
|
| 17 |
+
LATENT = "latent"
|
| 18 |
+
TARGET = "target"
|
| 19 |
+
INPUTS = "inputs"
|
| 20 |
+
RES = "res"
|
| 21 |
+
RESULT = "result"
|
| 22 |
+
MERGE = "merge"
|
| 23 |
+
MEMORY_STATE = "memory_state"
|
| 24 |
+
CONTROL_STATE = "control_state"
|
| 25 |
+
CONCAT = "concat"
|
| 26 |
+
FLATTEN = "flatten"
|
| 27 |
+
CROSS_ENTROPY = "cross_entropy"
|
| 28 |
+
SLOT = "slot"
|
| 29 |
+
PROPERTIES = "properties"
|
| 30 |
+
ACC = "acc"
|
| 31 |
+
GROUP = 'group'
|
| 32 |
+
NUMBER = 'number'
|
| 33 |
+
TRANS = 'trans'
|
| 34 |
+
TAIL = "tail"
|
| 35 |
+
MASK = "mask"
|
| 36 |
+
|
| 37 |
+
RAV_METRICS = [
|
| 38 |
+
ACC_NO_GROUP,
|
| 39 |
+
ACC_SAME,
|
| 40 |
+
ACC_CHOOSE_UPPER,
|
| 41 |
+
ACC_CHOOSE_LOWER,
|
| 42 |
+
"acc",
|
| 43 |
+
"c_acc_NO_group",
|
| 44 |
+
"c_acc",
|
| 45 |
+
"loss",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
IMP_RAV_METRICS = [
|
| 49 |
+
ACC_NO_GROUP,
|
| 50 |
+
ACC_SAME,
|
| 51 |
+
ACC_CHOOSE_UPPER,
|
| 52 |
+
ACC_CHOOSE_LOWER,
|
| 53 |
+
ACC,
|
| 54 |
+
]
|
raven_utils/config/models.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
AVAILABLE_MODELS = [
|
| 3 |
+
"197-0.31",
|
| 4 |
+
"53-0.48",
|
| 5 |
+
"74-0.50",
|
| 6 |
+
"21-0.48",
|
| 7 |
+
"10-0.52",
|
| 8 |
+
"179-0.50"
|
| 9 |
+
]
|
raven_utils/const.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HORIZONTAL = "horizontal"
|
| 2 |
+
VERTICAL = "vertical"
|
raven_utils/constant.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
RAVEN = "arr"
|
| 2 |
+
RAVEN_BIG = "arrb"
|
| 3 |
+
INDEX = "index"
|
| 4 |
+
LABELS = "labels"
|
| 5 |
+
TARGET_LABELS = "target_labels"
|
| 6 |
+
FEATURES = "features"
|
| 7 |
+
ACC_SAME = "acc_same"
|
| 8 |
+
ACC_CHOOSE_LOWER = "acc_choose_lower"
|
| 9 |
+
ACC_CHOOSE_UPPER = "acc_choose_upper"
|
| 10 |
+
ACC_NO_GROUP = "acc_NO_group"
|
| 11 |
+
CLASSIFICATION = "classification"
|
| 12 |
+
INFERENCE = "inference"
|
| 13 |
+
# PROPERTIES = "properties"
|
| 14 |
+
PROPERTY = "property"
|
| 15 |
+
MEMORY = "memory"
|
| 16 |
+
CONTROL = "control"
|
| 17 |
+
LATENT = "latent"
|
| 18 |
+
TARGET = "target"
|
| 19 |
+
INPUTS = "inputs"
|
| 20 |
+
RES = "res"
|
| 21 |
+
RESULT = "result"
|
| 22 |
+
MERGE = "merge"
|
| 23 |
+
MEMORY_STATE = "memory_state"
|
| 24 |
+
CONTROL_STATE = "control_state"
|
| 25 |
+
CONCAT = "concat"
|
| 26 |
+
FLATTEN = "flatten"
|
| 27 |
+
CROSS_ENTROPY = "cross_entropy"
|
| 28 |
+
SLOT = "slot"
|
| 29 |
+
PROPERTIES = "properties"
|
| 30 |
+
ACC = "acc"
|
| 31 |
+
GROUP = 'group'
|
| 32 |
+
NUMBER = 'number'
|
| 33 |
+
TRANS = 'trans'
|
| 34 |
+
TAIL = "tail"
|
| 35 |
+
MASK = "mask"
|
| 36 |
+
|
| 37 |
+
RAV_METRICS = [
|
| 38 |
+
ACC_NO_GROUP,
|
| 39 |
+
ACC_SAME,
|
| 40 |
+
ACC_CHOOSE_UPPER,
|
| 41 |
+
ACC_CHOOSE_LOWER,
|
| 42 |
+
"acc",
|
| 43 |
+
"c_acc_NO_group",
|
| 44 |
+
"c_acc",
|
| 45 |
+
"loss",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
IMP_RAV_METRICS = [
|
| 49 |
+
ACC_NO_GROUP,
|
| 50 |
+
ACC_SAME,
|
| 51 |
+
ACC_CHOOSE_UPPER,
|
| 52 |
+
ACC_CHOOSE_LOWER,
|
| 53 |
+
]
|
raven_utils/data.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
from models_utils import INPUTS, TARGET
|
| 6 |
+
|
| 7 |
+
from raven_utils.config.constant import RAVEN, LABELS, INDEX, FEATURES, RAV_METRICS, IMP_RAV_METRICS, ACC_NO_GROUP
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
from data_utils import pre, Data, gather, vec, resize
|
| 13 |
+
from data_utils.data_generator import DataGenerator
|
| 14 |
+
from funcy import identity
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_data(data, batch_size, steps=None, val_steps=None):
|
| 18 |
+
if val_steps is None:
|
| 19 |
+
val_steps = steps
|
| 20 |
+
fn = identity
|
| 21 |
+
train_target_index = data[4] + 8
|
| 22 |
+
train_generator = DataGenerator({
|
| 23 |
+
INPUTS: Data(data[0], fn),
|
| 24 |
+
TARGET: Data(data[2], identity),
|
| 25 |
+
LABELS: Data(data[2], identity),
|
| 26 |
+
INDEX: train_target_index[:, None],
|
| 27 |
+
# FEATURES: data[6]
|
| 28 |
+
},
|
| 29 |
+
batch=batch_size,
|
| 30 |
+
steps=steps
|
| 31 |
+
)
|
| 32 |
+
val_target_index = data[5] + 8
|
| 33 |
+
val_data = {
|
| 34 |
+
INPUTS: Data(data[1], fn),
|
| 35 |
+
TARGET: Data(data[3], identity),
|
| 36 |
+
LABELS: Data(data[3], identity),
|
| 37 |
+
INDEX: val_target_index[:, None],
|
| 38 |
+
# FEATURES: data[7]
|
| 39 |
+
}
|
| 40 |
+
val_generator = DataGenerator(
|
| 41 |
+
val_data,
|
| 42 |
+
batch=batch_size,
|
| 43 |
+
sampler="val",
|
| 44 |
+
steps=val_steps
|
| 45 |
+
)
|
| 46 |
+
return train_generator, val_generator
|
raven_utils/decode.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from data_utils import np_split
|
| 3 |
+
from ml_utils import lw
|
| 4 |
+
from models_utils.ops import ibin
|
| 5 |
+
|
| 6 |
+
import raven_utils as rv
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def output(x, split_fn=np_split, predict_fn_1=np.argmax, predict_fn_2=ibin):
|
| 10 |
+
res = output_divide(x, split_fn=split_fn)
|
| 11 |
+
res = output_predict(res, predict_fn_1=predict_fn_1, predict_fn_2=predict_fn_2)
|
| 12 |
+
return (res[0], res[1]) + tuple(output_properties(res[2], predict_fn=predict_fn_1))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def output_divide(output, split_fn=np_split):
|
| 16 |
+
group_output = output[..., rv.output.GROUP_SLICE_END]
|
| 17 |
+
slot_output = output[..., rv.output.SLOT_SLICE_END]
|
| 18 |
+
properties_output = output[..., rv.output.PROPERTIES_SLICE_END]
|
| 19 |
+
properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
|
| 20 |
+
return group_output, slot_output, properties_output_splited
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def output_predict(output, predict_fn_1=np.argmax, predict_fn_2=ibin):
|
| 24 |
+
return predict_fn_1(output[0]), predict_fn_2(output[1]), output[2]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def output_properties(x, predict_fn=np.argmax):
|
| 28 |
+
out_reshaped = []
|
| 29 |
+
for i, out in enumerate(x):
|
| 30 |
+
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
|
| 31 |
+
out_reshaped.append(predict_fn(out.reshape(shape)))
|
| 32 |
+
return out_reshaped
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def output_result(output, split_fn=np_split, arg_max=np.argmax):
|
| 36 |
+
result = output_properties(output, predict_fn=split_fn)
|
| 37 |
+
res = []
|
| 38 |
+
for i, r in enumerate(result):
|
| 39 |
+
if i == 1:
|
| 40 |
+
res.append(r)
|
| 41 |
+
else:
|
| 42 |
+
res.append(arg_max(r, axis=-1))
|
| 43 |
+
return tuple(res)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def decode_inference(inference, reshape=np.reshape):
|
| 47 |
+
return reshape(inference[rv.inference.SLOT_SLICE],
|
| 48 |
+
[-1, rv.group.NO, rv.inference.PROPERTY_TRANSFORMATION_NO]), reshape(
|
| 49 |
+
inference[rv.inference.PROPERTIES_SLICE],
|
| 50 |
+
[-1, rv.properties.NO, rv.entity.SUM, rv.inference.PROPERTY_TRANSFORMATION_NO])
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def decode_target(target):
|
| 54 |
+
target_group = target[..., 0]
|
| 55 |
+
target_slot = target[..., 1:rv.target.INDEX[0]]
|
| 56 |
+
target_properties = target[..., rv.target.INDEX[0]:rv.target.END_INDEX]
|
| 57 |
+
target_properties_splited = [
|
| 58 |
+
target_properties[..., ::rv.properties.NO],
|
| 59 |
+
target_properties[..., 1::rv.properties.NO],
|
| 60 |
+
target_properties[..., 2::rv.properties.NO]
|
| 61 |
+
]
|
| 62 |
+
return target_group, target_slot, target_properties_splited
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def decode_target_flat(target):
|
| 68 |
+
t = decode_target(target)
|
| 69 |
+
return t[0], t[1], t[2][0], t[2][1], t[2][2]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def demask(target, mask=None, group=None, zeroes=None):
|
| 73 |
+
if mask is None:
|
| 74 |
+
if group is None:
|
| 75 |
+
group = target[0]
|
| 76 |
+
# todo Use numpy range Mask
|
| 77 |
+
from models.uitls_ import RangeMask
|
| 78 |
+
mask = RangeMask()(group).numpy()
|
| 79 |
+
if zeroes is None:
|
| 80 |
+
return np.concatenate([t[mask] for t in lw(target[1:])])
|
| 81 |
+
return np.concatenate([target[0][None]] + [t * mask for t in lw(target[1:])],axis=-1)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def target_mask(mask,right=1):
|
| 85 |
+
shape = mask.shape
|
| 86 |
+
return np.concatenate([np.ones([shape[0], 1]) ,mask, np.repeat(mask,3,axis=1), np.ones([shape[0], right])],axis=1)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_full_range_mask(mask):
|
| 90 |
+
return np.concatenate([mask, np.repeat(mask, 3, axis=-1)], axis=-1)
|
| 91 |
+
|
| 92 |
+
def compare(target, predict, mask):
|
| 93 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
| 94 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
| 95 |
+
|
| 96 |
+
mask = get_full_range_mask(mask)
|
| 97 |
+
|
| 98 |
+
target_masked = target_comp * mask
|
| 99 |
+
predict_masked = predict_comp * mask
|
| 100 |
+
return target_masked == predict_masked
|
raven_utils/depricated/__init__.py
ADDED
|
File without changes
|
raven_utils/depricated/old_raven.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from data_utils import take, EXIST, COR
|
| 5 |
+
from data_utils.image import draw_images, add_text
|
| 6 |
+
from data_utils.op import np_split
|
| 7 |
+
from ml_utils import lu, dict_from_list2, filter_keys, none
|
| 8 |
+
from data_utils import ops as K
|
| 9 |
+
|
| 10 |
+
from config.constant import PROPERTY, TARGET, INPUTS
|
| 11 |
+
# from raven_utils.render.rendering import render_panels
|
| 12 |
+
|
| 13 |
+
RENDER_POSITIONS = [
|
| 14 |
+
[(0.5, 0.5, 1, 1)],
|
| 15 |
+
# ...
|
| 16 |
+
[(0.25, 0.25, 0.5, 0.5),
|
| 17 |
+
(0.25, 0.75, 0.5, 0.5),
|
| 18 |
+
(0.75, 0.25, 0.5, 0.5),
|
| 19 |
+
(0.75, 0.75, 0.5, 0.5)],
|
| 20 |
+
# ...
|
| 21 |
+
[(0.16, 0.16, 0.33, 0.33),
|
| 22 |
+
(0.16, 0.5, 0.33, 0.33),
|
| 23 |
+
(0.16, 0.83, 0.33, 0.33),
|
| 24 |
+
(0.5, 0.16, 0.33, 0.33),
|
| 25 |
+
(0.5, 0.5, 0.33, 0.33),
|
| 26 |
+
(0.5, 0.83, 0.33, 0.33),
|
| 27 |
+
(0.83, 0.16, 0.33, 0.33),
|
| 28 |
+
(0.83, 0.5, 0.33, 0.33),
|
| 29 |
+
(0.83, 0.83, 0.33, 0.33)],
|
| 30 |
+
# ...
|
| 31 |
+
[(0.5, 0.25, 0.5, 0.5)],
|
| 32 |
+
[(0.5, 0.75, 0.5, 0.5)],
|
| 33 |
+
# ...
|
| 34 |
+
[(0.25, 0.5, 0.5, 0.5)],
|
| 35 |
+
[(0.75, 0.5, 0.5, 0.5)],
|
| 36 |
+
# ...
|
| 37 |
+
[(0.5, 0.5, 1, 1)],
|
| 38 |
+
[(0.5, 0.5, 0.33, 0.33)],
|
| 39 |
+
# ...
|
| 40 |
+
[(0.5, 0.5, 1, 1)],
|
| 41 |
+
[(0.42, 0.42, 0.15, 0.15),
|
| 42 |
+
(0.42, 0.58, 0.15, 0.15),
|
| 43 |
+
(0.58, 0.42, 0.15, 0.15),
|
| 44 |
+
(0.58, 0.58, 0.15, 0.15)],
|
| 45 |
+
# ...
|
| 46 |
+
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
HORIZONTAL = "horizontal"
|
| 50 |
+
VERTICAL = "vertical"
|
| 51 |
+
|
| 52 |
+
NAMES = ['center_single',
|
| 53 |
+
'distribute_four',
|
| 54 |
+
'distribute_nine',
|
| 55 |
+
'in_center_single_out_center_single',
|
| 56 |
+
'in_distribute_four_out_center_single',
|
| 57 |
+
'left_center_single_right_center_single',
|
| 58 |
+
'up_center_single_down_center_single']
|
| 59 |
+
|
| 60 |
+
PROPERTIES_NAMES = [
|
| 61 |
+
'Color',
|
| 62 |
+
'Size',
|
| 63 |
+
'Type',
|
| 64 |
+
|
| 65 |
+
]
|
| 66 |
+
PROPERTIES = dict_from_list2(PROPERTIES_NAMES, [10, 6, 5])
|
| 67 |
+
ANGLE_MAX = 7
|
| 68 |
+
|
| 69 |
+
PROPERTIES_NO = len(PROPERTIES)
|
| 70 |
+
|
| 71 |
+
RULES_COMBINE = "Number/Position"
|
| 72 |
+
|
| 73 |
+
RULES_ATTRIBUTES = [
|
| 74 |
+
"Number",
|
| 75 |
+
"Position",
|
| 76 |
+
"Color",
|
| 77 |
+
"Size",
|
| 78 |
+
"Type"
|
| 79 |
+
]
|
| 80 |
+
RULES_ATTRIBUTES_LEN = len(RULES_ATTRIBUTES)
|
| 81 |
+
|
| 82 |
+
RULES_ATTRIBUTES_INDEX = dict_from_list2(RULES_ATTRIBUTES)
|
| 83 |
+
|
| 84 |
+
RULES_TYPES = [
|
| 85 |
+
"Constant",
|
| 86 |
+
"Arithmetic",
|
| 87 |
+
"Progression",
|
| 88 |
+
"Distribute_Three"
|
| 89 |
+
]
|
| 90 |
+
RULES_TYPES_INDEX = dict_from_list2(RULES_TYPES)
|
| 91 |
+
RULES_TYPES_LEN = len(RULES_ATTRIBUTES)
|
| 92 |
+
|
| 93 |
+
GROUPS_NO = len(NAMES)
|
| 94 |
+
ENTITY_NO = dict(zip(NAMES, [1, 4, 9, 2, 5, 2, 2]))
|
| 95 |
+
ENTITY_SUM = sum(list(ENTITY_NO.values()))
|
| 96 |
+
ENTITY_INDEX = np.concatenate([[0], np.cumsum(list(ENTITY_NO.values()))])
|
| 97 |
+
ENTITY_INDEX_TARGET = ENTITY_INDEX + 1
|
| 98 |
+
ENTITY_DICT = dict(zip(NAMES, ENTITY_INDEX_TARGET[:-1]))
|
| 99 |
+
NAMES_ORDER = dict(zip(NAMES, np.arange(len(NAMES))))
|
| 100 |
+
PROPERTIES_INDEXES = np.cumsum(np.array(list(ENTITY_NO.values())) * len(PROPERTIES))
|
| 101 |
+
INDEX = np.concatenate([[0], PROPERTIES_INDEXES]) + ENTITY_SUM + 1 # +2 type and uniformity
|
| 102 |
+
|
| 103 |
+
SECOND_LAYOUT = [i - 1 for i in [
|
| 104 |
+
ENTITY_DICT["in_center_single_out_center_single"] + 1,
|
| 105 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 1,
|
| 106 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 2,
|
| 107 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 3,
|
| 108 |
+
ENTITY_DICT["left_center_single_right_center_single"] + 1,
|
| 109 |
+
ENTITY_DICT["up_center_single_down_center_single"] + 1
|
| 110 |
+
]]
|
| 111 |
+
|
| 112 |
+
FIRST_LAYOUT = list(set(range(ENTITY_SUM)) - set(SECOND_LAYOUT))
|
| 113 |
+
LAYOUT_NO = 2
|
| 114 |
+
|
| 115 |
+
START_INDEX = dict(zip(NAMES, INDEX[:-1]))
|
| 116 |
+
END_INDEX = INDEX[-1]
|
| 117 |
+
|
| 118 |
+
RULES_ATTRIBUTES_ALL_LEN = RULES_ATTRIBUTES_LEN * LAYOUT_NO
|
| 119 |
+
UNIFORMITY_NO = 2
|
| 120 |
+
UNIFORMITY_INDEX = END_INDEX + RULES_ATTRIBUTES_ALL_LEN
|
| 121 |
+
|
| 122 |
+
FEATURE_NO = UNIFORMITY_INDEX + UNIFORMITY_NO
|
| 123 |
+
MAPPING = {
|
| 124 |
+
"distribute_nine":
|
| 125 |
+
{0.16: 0,
|
| 126 |
+
0.5: 1,
|
| 127 |
+
0.83: 2},
|
| 128 |
+
"distribute_four":
|
| 129 |
+
{0.25: 0,
|
| 130 |
+
0.75: 1},
|
| 131 |
+
'in_distribute_four_out_center_single':
|
| 132 |
+
{0.42: 0,
|
| 133 |
+
0.58: 1}
|
| 134 |
+
}
|
| 135 |
+
MUL = {
|
| 136 |
+
"distribute_nine": 3,
|
| 137 |
+
"distribute_four": 2,
|
| 138 |
+
'in_distribute_four_out_center_single': 2
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# SIZES = np.linspace(0.4, 0.9, 6)
|
| 142 |
+
TYPES = ["triangle", "square", "pentagon", "hexagon", "circle"]
|
| 143 |
+
# TYPES = ["triangle", "square", "pentagon", "circle", "circle"]
|
| 144 |
+
SIZES = ["vs", "s", "m", "h", "vh", "e"]
|
| 145 |
+
COLORS = ["vs", "s", "m", "h", "vh", "e"]
|
| 146 |
+
# TYPES = ["", "", "circle", "hexagon", "square"]
|
| 147 |
+
|
| 148 |
+
ENTITY_PROPERTIES_VALUES = list(PROPERTIES.values())
|
| 149 |
+
ENTITY_PROPERTIES_KEYS = list(PROPERTIES.keys())
|
| 150 |
+
ENTITY_PROPERTIES_NO = len(PROPERTIES)
|
| 151 |
+
INDEX = dict(zip(PROPERTIES, np.array(ENTITY_PROPERTIES_VALUES) * ENTITY_SUM))
|
| 152 |
+
ENTITY_PROPERTIES_SUM = sum(list(PROPERTIES.values()))
|
| 153 |
+
|
| 154 |
+
OUTPUT_SIZE = ENTITY_SUM * ENTITY_PROPERTIES_SUM + GROUPS_NO + ENTITY_SUM
|
| 155 |
+
|
| 156 |
+
SLOT_AND_GROUP = ENTITY_SUM + GROUPS_NO
|
| 157 |
+
|
| 158 |
+
OUTPUT_GROUP_SLICE = np.s_[:, -GROUPS_NO:]
|
| 159 |
+
OUTPUT_SLOT_SLICE = np.s_[:, -SLOT_AND_GROUP:-GROUPS_NO]
|
| 160 |
+
OUTPUT_PROPERTIES_SLICE = np.s_[:, :-SLOT_AND_GROUP]
|
| 161 |
+
|
| 162 |
+
OUTPUT_GROUP_SLICE_END = np.s_[-GROUPS_NO:]
|
| 163 |
+
OUTPUT_SLOT_SLICE_END = np.s_[-SLOT_AND_GROUP:-GROUPS_NO]
|
| 164 |
+
OUTPUT_PROPERTIES_SLICE_END = np.s_[:-SLOT_AND_GROUP]
|
| 165 |
+
|
| 166 |
+
# Transformation
|
| 167 |
+
# constant
|
| 168 |
+
# progression -2, -1,1 ,2
|
| 169 |
+
# arithmetic -/+ Position set arithmetic
|
| 170 |
+
# distribute three
|
| 171 |
+
|
| 172 |
+
# todo
|
| 173 |
+
SLOTS_GROUPS = GROUPS_NO
|
| 174 |
+
|
| 175 |
+
SLOT_TRANSFORMATION_NO = 4
|
| 176 |
+
PROPERTY_TRANSFORMATION_NO = 8
|
| 177 |
+
PROPERTIES_TRANSFORMATION_NO = PROPERTY_TRANSFORMATION_NO * PROPERTIES_NO
|
| 178 |
+
PROPERTIES_TRANSFORMATION_SIZE = PROPERTIES_TRANSFORMATION_NO * ENTITY_SUM
|
| 179 |
+
|
| 180 |
+
SLOT_TRANSFORMATION_SIZE = PROPERTY_TRANSFORMATION_NO * SLOTS_GROUPS
|
| 181 |
+
INFERENCE_SIZE = SLOT_TRANSFORMATION_SIZE + PROPERTIES_TRANSFORMATION_SIZE
|
| 182 |
+
|
| 183 |
+
INFERENCE_SLOT_SLICE = np.s_[:, :SLOT_TRANSFORMATION_SIZE]
|
| 184 |
+
INFERENCE_PROPERTIES_SLICE = np.s_[:, -PROPERTIES_TRANSFORMATION_SIZE:]
|
| 185 |
+
from operator import add
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# todo Refactor
|
| 189 |
+
# Maybe properties should be on same level as rest.
|
| 190 |
+
def decode_output(output, split_fn=np_split):
|
| 191 |
+
group_output = output[..., OUTPUT_GROUP_SLICE_END]
|
| 192 |
+
slot_output = output[..., OUTPUT_SLOT_SLICE_END]
|
| 193 |
+
properties_output = output[..., OUTPUT_PROPERTIES_SLICE_END]
|
| 194 |
+
properties_output_splited = split_fn(properties_output, list(rv.properties.INDEX.values()), axis=-1)
|
| 195 |
+
return group_output, slot_output, properties_output_splited
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def decode_inference(inference, reshape=np.reshape):
|
| 199 |
+
return reshape(inference[INFERENCE_SLOT_SLICE],
|
| 200 |
+
[-1, SLOTS_GROUPS, PROPERTY_TRANSFORMATION_NO]), reshape(
|
| 201 |
+
inference[INFERENCE_PROPERTIES_SLICE],
|
| 202 |
+
[-1, PROPERTIES_NO, ENTITY_SUM, PROPERTY_TRANSFORMATION_NO])
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def decode_output_reshape(output, split_fn=np_split):
|
| 206 |
+
result = decode_output(output, split_fn=split_fn)
|
| 207 |
+
out_reshaped = []
|
| 208 |
+
for i, out in enumerate(result[2]):
|
| 209 |
+
shape = (-1, ENTITY_SUM, ENTITY_PROPERTIES_VALUES[i])
|
| 210 |
+
out_reshaped.append(out.reshape(shape))
|
| 211 |
+
return result[:2] + tuple(out_reshaped)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def take_target(target):
|
| 215 |
+
return target[1], target[2]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def create_target(images, index, pattern_index=(2, 5), full_index=False, arrange=np.arange, shape=lambda x: x.shape):
|
| 219 |
+
return [images[:, pattern_index[0]], images[:, pattern_index[1]],
|
| 220 |
+
images[arrange(shape(index)[0]), (0 if full_index else 8) + index[:, 0]]]
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def take_target_simple(target):
|
| 224 |
+
return target[1], target[0]
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def create_target_simple(images, target, index=slice(None), pattern_index=(2, 5)):
|
| 228 |
+
return [images[:, pattern_index[0]], images[:, pattern_index[1]], target][index]
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def decode_output_result(output, split_fn=np_split, arg_max=np.argmax):
|
| 232 |
+
result = decode_output_reshape(output, split_fn=split_fn)
|
| 233 |
+
res = []
|
| 234 |
+
for i, r in enumerate(result):
|
| 235 |
+
if i == 1:
|
| 236 |
+
res.append(r)
|
| 237 |
+
else:
|
| 238 |
+
res.append(arg_max(r, axis=-1))
|
| 239 |
+
return tuple(res)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def decode_target(target):
|
| 243 |
+
target_group = target[..., 0]
|
| 244 |
+
target_slot = target[..., 1:INDEX[0]]
|
| 245 |
+
target_properties = target[..., INDEX[0]:END_INDEX]
|
| 246 |
+
target_properties_splited = [
|
| 247 |
+
target_properties[..., ::PROPERTIES_NO],
|
| 248 |
+
target_properties[..., 1::PROPERTIES_NO],
|
| 249 |
+
target_properties[..., 2::PROPERTIES_NO]
|
| 250 |
+
]
|
| 251 |
+
return target_group, target_slot, target_properties_splited
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def decode_target_flat(target):
|
| 255 |
+
t = decode_target(target)
|
| 256 |
+
return t[0], t[1], t[2][0], t[2][1], t[2][2]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def draw_board(images, target=None, predict=None,image=None, desc=None, layout=None, break_=20):
|
| 260 |
+
if image != "target" and predict is not None:
|
| 261 |
+
image = images[predict:predict + 1]
|
| 262 |
+
elif images is None and target is not None:
|
| 263 |
+
image = images[target:target + 1]
|
| 264 |
+
# image = False to not draw anything
|
| 265 |
+
border = [{COR: target - 8, EXIST: (1, 3)}] + [{COR: p, EXIST: (0, 2)} for p in none(predict)]
|
| 266 |
+
|
| 267 |
+
boards = []
|
| 268 |
+
boards.append(draw_images(np.concatenate([images[:8], image[None] if len(image.shape)==3 else image]) if image is not None else images[:8]))
|
| 269 |
+
if layout == 1:
|
| 270 |
+
i = draw_images(images[8:], column=4, border=border)
|
| 271 |
+
if break_:
|
| 272 |
+
i = np.concatenate([np.zeros([ break_, i.shape[1],1]),i ],axis=0)
|
| 273 |
+
boards.append(i)
|
| 274 |
+
|
| 275 |
+
else:
|
| 276 |
+
boards.append(
|
| 277 |
+
draw_images(np.concatenate([images[8:], predict]) if predict is not None else images[8:], column=4,
|
| 278 |
+
border=target - 8))
|
| 279 |
+
full_board = draw_images(boards, grid=False)
|
| 280 |
+
if desc:
|
| 281 |
+
full_board = add_text(full_board, desc)
|
| 282 |
+
return full_board
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def draw_boards(images, target=None, predict=None, image=None, desc=None, no=1, layout=None):
|
| 286 |
+
boards = []
|
| 287 |
+
for i, image in enumerate(images):
|
| 288 |
+
boards.append(draw_board(image, target[i][0] if target is not None else None,
|
| 289 |
+
predict[i] if predict is not None else None,
|
| 290 |
+
image[i] if image is not None else None,
|
| 291 |
+
desc[i] if desc is not None else None, layout=layout))
|
| 292 |
+
return boards
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def draw_raven(generator, predict=None, no=1, add_target_desc=True, indexes=None, types=TYPES,
|
| 296 |
+
layout=1):
|
| 297 |
+
if indexes is None:
|
| 298 |
+
indexes = val_sample(no)
|
| 299 |
+
data = generator.data[indexes]
|
| 300 |
+
if is_model(predict):
|
| 301 |
+
d = filter_keys(data, PROPERTY,reverse=True)
|
| 302 |
+
# tmp change
|
| 303 |
+
pro = predict(d)['predict']
|
| 304 |
+
print(pro)
|
| 305 |
+
predict = render_panels(pro, target=False)
|
| 306 |
+
# if target is not None:
|
| 307 |
+
target = data[TARGET]
|
| 308 |
+
target_index = data["index"]
|
| 309 |
+
images = data[INPUTS]
|
| 310 |
+
|
| 311 |
+
if hasattr(predict, "shape"):
|
| 312 |
+
if len(predict.shape) > 3:
|
| 313 |
+
# iamges
|
| 314 |
+
image = predict
|
| 315 |
+
# todo create index and output based on image
|
| 316 |
+
predict = None
|
| 317 |
+
predict_index = None
|
| 318 |
+
elif len(predict.shape) == 3:
|
| 319 |
+
image = render_panels(predict, target=False)
|
| 320 |
+
# Create index based on predict.
|
| 321 |
+
predict_index = None
|
| 322 |
+
else:
|
| 323 |
+
image = images[predict]
|
| 324 |
+
predict_index = predict
|
| 325 |
+
predict = target
|
| 326 |
+
else:
|
| 327 |
+
image = K.gather(images, target_index[:, 0])
|
| 328 |
+
predict_index = None
|
| 329 |
+
predict = None
|
| 330 |
+
|
| 331 |
+
# elif not(hasattr(target,"shape") and len(target.shape) > 3):
|
| 332 |
+
# if hasattr(target,"shape") and target.shape[-1] == OUTPUT_SIZE:
|
| 333 |
+
# pro = target
|
| 334 |
+
# predict = render_panels(pro)
|
| 335 |
+
# elif hasattr(target,"shape") and target.shape[-1] == FEATURE_NO:
|
| 336 |
+
# # pro = target
|
| 337 |
+
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
| 338 |
+
# else:
|
| 339 |
+
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
| 340 |
+
# # predict = [None] * no
|
| 341 |
+
# predict = render_panels(data[TARGET])
|
| 342 |
+
|
| 343 |
+
all_rules = []
|
| 344 |
+
for d in data[PROPERTY]:
|
| 345 |
+
rules = []
|
| 346 |
+
for j, rule_group in enumerate(d.findAll("Rule_Group")):
|
| 347 |
+
# rules_all.append(rule_group['id'])
|
| 348 |
+
for j, rule in enumerate(rule_group.findAll("Rule")):
|
| 349 |
+
rules.append(f"{rule['attr']} - {rule['name']}")
|
| 350 |
+
rules.append("")
|
| 351 |
+
all_rules.append(rules)
|
| 352 |
+
target_desc = get_desc(target)
|
| 353 |
+
if predict is not None:
|
| 354 |
+
predict_desc = decode_output_result(predict) if predict.shape[-1] == OUTPUT_SIZE else get_desc(predict)
|
| 355 |
+
else:
|
| 356 |
+
predict_desc = [None] * len(target_desc)
|
| 357 |
+
for a, po, to in zip(all_rules, predict_desc, target_desc):
|
| 358 |
+
# fl(predict_desc[-1])
|
| 359 |
+
if po is None:
|
| 360 |
+
po = [None] * len(to)
|
| 361 |
+
for p, t in zip(po, to):
|
| 362 |
+
a.extend(
|
| 363 |
+
[" ".join([str(i) for i in t])] + (
|
| 364 |
+
[" ".join([str(i) for i in p]), ""] if p is not None else []
|
| 365 |
+
)
|
| 366 |
+
)
|
| 367 |
+
# a.extend([""] + [] + [""] + [" ".join(fl(p))])
|
| 368 |
+
|
| 369 |
+
# image = draw_boards(data[INPUTS],target=data["index"], predict=predict[:no], desc=all_rules, no=no,layer=layer)
|
| 370 |
+
image = draw_boards(images, target=target_index, predict=predict_index, image=image, desc=None, no=no,
|
| 371 |
+
layout=layout)
|
| 372 |
+
return lu([(i, j) for i, j in zip(image, all_rules)])
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def val_sample(no=GROUPS_NO, base=3):
|
| 376 |
+
indexes = np.arange(no) * 2000 + base
|
| 377 |
+
return indexes
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def get_desc(target, exist=None, types=TYPES, sizes=SIZES):
|
| 381 |
+
decoded = decode_target(target)
|
| 382 |
+
exist = decoded[1] if exist is None else exist
|
| 383 |
+
taken = np.stack(take(decoded[2], np.array(exist, dtype=bool))).T
|
| 384 |
+
|
| 385 |
+
figures_no = np.sum(exist, axis=-1)
|
| 386 |
+
desc = np.split(taken, np.cumsum(figures_no))[:-1]
|
| 387 |
+
# figures_no = np.sum(exist, axis=-1)
|
| 388 |
+
# div = np.split(desc, np.cumsum(figures_no))[:-1]
|
| 389 |
+
result = []
|
| 390 |
+
for pd in desc:
|
| 391 |
+
r = []
|
| 392 |
+
for p in pd:
|
| 393 |
+
r.append([p[0], sizes[p[1]], types[p[2]]])
|
| 394 |
+
result.append(r)
|
| 395 |
+
|
| 396 |
+
return result
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# def get
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def get_description(inputs, predict, pro, no, types=TYPES, sizes=SIZES):
|
| 403 |
+
# target = inputs[1][2][:no]
|
| 404 |
+
target = inputs[TARGET]
|
| 405 |
+
target_group = target[:, 0]
|
| 406 |
+
target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
|
| 407 |
+
target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
|
| 408 |
+
pro_reshaped = np.reshape(pro, (pro.shape[0], -1, PROPERTIES_NO))
|
| 409 |
+
target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
|
| 410 |
+
|
| 411 |
+
# mask = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
|
| 412 |
+
# masked_result = np.repeat(target_exist, [4] * ENTITY_SUM, axis=-1)
|
| 413 |
+
pro_res = pro_reshaped[target_exist]
|
| 414 |
+
target_res = target_reshaped[target_exist]
|
| 415 |
+
figures_no = np.sum(target_exist, axis=-1)
|
| 416 |
+
|
| 417 |
+
pro_div = np.split(pro_res, np.cumsum(figures_no))[:-1]
|
| 418 |
+
target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
|
| 419 |
+
pro_result_full = []
|
| 420 |
+
target_result_full = []
|
| 421 |
+
for pd, td in zip(pro_div, target_div):
|
| 422 |
+
pro_result = []
|
| 423 |
+
target_result = []
|
| 424 |
+
for p in pd:
|
| 425 |
+
pro_result.append([p[0], sizes[p[1]], types[p[2]]])
|
| 426 |
+
for t in td:
|
| 427 |
+
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
| 428 |
+
pro_result_full.append(pro_result)
|
| 429 |
+
target_result_full.append(target_result)
|
| 430 |
+
|
| 431 |
+
return pro_result_full, target_result_full
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def get_properties(target, types=TYPES, sizes=SIZES):
|
| 435 |
+
target_exist = np.asarray(target[:, 1:ENTITY_SUM + 1], dtype="bool")
|
| 436 |
+
target_rest = target[:, ENTITY_SUM + 1:ENTITY_SUM + 1 + ENTITY_SUM * PROPERTIES_NO]
|
| 437 |
+
target_reshaped = np.reshape(target_rest, (target_rest.shape[0], -1, PROPERTIES_NO))
|
| 438 |
+
target_res = target_reshaped[target_exist]
|
| 439 |
+
figures_no = np.sum(target_exist, axis=-1)
|
| 440 |
+
target_div = np.split(target_res, np.cumsum(figures_no))[:-1]
|
| 441 |
+
target_result_full = []
|
| 442 |
+
for td in target_div:
|
| 443 |
+
target_result = []
|
| 444 |
+
for t in td:
|
| 445 |
+
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
| 446 |
+
target_result_full.append(target_result)
|
| 447 |
+
return target_result_full
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def desc_properties(target, decode_fn=None, types=TYPES, sizes=SIZES):
|
| 451 |
+
if decode_fn is None:
|
| 452 |
+
if target.shape[1] == OUTPUT_SIZE:
|
| 453 |
+
decode_fn = decode_output_result
|
| 454 |
+
else:
|
| 455 |
+
decode_fn = decode_target
|
| 456 |
+
|
| 457 |
+
target_div = decode_fn(target)[2:]
|
| 458 |
+
target_result_full = []
|
| 459 |
+
for td in target_div:
|
| 460 |
+
target_result = []
|
| 461 |
+
for t in td:
|
| 462 |
+
target_result.append([t[0], sizes[t[1]], types[t[2]]])
|
| 463 |
+
target_result_full.append(target_result)
|
| 464 |
+
return target_result_full
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def get_pro(t, types=TYPES, sizes=SIZES):
|
| 468 |
+
return [int(t[0]), sizes[t[1]], types[t[2]]]
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def get_pro2(td, types=TYPES, sizes=SIZES):
|
| 472 |
+
target_result = []
|
| 473 |
+
for t in td:
|
| 474 |
+
target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
|
| 475 |
+
return target_result
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def get_pro3(target_div, types=TYPES, sizes=SIZES):
|
| 479 |
+
target_result_full = []
|
| 480 |
+
for td in target_div.to_list():
|
| 481 |
+
target_result = []
|
| 482 |
+
for t in td:
|
| 483 |
+
target_result.append([int(t[0]), sizes[t[1]], types[t[2]]])
|
| 484 |
+
target_result_full.append(target_result)
|
| 485 |
+
return target_result_full
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
from models_utils import init_image as def_init_image, is_model
|
| 489 |
+
|
| 490 |
+
init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
|
raven_utils/draw.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from data_utils import take, EXIST, COR
|
| 3 |
+
from data_utils.image import draw_images, add_text
|
| 4 |
+
from funcy import identity
|
| 5 |
+
from ml_utils import none, filter_keys, lu
|
| 6 |
+
from models_utils import is_model
|
| 7 |
+
from models_utils import ops as K
|
| 8 |
+
|
| 9 |
+
from raven_utils.constant import PROPERTY, TARGET, INPUTS
|
| 10 |
+
from raven_utils.decode import decode_target, target_mask
|
| 11 |
+
from raven_utils.render.rendering import render_panels
|
| 12 |
+
from raven_utils.render_ import TYPES, SIZES
|
| 13 |
+
from raven_utils.uitls import get_val_index
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def draw_board(images, target=None, predict=None, image=None, desc=None, layout=None, break_=20):
|
| 17 |
+
if image != "target" and predict is not None:
|
| 18 |
+
image = images[predict:predict + 1]
|
| 19 |
+
elif images is None and target is not None:
|
| 20 |
+
image = images[target:target + 1]
|
| 21 |
+
# image = False to not draw anything
|
| 22 |
+
border = [{COR: target - 8, EXIST: list(range(4)) if predict is None else (1, 3)}] + [{COR: p, EXIST: (0, 2)} for p
|
| 23 |
+
in none(predict)]
|
| 24 |
+
|
| 25 |
+
boards = []
|
| 26 |
+
boards.append(draw_images(
|
| 27 |
+
np.concatenate([images[:8], image[None] if len(image.shape) == 3 else image]) if image is not None else images[
|
| 28 |
+
:8]))
|
| 29 |
+
if layout == 1:
|
| 30 |
+
i = draw_images(images[8:], column=4, border=border)
|
| 31 |
+
if break_:
|
| 32 |
+
i = np.concatenate([np.zeros([break_, i.shape[1], 1]), i], axis=0)
|
| 33 |
+
boards.append(i)
|
| 34 |
+
|
| 35 |
+
else:
|
| 36 |
+
boards.append(
|
| 37 |
+
draw_images(np.concatenate([images[8:], predict]) if predict is not None else images[8:], column=4,
|
| 38 |
+
border=target - 8))
|
| 39 |
+
full_board = draw_images(boards, grid=False)
|
| 40 |
+
if desc:
|
| 41 |
+
full_board = add_text(full_board, desc)
|
| 42 |
+
return full_board
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def draw_boards(images, target=None, predict=None,image=None, desc=None, layout=None):
|
| 46 |
+
boards = []
|
| 47 |
+
for i, im in enumerate(images):
|
| 48 |
+
boards.append(draw_board(im, target[i][0] if target is not None else None,
|
| 49 |
+
predict[i] if predict is not None else None,
|
| 50 |
+
image[i] if image is not None else None,
|
| 51 |
+
desc[i] if desc is not None else None, layout=layout))
|
| 52 |
+
return boards
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def draw_from_generator(generator, predict=None, no=1, indexes=None, layout=1):
|
| 56 |
+
data,_ = val_sample(generator, no, indexes)
|
| 57 |
+
return draw_raven(data, predict=predict, pre_fn=generator.data.data["inputs"].fn, layout=layout)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def val_sample(generator, no=1, indexes=None):
|
| 61 |
+
if indexes is None:
|
| 62 |
+
indexes = get_val_index(base=no)
|
| 63 |
+
data = generator.data[indexes]
|
| 64 |
+
return data, indexes
|
| 65 |
+
|
| 66 |
+
def render_from_model(data,predict,pre_fn=identity):
|
| 67 |
+
data = filter_keys(data, PROPERTY, reverse=True)
|
| 68 |
+
if is_model(predict):
|
| 69 |
+
predict = predict(data)
|
| 70 |
+
pro = np.array(target_mask(predict['predict_mask'].numpy()) * predict["predict"].numpy(), dtype=np.int8)
|
| 71 |
+
return pre_fn(render_panels(pro, target=False)[None])[0]
|
| 72 |
+
|
| 73 |
+
def draw_raven(data, predict=None, pre_fn=identity, layout=1):
|
| 74 |
+
if is_model(predict):
|
| 75 |
+
d = filter_keys(data, PROPERTY, reverse=True)
|
| 76 |
+
# tmp change
|
| 77 |
+
res = predict(d)
|
| 78 |
+
pro = np.array(target_mask(res['mask'].numpy()) * res["predict"].numpy(),dtype=np.int8)
|
| 79 |
+
predict = pre_fn(render_panels(pro, target=False)[None])[0]
|
| 80 |
+
# from data_utils import ims
|
| 81 |
+
# ims(1 - predict[0])
|
| 82 |
+
# if target is not None:
|
| 83 |
+
target = data[TARGET]
|
| 84 |
+
target_index = data["index"]
|
| 85 |
+
images = data[INPUTS]
|
| 86 |
+
# np.equal(res['predict'], pro[:,:102]).sum()
|
| 87 |
+
|
| 88 |
+
if hasattr(predict, "shape"):
|
| 89 |
+
if len(predict.shape) > 3:
|
| 90 |
+
# iamges
|
| 91 |
+
image = predict
|
| 92 |
+
# todo create index and output based on image
|
| 93 |
+
predict = None
|
| 94 |
+
predict_index = None
|
| 95 |
+
elif len(predict.shape) == 3:
|
| 96 |
+
image = render_panels(predict, target=False)
|
| 97 |
+
# Create index based on predict.
|
| 98 |
+
predict_index = None
|
| 99 |
+
else:
|
| 100 |
+
image = images[predict]
|
| 101 |
+
predict_index = predict
|
| 102 |
+
predict = target
|
| 103 |
+
else:
|
| 104 |
+
image = K.gather(images, target_index[:, 0])
|
| 105 |
+
predict_index = None
|
| 106 |
+
predict = None
|
| 107 |
+
|
| 108 |
+
# elif not(hasattr(target,"shape") and len(target.shape) > 3):
|
| 109 |
+
# if hasattr(target,"shape") and target.shape[-1] == OUTPUT_SIZE:
|
| 110 |
+
# pro = target
|
| 111 |
+
# predict = render_panels(pro)
|
| 112 |
+
# elif hasattr(target,"shape") and target.shape[-1] == FEATURE_NO:
|
| 113 |
+
# # pro = target
|
| 114 |
+
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
| 115 |
+
# else:
|
| 116 |
+
# pro = np.zeros([no, OUTPUT_SIZE], dtype="int")
|
| 117 |
+
# # predict = [None] * no
|
| 118 |
+
# predict = render_panels(data[TARGET])
|
| 119 |
+
|
| 120 |
+
image = draw_boards(images, target=target_index, predict=predict_index,image=image, desc=None,
|
| 121 |
+
layout=layout)
|
| 122 |
+
|
| 123 |
+
all_rules = extract_rules(data[PROPERTY])
|
| 124 |
+
target_desc = get_desc(target)
|
| 125 |
+
if predict is not None:
|
| 126 |
+
predict_desc = get_desc(predict)
|
| 127 |
+
else:
|
| 128 |
+
predict_desc = [None] * len(target_desc)
|
| 129 |
+
for a, po, to in zip(all_rules, predict_desc, target_desc):
|
| 130 |
+
# fl(predict_desc[-1])
|
| 131 |
+
if po is None:
|
| 132 |
+
po = [None] * len(to)
|
| 133 |
+
for p, t in zip(po, to):
|
| 134 |
+
a.extend(
|
| 135 |
+
[" ".join([str(i) for i in t])] + (
|
| 136 |
+
[" ".join([str(i) for i in p]), ""] if p is not None else []
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
# a.extend([""] + [] + [""] + [" ".join(fl(p))])
|
| 140 |
+
|
| 141 |
+
# image = draw_boards(data[INPUTS],target=data["index"], predict=predict[:no], desc=all_rules, no=no,layer=layer)
|
| 142 |
+
return lu([(i, j) for i, j in zip(image, all_rules)])
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def extract_rules(data):
|
| 146 |
+
all_rules = []
|
| 147 |
+
for d in data:
|
| 148 |
+
rules = []
|
| 149 |
+
for j, rule_group in enumerate(d.findAll("Rule_Group")):
|
| 150 |
+
# rules_all.append(rule_group['id'])
|
| 151 |
+
for j, rule in enumerate(rule_group.findAll("Rule")):
|
| 152 |
+
rules.append(f"{rule['attr']} - {rule['name']}")
|
| 153 |
+
rules.append("")
|
| 154 |
+
all_rules.append(rules)
|
| 155 |
+
return all_rules
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_desc(target, exist=None, types=TYPES, sizes=SIZES):
|
| 159 |
+
decoded = decode_target(target)
|
| 160 |
+
exist = decoded[1] if exist is None else exist
|
| 161 |
+
taken = np.stack(take(decoded[2], np.array(exist, dtype=bool))).T
|
| 162 |
+
|
| 163 |
+
figures_no = np.sum(exist, axis=-1)
|
| 164 |
+
desc = np.split(taken, np.cumsum(figures_no))[:-1]
|
| 165 |
+
# figures_no = np.sum(exist, axis=-1)
|
| 166 |
+
# div = np.split(desc, np.cumsum(figures_no))[:-1]
|
| 167 |
+
result = []
|
| 168 |
+
for pd in desc:
|
| 169 |
+
r = []
|
| 170 |
+
for p in pd:
|
| 171 |
+
r.append([p[0], sizes[p[1]], types[p[2]]])
|
| 172 |
+
result.append(r)
|
| 173 |
+
|
| 174 |
+
return result
|
raven_utils/entity.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import raven_utils.group as group
|
| 2 |
+
import numpy as np
|
| 3 |
+
NO = dict(zip(group.NAMES, [1, 4, 9, 2, 5, 2, 2]))
|
| 4 |
+
SUM = sum(list(NO.values()))
|
| 5 |
+
|
| 6 |
+
INDEX = np.concatenate([[0], np.cumsum(list(NO.values()))])
|
raven_utils/group.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
NAMES = ['center_single',
|
| 4 |
+
'distribute_four',
|
| 5 |
+
'distribute_nine',
|
| 6 |
+
'in_center_single_out_center_single',
|
| 7 |
+
'in_distribute_four_out_center_single',
|
| 8 |
+
'left_center_single_right_center_single',
|
| 9 |
+
'up_center_single_down_center_single']
|
| 10 |
+
|
| 11 |
+
NO = len(NAMES)
|
raven_utils/inference.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import raven_utils.properties as properties
|
| 3 |
+
import raven_utils.group as group
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
SLOT_TRANSFORMATION_NO = 4
|
| 7 |
+
PROPERTY_TRANSFORMATION_NO = 8
|
| 8 |
+
PROPERTIES_TRANSFORMATION_NO = PROPERTY_TRANSFORMATION_NO * properties.NO
|
| 9 |
+
PROPERTIES_TRANSFORMATION_SIZE = PROPERTIES_TRANSFORMATION_NO * group.NO
|
| 10 |
+
|
| 11 |
+
SLOT_TRANSFORMATION_SIZE = PROPERTY_TRANSFORMATION_NO * group.NO
|
| 12 |
+
SIZE = SLOT_TRANSFORMATION_SIZE + PROPERTIES_TRANSFORMATION_SIZE
|
| 13 |
+
|
| 14 |
+
SLOT_SLICE = np.s_[:, :SLOT_TRANSFORMATION_SIZE]
|
| 15 |
+
PROPERTIES_SLICE = np.s_[:, -PROPERTIES_TRANSFORMATION_SIZE:]
|
raven_utils/models/__init__.py
ADDED
|
File without changes
|
raven_utils/models/attn.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from tensorflow.keras import backend as K
|
| 5 |
+
from tensorflow.keras.layers import LSTMCell
|
| 6 |
+
from tensorflow.keras.models import Model
|
| 7 |
+
from tensorflow.keras.layers import Conv2D, Dense
|
| 8 |
+
from tensorflow.keras.losses import mse
|
| 9 |
+
from tensorflow.keras.models import clone_model
|
| 10 |
+
from tensorflow.layers.base import InputSpec, Layer
|
| 11 |
+
|
| 12 |
+
from models.dense import create_conv_model
|
| 13 |
+
from models.utils import broadcast
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ReflectionPadding2D(Layer):
|
| 17 |
+
def __init__(self, padding=(1, 1), **kwargs):
|
| 18 |
+
self.padding = tuple(padding)
|
| 19 |
+
self.input_spec = [InputSpec(ndim=4)]
|
| 20 |
+
super(ReflectionPadding2D, self).__init__(**kwargs)
|
| 21 |
+
|
| 22 |
+
def compute_output_shape(self, s):
|
| 23 |
+
""" If you are using "channels_last" configuration"""
|
| 24 |
+
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
| 25 |
+
|
| 26 |
+
def call(self, x, mask=None):
|
| 27 |
+
w_pad, h_pad = self.padding
|
| 28 |
+
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Conv2Ref(Layer):
|
| 32 |
+
def __init__(self, padding=(1, 1), **kwargs):
|
| 33 |
+
self.padding = tuple(padding)
|
| 34 |
+
self.input_spec = [InputSpec(ndim=4)]
|
| 35 |
+
super(ReflectionPadding2D, self).__init__(**kwargs)
|
| 36 |
+
|
| 37 |
+
def compute_output_shape(self, s):
|
| 38 |
+
""" If you are using "channels_last" configuration"""
|
| 39 |
+
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
| 40 |
+
|
| 41 |
+
def call(self, x, mask=None):
|
| 42 |
+
w_pad, h_pad = self.padding
|
| 43 |
+
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SegmentationNetwork(Model):
|
| 47 |
+
|
| 48 |
+
def __init__(self, filters=64, kernels=(3, 3)):
|
| 49 |
+
super(RecAE, self).__init__()
|
| 50 |
+
self.conv_1 = Conv2D(filters, kernels, padding=SAME)
|
| 51 |
+
self.conv_2 = Conv2D(filters, kernels, padding=SAME)
|
| 52 |
+
|
| 53 |
+
def call(self, inputs):
|
| 54 |
+
x = K.relu(inputs)
|
| 55 |
+
x = self.conv_1(x)
|
| 56 |
+
x = K.relu(x)
|
| 57 |
+
x = self.conv_2(x)
|
| 58 |
+
return x + inputs
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class QueryNetwork(Model):
|
| 62 |
+
|
| 63 |
+
def __init__(self, units=64):
|
| 64 |
+
super(RecAE, self).__init__()
|
| 65 |
+
self.conv_1 = Dense(units)
|
| 66 |
+
self.conv_2 = Dense(units)
|
| 67 |
+
|
| 68 |
+
def call(self, inputs):
|
| 69 |
+
x = K.relu(inputs)
|
| 70 |
+
x = self.conv_1(x)
|
| 71 |
+
x = K.relu(x)
|
| 72 |
+
x = self.conv_2(x)
|
| 73 |
+
return x + inputs
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class RecAE(Model):
|
| 77 |
+
|
| 78 |
+
def __init__(self, head, bottle, decoder):
|
| 79 |
+
super(RecAE, self).__init__()
|
| 80 |
+
self.head = head
|
| 81 |
+
self.bottle = bottle
|
| 82 |
+
self.base = clone_model(bottle)
|
| 83 |
+
self.decoder = decoder
|
| 84 |
+
self.segmentation_network = SegmentationNetwork()
|
| 85 |
+
self.query_network = QueryNetwork()
|
| 86 |
+
self.control = LSTMCell(64)
|
| 87 |
+
self.memory = LSTMCell(64)
|
| 88 |
+
|
| 89 |
+
def call(self, inputs):
|
| 90 |
+
feature = self.head(inputs)
|
| 91 |
+
segmentation = self.segmentation_network(feature)
|
| 92 |
+
control_base = self.base(feature)
|
| 93 |
+
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
| 94 |
+
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
| 95 |
+
shape = K.shape(feature)[:-1]
|
| 96 |
+
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
| 97 |
+
full_image = tf.zeros(K.shape(inputs))
|
| 98 |
+
masks = []
|
| 99 |
+
ff = tf.zeros(K.shape(inputs))
|
| 100 |
+
scope = tf.ones(shape)[..., tf.newaxis]
|
| 101 |
+
for i in range(4):
|
| 102 |
+
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
| 103 |
+
query = self.query_network(h_c[0])
|
| 104 |
+
log_attention = image_attention(segmentation, query)
|
| 105 |
+
attention = K.sigmoid(log_attention)
|
| 106 |
+
mask = attention * scope
|
| 107 |
+
scope = scope - mask
|
| 108 |
+
im = feature * mask
|
| 109 |
+
# im = feature
|
| 110 |
+
latent = self.bottle(im)
|
| 111 |
+
decoded = self.decoder(latent)
|
| 112 |
+
# self.add_loss(K.mean(-mse(full_attention, attention)))
|
| 113 |
+
# self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
|
| 114 |
+
full_attention += attention
|
| 115 |
+
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
| 116 |
+
ff += K.sigmoid(decoded)
|
| 117 |
+
full_image += K.sigmoid(decoded) * big_mask
|
| 118 |
+
r_m, h_m = self.memory(latent, h_m)
|
| 119 |
+
masks.append(big_mask)
|
| 120 |
+
self.add_loss(K.mean(mse(inputs, full_image)))
|
| 121 |
+
return full_image, masks
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# def image_attention(image, query, scale=True):
|
| 125 |
+
@tf.function
|
| 126 |
+
def image_attention(image, query):
|
| 127 |
+
log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
|
| 128 |
+
# if scale is not None:
|
| 129 |
+
log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
|
| 130 |
+
return log_attention
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class RecAE_2(Model):
|
| 134 |
+
|
| 135 |
+
def __init__(self, head, bottle, decoder):
|
| 136 |
+
super(RecAE_2, self).__init__()
|
| 137 |
+
self.head = head
|
| 138 |
+
self.bottle = bottle
|
| 139 |
+
# self.base = clone_model(bottle)
|
| 140 |
+
self.base = self.bottle
|
| 141 |
+
self.decoder = decoder
|
| 142 |
+
self.segmentation_network = create_conv_model((64, 64, 1))
|
| 143 |
+
self.control = LSTMCell(64)
|
| 144 |
+
self.memory = LSTMCell(64)
|
| 145 |
+
|
| 146 |
+
def call(self, inputs):
|
| 147 |
+
feature = self.head(inputs)
|
| 148 |
+
control_base = self.base(feature)
|
| 149 |
+
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
| 150 |
+
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
| 151 |
+
shape = K.shape(feature)[:-1]
|
| 152 |
+
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
| 153 |
+
full_image = tf.zeros(K.shape(inputs))
|
| 154 |
+
big_masks = []
|
| 155 |
+
masks = []
|
| 156 |
+
ff = tf.zeros(K.shape(inputs))
|
| 157 |
+
scope = tf.ones(shape)[..., tf.newaxis]
|
| 158 |
+
for i in range(4):
|
| 159 |
+
if i ==3:
|
| 160 |
+
mask = scope
|
| 161 |
+
else:
|
| 162 |
+
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
| 163 |
+
query = broadcast(h_c[0], feature.shape[1:])
|
| 164 |
+
log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
|
| 165 |
+
attention = K.sigmoid(log_attention)
|
| 166 |
+
mask = attention * scope
|
| 167 |
+
scope = scope - mask
|
| 168 |
+
masks.append(mask)
|
| 169 |
+
im = feature * mask
|
| 170 |
+
# im = feature
|
| 171 |
+
latent = self.bottle(im)
|
| 172 |
+
decoded = self.decoder(latent)
|
| 173 |
+
# self.add_loss(K.mean(-mse(scope, mask)))
|
| 174 |
+
sum = K.sum(tf.ones(K.shape(mask)))
|
| 175 |
+
self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
|
| 176 |
+
# self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
|
| 177 |
+
for m in masks:
|
| 178 |
+
self.add_loss(K.mean(-mse(m,mask)))
|
| 179 |
+
|
| 180 |
+
full_attention += mask
|
| 181 |
+
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
| 182 |
+
ff += K.sigmoid(decoded)
|
| 183 |
+
full_image += K.sigmoid(decoded) * big_mask
|
| 184 |
+
r_m, h_m = self.memory(latent, h_m)
|
| 185 |
+
big_masks.append(big_mask)
|
| 186 |
+
self.add_loss(K.mean(mse(inputs, full_image)))
|
| 187 |
+
return full_image, big_masks
|
raven_utils/models/attn2.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from tensorflow.keras import backend as K
|
| 5 |
+
from tensorflow.keras.layers import LSTMCell
|
| 6 |
+
from tensorflow.keras.models import Model
|
| 7 |
+
from tensorflow.keras.layers import Conv2D, Dense
|
| 8 |
+
from tensorflow.keras.losses import mse
|
| 9 |
+
from tensorflow.keras.models import clone_model
|
| 10 |
+
from tensorflow.layers.base import InputSpec, Layer
|
| 11 |
+
|
| 12 |
+
from models.dense import create_conv_model
|
| 13 |
+
from models.utils import broadcast
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ReflectionPadding2D(Layer):
|
| 17 |
+
def __init__(self, padding=(1, 1), **kwargs):
|
| 18 |
+
self.padding = tuple(padding)
|
| 19 |
+
self.input_spec = [InputSpec(ndim=4)]
|
| 20 |
+
super(ReflectionPadding2D, self).__init__(**kwargs)
|
| 21 |
+
|
| 22 |
+
def compute_output_shape(self, s):
|
| 23 |
+
""" If you are using "channels_last" configuration"""
|
| 24 |
+
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
| 25 |
+
|
| 26 |
+
def call(self, x, mask=None):
|
| 27 |
+
w_pad, h_pad = self.padding
|
| 28 |
+
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Conv2Ref(Layer):
|
| 32 |
+
def __init__(self, padding=(1, 1), **kwargs):
|
| 33 |
+
self.padding = tuple(padding)
|
| 34 |
+
self.input_spec = [InputSpec(ndim=4)]
|
| 35 |
+
super(ReflectionPadding2D, self).__init__(**kwargs)
|
| 36 |
+
|
| 37 |
+
def compute_output_shape(self, s):
|
| 38 |
+
""" If you are using "channels_last" configuration"""
|
| 39 |
+
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
|
| 40 |
+
|
| 41 |
+
def call(self, x, mask=None):
|
| 42 |
+
w_pad, h_pad = self.padding
|
| 43 |
+
return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SegmentationNetwork(Model):
|
| 47 |
+
|
| 48 |
+
def __init__(self, filters=64, kernels=(3, 3)):
|
| 49 |
+
super(RecAE, self).__init__()
|
| 50 |
+
self.conv_1 = Conv2D(filters, kernels)
|
| 51 |
+
self.conv_2 = Conv2D(filters, kernels)
|
| 52 |
+
|
| 53 |
+
def call(self, inputs):
|
| 54 |
+
x = K.relu(inputs)
|
| 55 |
+
x = self.conv_1(x)
|
| 56 |
+
x = K.relu(x)
|
| 57 |
+
x = self.conv_2(x)
|
| 58 |
+
return x + inputs
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class QueryNetwork(Model):
|
| 62 |
+
|
| 63 |
+
def __init__(self, units=64):
|
| 64 |
+
super(RecAE, self).__init__()
|
| 65 |
+
self.conv_1 = Dense(units)
|
| 66 |
+
self.conv_2 = Dense(units)
|
| 67 |
+
|
| 68 |
+
def call(self, inputs):
|
| 69 |
+
x = K.relu(inputs)
|
| 70 |
+
x = self.conv_1(x)
|
| 71 |
+
x = K.relu(x)
|
| 72 |
+
x = self.conv_2(x)
|
| 73 |
+
return x + inputs
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class RecAE(Model):
|
| 77 |
+
|
| 78 |
+
def __init__(self, head, bottle, decoder):
|
| 79 |
+
super(RecAE, self).__init__()
|
| 80 |
+
self.head = head
|
| 81 |
+
self.bottle = bottle
|
| 82 |
+
self.base = clone_model(bottle)
|
| 83 |
+
self.decoder = decoder
|
| 84 |
+
self.segmentation_network = SegmentationNetwork()
|
| 85 |
+
self.query_network = QueryNetwork()
|
| 86 |
+
self.control = LSTMCell(64)
|
| 87 |
+
self.memory = LSTMCell(64)
|
| 88 |
+
|
| 89 |
+
def call(self, inputs):
|
| 90 |
+
feature = self.head(inputs)
|
| 91 |
+
segmentation = self.segmentation_network(feature)
|
| 92 |
+
control_base = self.base(feature)
|
| 93 |
+
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
| 94 |
+
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
| 95 |
+
shape = K.shape(feature)[:-1]
|
| 96 |
+
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
| 97 |
+
full_image = tf.zeros(K.shape(inputs))
|
| 98 |
+
masks = []
|
| 99 |
+
ff = tf.zeros(K.shape(inputs))
|
| 100 |
+
scope = tf.ones(shape)[..., tf.newaxis]
|
| 101 |
+
for i in range(10):
|
| 102 |
+
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
| 103 |
+
query = self.query_network(h_c[0])
|
| 104 |
+
log_attention = image_attention(segmentation, query)
|
| 105 |
+
attention = K.softmax(log_attention)
|
| 106 |
+
mask = attention * scope
|
| 107 |
+
scope = scope - mask
|
| 108 |
+
im = feature * mask
|
| 109 |
+
# im = feature
|
| 110 |
+
latent = self.bottle(im)
|
| 111 |
+
decoded = self.decoder(latent)
|
| 112 |
+
# self.add_loss(K.mean(-mse(full_attention, attention)))
|
| 113 |
+
# self.add_loss(K.mean(-mse(tf.ones(attention.shape), attention)))
|
| 114 |
+
full_attention += attention
|
| 115 |
+
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
| 116 |
+
ff += K.sigmoid(decoded)
|
| 117 |
+
full_image += K.sigmoid(decoded) * big_mask
|
| 118 |
+
r_m, h_m = self.memory(latent, h_m)
|
| 119 |
+
masks.append(big_mask)
|
| 120 |
+
self.add_loss(K.mean(mse(inputs, full_image)))
|
| 121 |
+
return full_image, masks
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# def image_attention(image, query, scale=True):
|
| 125 |
+
@tf.function
|
| 126 |
+
def image_attention(image, query):
|
| 127 |
+
log_attention = K.sum(query[:, tf.newaxis, tf.newaxis, :] * image, axis=-1, keepdims=True)
|
| 128 |
+
# if scale is not None:
|
| 129 |
+
log_attention /= tf.sqrt(tf.cast(K.shape(image)[-1], dtype=float))
|
| 130 |
+
return log_attention
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class RecAE_2(Model):
|
| 134 |
+
|
| 135 |
+
def __init__(self, head, bottle, decoder):
|
| 136 |
+
super(RecAE_2, self).__init__()
|
| 137 |
+
self.head = head
|
| 138 |
+
self.bottle = bottle
|
| 139 |
+
# self.base = clone_model(bottle)
|
| 140 |
+
self.base = self.bottle
|
| 141 |
+
self.decoder = decoder
|
| 142 |
+
self.segmentation_network = create_conv_model((64, 64, 1))
|
| 143 |
+
self.control = LSTMCell(64)
|
| 144 |
+
self.memory = LSTMCell(64)
|
| 145 |
+
|
| 146 |
+
def call(self, inputs):
|
| 147 |
+
feature = self.head(inputs)
|
| 148 |
+
control_base = self.base(feature)
|
| 149 |
+
h_c = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
| 150 |
+
h_m = [tf.random.normal([K.shape(inputs)[0], self.control.units])] * 2
|
| 151 |
+
shape = K.shape(feature)[:-1]
|
| 152 |
+
full_attention = tf.zeros(shape)[..., tf.newaxis]
|
| 153 |
+
full_image = tf.zeros(K.shape(inputs))
|
| 154 |
+
big_masks = []
|
| 155 |
+
masks = []
|
| 156 |
+
ff = tf.zeros(K.shape(inputs))
|
| 157 |
+
scope = tf.ones(shape)[..., tf.newaxis]
|
| 158 |
+
for i in range(4):
|
| 159 |
+
if i ==3:
|
| 160 |
+
mask = scope
|
| 161 |
+
else:
|
| 162 |
+
r_c, h_c = self.control(tf.concat([control_base, h_m[0]], 1), h_c)
|
| 163 |
+
query = broadcast(h_c[0], feature.shape[1:])
|
| 164 |
+
log_attention = self.segmentation_network(tf.concat([feature, query], axis=-1))
|
| 165 |
+
attention = K.sigmoid(log_attention)
|
| 166 |
+
mask = attention * scope
|
| 167 |
+
scope = scope - mask
|
| 168 |
+
masks.append(mask)
|
| 169 |
+
im = feature * mask
|
| 170 |
+
# im = feature
|
| 171 |
+
latent = self.bottle(im)
|
| 172 |
+
decoded = self.decoder(latent)
|
| 173 |
+
# self.add_loss(K.mean(-mse(scope, mask)))
|
| 174 |
+
sum = K.sum(tf.ones(K.shape(mask)))
|
| 175 |
+
self.add_loss(K.abs((sum/4)-K.sum(mask))/sum)
|
| 176 |
+
# self.add_loss(K.mean(-mse(tf.zeros(K.shape(mask)), mask)))
|
| 177 |
+
for m in masks:
|
| 178 |
+
self.add_loss(K.mean(-mse(m,mask)))
|
| 179 |
+
|
| 180 |
+
full_attention += mask
|
| 181 |
+
big_mask = tf.image.resize(mask, K.shape(inputs)[1:-1])
|
| 182 |
+
ff += K.sigmoid(decoded)
|
| 183 |
+
full_image += K.sigmoid(decoded) * big_mask
|
| 184 |
+
r_m, h_m = self.memory(latent, h_m)
|
| 185 |
+
big_masks.append(big_mask)
|
| 186 |
+
self.add_loss(K.mean(mse(inputs, full_image)))
|
| 187 |
+
return full_image, big_masks
|
raven_utils/models/augment.py
ADDED
|
File without changes
|
raven_utils/models/body.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from ml_utils import self_product, lw
|
| 5 |
+
|
| 6 |
+
from models_utils import DictModel, ListModel, Flat, bm, Base, Cat, Res, Flat2, conv, KERNEL_SIZE, FILTERS, SAME, \
|
| 7 |
+
Get, SM, bs, RELU, ACTIVATION, dense, bd, HardBlock, MaxBlock
|
| 8 |
+
import models_utils.ops as K
|
| 9 |
+
from models_utils import Merge, SoftBlock
|
| 10 |
+
from models_utils.build import build_multi_dense, build_multi_conv, build_conv_model, build_encoder
|
| 11 |
+
from tensorflow.keras.layers import Lambda, Dense
|
| 12 |
+
from tensorflow.keras.layers import Conv2D
|
| 13 |
+
|
| 14 |
+
from config.constant import MEMORY, CONTROL, LATENT, MERGE, CONCAT, INFERENCE, FLATTEN
|
| 15 |
+
from models_utils.config import config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RavRes(Res):
|
| 19 |
+
def __init__(self, model="v2", latent=256, act=RELU):
|
| 20 |
+
super().__init__(model=model)
|
| 21 |
+
self.latent = latent
|
| 22 |
+
|
| 23 |
+
def call(self, inputs):
|
| 24 |
+
return self.model(inputs) + inputs[0][:, ..., self.latent:]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# not working
|
| 28 |
+
class RavResConv(Res):
|
| 29 |
+
def __init__(self, model="v2", latent=256, act=RELU):
|
| 30 |
+
super().__init__(model=model)
|
| 31 |
+
self.latent = latent
|
| 32 |
+
self.conv = conv(latent, (1, 1), activation=act)
|
| 33 |
+
|
| 34 |
+
def call(self, inputs):
|
| 35 |
+
return self.model(inputs) + self.conv(inputs[0])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class RavResDense(Res):
|
| 39 |
+
def __init__(self, model="v2", latent=256, act=config.DEF_DENSE.activation):
|
| 40 |
+
super().__init__(model=model)
|
| 41 |
+
self.latent = latent
|
| 42 |
+
self.conv = dense(latent, activation=act)
|
| 43 |
+
|
| 44 |
+
def call(self, inputs):
|
| 45 |
+
return self.model(inputs) + self.conv(inputs[0])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def create_dense_block(latent=256, loop=1):
|
| 49 |
+
soft_block = Res(SoftBlock(build_multi_dense(latent), add_identity=None,
|
| 50 |
+
score_activation=tf.sigmoid), latent=latent)
|
| 51 |
+
cells = [
|
| 52 |
+
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
|
| 53 |
+
(None, CONCAT, MEMORY),
|
| 54 |
+
(Dense(latent), CONCAT, MERGE),
|
| 55 |
+
(Merge(latent), [INFERENCE, MERGE], CONTROL),
|
| 56 |
+
(soft_block, [MEMORY, CONTROL], MEMORY)
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
return ListModel([DictModel(*cell) for cell in cells] * loop, [LATENT, INFERENCE], MEMORY)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def build_multi_conv(filters=32, end_filters=64, padding="same",mul=1, norm=None, **kwargs):
|
| 63 |
+
base = [(1, 3), (3, 1), (3, 3)]
|
| 64 |
+
block = list(self_product(base))
|
| 65 |
+
block2 = [b + b[0:1] for b in block]
|
| 66 |
+
block3 = [b + b for b in block]
|
| 67 |
+
block4 = ([[(3, 3)]] + [[(3, 3), (3, 3)]] + [[(3, 3), (3, 3), (3, 3)]]) * 2
|
| 68 |
+
block5 = [[], []]
|
| 69 |
+
all_blocks = [s for b in [block, block2, block3, block4, block5] for s in b]
|
| 70 |
+
start = {
|
| 71 |
+
FILTERS: filters,
|
| 72 |
+
KERNEL_SIZE: (1, 1)
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
end = {
|
| 76 |
+
FILTERS: end_filters,
|
| 77 |
+
KERNEL_SIZE: (1, 1),
|
| 78 |
+
ACTIVATION: None
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
all_arch = []
|
| 82 |
+
for ab in all_blocks:
|
| 83 |
+
arch = [{
|
| 84 |
+
FILTERS: filters,
|
| 85 |
+
KERNEL_SIZE: a,
|
| 86 |
+
**kwargs
|
| 87 |
+
} for a in ab]
|
| 88 |
+
all_arch.append([start] + arch + [end])
|
| 89 |
+
|
| 90 |
+
all_arch = all_arch * mul
|
| 91 |
+
|
| 92 |
+
return [
|
| 93 |
+
build_encoder(a, add_norm=norm if norm else None, padding=padding, name=f"b{i}", order=(1, 0) if norm else None)
|
| 94 |
+
for i, a in enumerate(all_arch)]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def create_block(latent=256, simpler=0, loop=1, padding=SAME, norm=None, trans_div=2, act="pass", type_="conv",
|
| 98 |
+
block_=SoftBlock,max_k=16,
|
| 99 |
+
**kwargs):
|
| 100 |
+
trans_size = int(latent / trans_div)
|
| 101 |
+
# if block_ == HardBlock:
|
| 102 |
+
# mul = 2
|
| 103 |
+
# elif block_ == MaxBlock:
|
| 104 |
+
# mul = int(38/max_k)
|
| 105 |
+
# else:
|
| 106 |
+
# mul = 1
|
| 107 |
+
|
| 108 |
+
if act == "pass":
|
| 109 |
+
res_class = RavRes
|
| 110 |
+
else:
|
| 111 |
+
if type_ == "dense":
|
| 112 |
+
res_class = RavResDense
|
| 113 |
+
else:
|
| 114 |
+
res_class = RavResConv
|
| 115 |
+
|
| 116 |
+
if type_ == "dense":
|
| 117 |
+
build_res = lambda: Res(model="dv2")
|
| 118 |
+
# build_reduction = lambda: bm([dense(latent), "IN"])
|
| 119 |
+
build_reduction = lambda: dense(latent)
|
| 120 |
+
build_flatten = lambda: bd([latent] * 2)
|
| 121 |
+
else:
|
| 122 |
+
build_res = lambda: Res(padding=padding)
|
| 123 |
+
build_reduction = lambda: bm([conv(trans_size if simpler else latent, 1, padding=padding), "BN"])
|
| 124 |
+
# build_reduction = lambda: bm([conv(latent, 1, padding=padding), "BN"])
|
| 125 |
+
# build_reduction = lambda: bm([conv(trans_size, 1, padding=padding), "BN"])
|
| 126 |
+
# build_reduction = lambda: conv(trans_size, 1, padding=padding)
|
| 127 |
+
# build_flatten = lambda: Flat2(filters=trans_size,res_no=2, padding=padding, units=64)
|
| 128 |
+
build_flatten = lambda: Flat2(filters=trans_size,padding=padding, units=64)
|
| 129 |
+
|
| 130 |
+
if simpler == 1:
|
| 131 |
+
cells = [
|
| 132 |
+
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT,"concatenation"),
|
| 133 |
+
# (None, CONCAT, MEMORY),
|
| 134 |
+
(build_reduction(), CONCAT, MERGE,"Start_resnet_block"),
|
| 135 |
+
# (Get(), INFERENCE, INFERENCE),
|
| 136 |
+
(K.cat, [INFERENCE, MERGE], CONTROL,"concatenation"),
|
| 137 |
+
]
|
| 138 |
+
else:
|
| 139 |
+
cells = [
|
| 140 |
+
(lambda x: K.cat([x[:, 0], x[:, 1]]), LATENT, CONCAT),
|
| 141 |
+
(build_reduction(), CONCAT, MEMORY),
|
| 142 |
+
(build_reduction(), INFERENCE, CONTROL),
|
| 143 |
+
]
|
| 144 |
+
for i, l in enumerate(lw(loop)):
|
| 145 |
+
if l:
|
| 146 |
+
concat = K.cat
|
| 147 |
+
control_reduction = build_reduction()
|
| 148 |
+
control_res = build_res()
|
| 149 |
+
control_flatten = build_flatten()
|
| 150 |
+
if i == 0 and simpler == 1:
|
| 151 |
+
rest_params = {
|
| 152 |
+
"latent": latent,
|
| 153 |
+
"act": act
|
| 154 |
+
}
|
| 155 |
+
else:
|
| 156 |
+
rest_params = {
|
| 157 |
+
"latent": 0
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if block_ == SoftBlock:
|
| 162 |
+
block_params = {
|
| 163 |
+
}
|
| 164 |
+
else:
|
| 165 |
+
block_params = {
|
| 166 |
+
"trans_output_shape": latent
|
| 167 |
+
}
|
| 168 |
+
if block_ == MaxBlock:
|
| 169 |
+
block_params['max_k'] = max_k
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# todo change name
|
| 173 |
+
soft_block = res_class(
|
| 174 |
+
block_(
|
| 175 |
+
build_multi_dense(latent) if type_ == "dense" else build_multi_conv(trans_size, end_filters=latent,
|
| 176 |
+
norm=norm, padding=padding,
|
| 177 |
+
**kwargs),
|
| 178 |
+
add_identity=None,
|
| 179 |
+
score_activation=tf.sigmoid,
|
| 180 |
+
**block_params
|
| 181 |
+
|
| 182 |
+
),
|
| 183 |
+
**rest_params)
|
| 184 |
+
|
| 185 |
+
if i == 0 and simpler == 1:
|
| 186 |
+
cells.extend([
|
| 187 |
+
(control_reduction, CONTROL, CONTROL,"Reduction"),
|
| 188 |
+
(control_res, CONTROL, CONTROL,"Control_resnet_block"),
|
| 189 |
+
(control_flatten, CONTROL, FLATTEN,"Weights"),
|
| 190 |
+
(soft_block, [CONCAT, FLATTEN], MEMORY,"Transformation"),
|
| 191 |
+
# (soft_block, [MEMORY, FLATTEN], MEMORY,"Transformation"),
|
| 192 |
+
])
|
| 193 |
+
else:
|
| 194 |
+
if l:
|
| 195 |
+
memory_res = build_res()
|
| 196 |
+
|
| 197 |
+
cells.extend([
|
| 198 |
+
(memory_res, MEMORY, MEMORY,"Memory_resnet_block"),
|
| 199 |
+
(concat, [CONTROL, MEMORY], CONTROL,"concatenation"),
|
| 200 |
+
(control_reduction, CONTROL, CONTROL,"Reduction"),
|
| 201 |
+
(control_res, CONTROL, CONTROL,"Control_resnet_block"),
|
| 202 |
+
(control_flatten, CONTROL, FLATTEN,"Weights"),
|
| 203 |
+
(soft_block, [MEMORY, FLATTEN], MEMORY, "Transformation"),
|
| 204 |
+
])
|
| 205 |
+
return ListModel([DictModel(*cell) for cell in cells], [LATENT, INFERENCE], MEMORY, debug_=False)
|
| 206 |
+
|
| 207 |
+
#
|
| 208 |
+
#
|
| 209 |
+
# def test(x):
|
| 210 |
+
# np.zeros(4)
|
| 211 |
+
# self_product((1, 3))
|
| 212 |
+
#
|
| 213 |
+
#
|
| 214 |
+
# list(itertools.product())
|
| 215 |
+
# u.layers[0].layers[-1].model.layers[1]
|
| 216 |
+
|
| 217 |
+
# class RecurrentBodyDict(Model):
|
| 218 |
+
# # def __init__(self, start=None, cell=None, output_network=None, output_activation="tanh", latent=64, loop_no=5):
|
| 219 |
+
# def __init__(self, start=None, cell=None, output_network=None, output_activation=None, latent=64, loop_no=5):
|
| 220 |
+
# super().__init__()
|
| 221 |
+
# self.start = sm(start, lambda: SubClassingModel([StartLSTMControl(latent), StartLSTMMemory(latent)]),
|
| 222 |
+
# latent=latent)
|
| 223 |
+
# self.cell = sm(cell, lambda: SubClassingModel([LSTMControl(latent), LSTMMemory(latent)]), latent=latent)
|
| 224 |
+
# self.output_network = sm(output_network, lf(take_memory_states))
|
| 225 |
+
# self.loop_no = loop_no
|
| 226 |
+
# # tmp
|
| 227 |
+
# self.activation = Activation(output_activation)
|
| 228 |
+
#
|
| 229 |
+
# def call(self, inputs):
|
| 230 |
+
# outputs = []
|
| 231 |
+
# for j in range(3):
|
| 232 |
+
# outputs.append(self.start({"latent": inputs[0][j], "inference": inputs[1]}))
|
| 233 |
+
# for i in range(self.loop_no):
|
| 234 |
+
# for j in range(3):
|
| 235 |
+
# outputs[j] = self.cell(outputs[j])
|
| 236 |
+
#
|
| 237 |
+
# return self.activation(self.output_network(outputs))
|
| 238 |
+
#
|
| 239 |
+
#
|
| 240 |
+
# class RecurrentBodySimpleMix4Dict(RecurrentBodyDict):
|
| 241 |
+
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
| 242 |
+
# super().__init__(
|
| 243 |
+
# start=SubClassingModel(
|
| 244 |
+
# [ConcatCell(), DenseCell(latent), InfMergeCell(latent),
|
| 245 |
+
# WeigthCell(latent, layer_no=np.repeat([1, 2, 3, 4, 5, 6, 7, 8], 4),
|
| 246 |
+
# add_identity=Lambda(lambda x: x[:, latent:]))]),
|
| 247 |
+
# cell=False,
|
| 248 |
+
# output_network=output_network, loop_no=0)
|
| 249 |
+
# class RecurrentBodySimpleMix4Conv(RecurrentBodyDict):
|
| 250 |
+
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
| 251 |
+
# super().__init__(
|
| 252 |
+
# start=SubClassingModel(
|
| 253 |
+
# [ConcatCell(), ConvCell(latent), ReduceCell(latent), InfMergeCell(latent),
|
| 254 |
+
# ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
|
| 255 |
+
# WeigthCell(latent,
|
| 256 |
+
# transformation_network=[build_conv_model2([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
|
| 257 |
+
# range(1, 5) for _ in range(1)],
|
| 258 |
+
# add_identity=Lambda(lambda x: x[:, ..., latent:]))
|
| 259 |
+
# ]),
|
| 260 |
+
# cell=False,
|
| 261 |
+
# output_network=output_network, loop_no=0)
|
| 262 |
+
#
|
| 263 |
+
#
|
| 264 |
+
# class RecurrentBodySimpleMix4Conv2(RecurrentBodyDict):
|
| 265 |
+
# def __init__(self, latent=64, output_network=None, loop_no=5):
|
| 266 |
+
# super().__init__(
|
| 267 |
+
# start=SubClassingModel(
|
| 268 |
+
# [ConcatCell(), ConvCell(latent), ReduceCell2(latent), InfMergeCell(latent),
|
| 269 |
+
# ModelCell(latent=latent, layers_no=2, input_name=CONTROL, result_name=CONTROL),
|
| 270 |
+
# WeigthCell(latent,
|
| 271 |
+
# transformation_network=[bc([latent] * i, kernels=(j, j)) for i in range(1, 7) for j in
|
| 272 |
+
# range(1, 5) for _ in range(1)],
|
| 273 |
+
# add_identity=Lambda(lambda x: x[:, ..., latent:]))
|
| 274 |
+
# ]),
|
| 275 |
+
# cell=False,
|
| 276 |
+
# output_network=output_network, loop_no=0)
|
raven_utils/models/class_.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ml_utils import lw
|
| 2 |
+
from models_utils import SubClassingModel, ops as K, Base
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Merge(SubClassingModel):
|
| 7 |
+
def call(self, inputs):
|
| 8 |
+
results = []
|
| 9 |
+
for i, model in enumerate(self.model[:-1]):
|
| 10 |
+
results.append(model(inputs[i]))
|
| 11 |
+
# todo why K.cat not working
|
| 12 |
+
results = self.model[-1](tf.concat(results, axis=-1))
|
| 13 |
+
return results
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RavenClass(Base):
|
| 17 |
+
def __init__(self, model, scales=None, no=3, name=None):
|
| 18 |
+
super().__init__(model=model, name=name)
|
| 19 |
+
self.scales = scales
|
| 20 |
+
self.no = no
|
| 21 |
+
|
| 22 |
+
def call(self, inputs):
|
| 23 |
+
inputs = lw(inputs)
|
| 24 |
+
class_res = []
|
| 25 |
+
# for i in range(inputs[0].shape[1]):
|
| 26 |
+
for i in range(self.no):
|
| 27 |
+
# d = [r[:, i] if r.ndim == 5 else r for r in inputs]
|
| 28 |
+
d = [inputs[s][:, i] if inputs[s].ndim > 2 else inputs for s in self.scales]
|
| 29 |
+
class_res.append(self.model(d))
|
| 30 |
+
# return tf.stack(class_res,axis=1)
|
| 31 |
+
return [class_res]
|
raven_utils/models/head.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from ml_utils import set_default
|
| 3 |
+
from models_utils import build_dense_model, bm, ActivationModel, sm, large_conv_dense_encoder, Pass
|
| 4 |
+
from models_utils import res
|
| 5 |
+
from tensorflow.keras import Model
|
| 6 |
+
from models_utils import ops as K
|
| 7 |
+
from tensorflow.keras.layers import Dense, Conv2D, Flatten
|
| 8 |
+
from keras.backend import batch_flatten
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# todo Refactoring
|
| 12 |
+
class HeadModel(Model):
|
| 13 |
+
def __init__(self, encoder=None, inference_network=None, output_size=64, inference_output_size=None,
|
| 14 |
+
inference_activation="relu", stem=None, images_no=8, inference_image_no=None):
|
| 15 |
+
super().__init__()
|
| 16 |
+
# self.encoder = sm(encoder, bm([en.large_conv_dense_encoder(), Dense(output_size)], False))
|
| 17 |
+
self.encoder = encoder or bm([large_conv_dense_encoder(), Dense(output_size)])
|
| 18 |
+
# self.head = head or HeadBatch(encoder=encoder, output_size=output_size)
|
| 19 |
+
inference_output_size = inference_output_size or output_size
|
| 20 |
+
self.inference_network = inference_network or bm([
|
| 21 |
+
K.flat,
|
| 22 |
+
build_dense_model([1028, 512, 512, inference_output_size],
|
| 23 |
+
last_activation=inference_activation)]
|
| 24 |
+
)
|
| 25 |
+
self.stem = stem or Pass()
|
| 26 |
+
self.images_no = images_no
|
| 27 |
+
self.inference_image_no = self.images_no if inference_image_no is None else inference_image_no
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LatentHeadModel(HeadModel):
|
| 31 |
+
def call(self, inputs):
|
| 32 |
+
result = K.map_batch(inputs[:, :self.images_no], self.encoder)
|
| 33 |
+
inference = self.inference_network(result[:, :self.inference_image_no])
|
| 34 |
+
latents = self.stem(result)
|
| 35 |
+
return [latents, inference,result]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# # todo use map_batch
|
| 39 |
+
# class HeadBatch(Model):
|
| 40 |
+
# def __init__(self, encoder=None, output_size=64):
|
| 41 |
+
# super().__init__()
|
| 42 |
+
# self.encoder = sm(encoder, bm([large_conv_dense_encoder(), Dense(output_size)], False))
|
| 43 |
+
#
|
| 44 |
+
# def call(self, inputs):
|
| 45 |
+
# shape = tf.shape(inputs)
|
| 46 |
+
# latents = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
| 47 |
+
# latents = K.reshape(latents, tf.concat([[-1, shape[1]], latents.shape[1:]], axis=-1))
|
| 48 |
+
# return latents
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Not working
|
| 52 |
+
class DuoHeadModel(HeadModel):
|
| 53 |
+
def __init__(self, encoder=None, inference_network=None, images_no=8, filters=-4):
|
| 54 |
+
super().__init__(encoder=encoder, inference_network=inference_network, images_no=images_no)
|
| 55 |
+
self.encoder = ActivationModel(self.encoder, filters=filters, include_input=False)
|
| 56 |
+
|
| 57 |
+
def call(self, inputs):
|
| 58 |
+
shape = inputs.shape
|
| 59 |
+
result = reversed(self.encoder(K.reshape(inputs, shape=[-1] + list(shape[2:]))))
|
| 60 |
+
latents = K.reshape(result[0], [-1, self.images_no] + [result[0].shape[-1]])
|
| 61 |
+
inference = self.inference_network(K.flat(result[1]))
|
| 62 |
+
return [latents, inference]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MultiHeadModel(Model):
|
| 66 |
+
def __init__(self, encoder=None, images_no=8, filters=(1, 3, 6)):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
|
| 69 |
+
self.merge = MergeSacles()
|
| 70 |
+
self.images_no = images_no
|
| 71 |
+
|
| 72 |
+
def call(self, inputs):
|
| 73 |
+
shape = tf.shape(inputs)
|
| 74 |
+
results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
| 75 |
+
latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
|
| 76 |
+
in results]
|
| 77 |
+
|
| 78 |
+
l1 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
|
| 79 |
+
# l1 = tf.reshape(l1, tuple(list(l1.shape[:3]) + [l1.shape[-2] * l1.shape[-1]]))
|
| 80 |
+
shape = tf.shape(l1)
|
| 81 |
+
l1 = tf.reshape(l1, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
| 82 |
+
|
| 83 |
+
l2 = tf.transpose(latents[1], (0, 2, 3, 1, 4))
|
| 84 |
+
# l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
|
| 85 |
+
shape = tf.shape(l2)
|
| 86 |
+
l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
| 87 |
+
|
| 88 |
+
l3 = latents[2]
|
| 89 |
+
shape = tf.shape(l3)
|
| 90 |
+
# l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
|
| 91 |
+
l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
|
| 92 |
+
|
| 93 |
+
inference = self.merge([l1, l2, l3])
|
| 94 |
+
return [latents, inference]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class MergeSacles(Model):
|
| 98 |
+
def __init__(self):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.inf_1 = bm([Conv2D(64, 1, activation="relu"), res(64),
|
| 101 |
+
Conv2D(64, 3, strides=2, padding=SAME, activation="relu"),
|
| 102 |
+
res(64),
|
| 103 |
+
Flatten(),
|
| 104 |
+
Dense(256, "relu")])
|
| 105 |
+
self.inf_2 = bm([Conv2D(128, 1, activation="relu"),
|
| 106 |
+
res(128),
|
| 107 |
+
Flatten(),
|
| 108 |
+
Dense(256, "relu")])
|
| 109 |
+
self.inf_3 = Dense(256, "relu")
|
| 110 |
+
|
| 111 |
+
def call(self, inputs):
|
| 112 |
+
il1 = self.inf_1(inputs[0])
|
| 113 |
+
il2 = self.inf_2(inputs[1])
|
| 114 |
+
il3 = self.inf_3(inputs[2])
|
| 115 |
+
inference = tf.concat([il1, il2, il3], axis=1)
|
| 116 |
+
return inference
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class MultiHeadModel2(Model):
|
| 120 |
+
def __init__(self, encoder=None, images_no=8, filters=(3, 6)):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.encoder = ActivationModel(encoder, filters=filters, include_input=False)
|
| 123 |
+
self.merge = MergeSacles2()
|
| 124 |
+
self.images_no = images_no
|
| 125 |
+
|
| 126 |
+
def call(self, inputs):
|
| 127 |
+
shape = tf.shape(inputs)
|
| 128 |
+
results = self.encoder(tf.reshape(inputs, shape=tf.concat([[-1], shape[2:]], axis=-1)))
|
| 129 |
+
latents = [tf.reshape(result, shape=tf.concat([[-1, self.images_no], tf.shape(result)[1:]], axis=-1)) for result
|
| 130 |
+
in results]
|
| 131 |
+
|
| 132 |
+
l2 = tf.transpose(latents[0], (0, 2, 3, 1, 4))
|
| 133 |
+
# l2 = tf.reshape(l2, [-1] + list(l2.shape[1:3]) + [l2.shape[-2] * l2.shape[-1]])
|
| 134 |
+
shape = tf.shape(l2)
|
| 135 |
+
l2 = tf.reshape(l2, tf.concat([[-1], shape[1:3], [shape[-2] * shape[-1]]], axis=-1))
|
| 136 |
+
|
| 137 |
+
l3 = latents[1]
|
| 138 |
+
shape = tf.shape(l3)
|
| 139 |
+
# l3 = tf.reshape(l3, [-1] + [l3.shape[-2] * l3.shape[-1]])
|
| 140 |
+
l3 = tf.reshape(l3, tf.concat([[-1], [shape[-2] * shape[-1]]], axis=-1))
|
| 141 |
+
|
| 142 |
+
inference = self.merge([l2, l3])
|
| 143 |
+
return [latents, inference]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class MergeSacles2(Model):
|
| 147 |
+
def __init__(self):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.inf_1 = bm([Conv2D(128, 1, activation="relu"),
|
| 150 |
+
res(128),
|
| 151 |
+
Flatten(),
|
| 152 |
+
Dense(256, "relu")])
|
| 153 |
+
self.inf_2 = Dense(256, "relu")
|
| 154 |
+
|
| 155 |
+
def call(self, inputs):
|
| 156 |
+
il1 = self.inf_1(inputs[0])
|
| 157 |
+
il2 = self.inf_2(inputs[1])
|
| 158 |
+
inference = tf.concat([il1, il2], axis=1)
|
| 159 |
+
return inference
|
raven_utils/models/loss.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import tensorflow.experimental.numpy as tnp
|
| 5 |
+
from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
|
| 6 |
+
from models_utils import SubClassingModel
|
| 7 |
+
from models_utils.models.utils import interleave
|
| 8 |
+
from models_utils.op import reshape
|
| 9 |
+
from tensorflow.keras import Model
|
| 10 |
+
# from tensorflow.keras import backend as K
|
| 11 |
+
from tensorflow.keras.layers import Lambda
|
| 12 |
+
from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
|
| 13 |
+
from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
|
| 14 |
+
import models_utils.ops as K
|
| 15 |
+
|
| 16 |
+
import raven_utils.decode
|
| 17 |
+
import raven_utils as rv
|
| 18 |
+
from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
|
| 19 |
+
SLOT, \
|
| 20 |
+
PROPERTIES, ACC, GROUP, NUMBER, MASK
|
| 21 |
+
from raven_utils.models.uitls_ import RangeMask
|
| 22 |
+
from raven_utils.const import VERTICAL, HORIZONTAL
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_properties_mask(target):
|
| 26 |
+
return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_change_mask(target):
|
| 30 |
+
properties_mask = get_properties_mask(target)
|
| 31 |
+
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_uniform_mask(target):
|
| 35 |
+
u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
|
| 36 |
+
properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
|
| 37 |
+
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def create_all_mask(target):
|
| 41 |
+
return [
|
| 42 |
+
tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
|
| 43 |
+
enumerate(rv.rules.ATTRIBUTES)]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BaselineClassificationLossModel(Model):
|
| 47 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
|
| 50 |
+
self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 51 |
+
group_loss=group_loss)
|
| 52 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
| 53 |
+
|
| 54 |
+
def call(self, inputs):
|
| 55 |
+
losses = []
|
| 56 |
+
output = inputs[1]
|
| 57 |
+
losses.append(self.loss_fn([inputs[0][0], output]))
|
| 58 |
+
losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
|
| 59 |
+
return losses
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RavenLoss(Model):
|
| 63 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
|
| 64 |
+
classification=False, trans=True, anneal=False):
|
| 65 |
+
super().__init__()
|
| 66 |
+
if anneal:
|
| 67 |
+
self.weight_scheduler
|
| 68 |
+
self.classification = classification
|
| 69 |
+
self.trans = trans
|
| 70 |
+
self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
|
| 71 |
+
out=[PREDICT, MASK], name="pred")
|
| 72 |
+
if self.trans:
|
| 73 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 74 |
+
group_loss=group_loss, enable_metrics=False, lw=lw[0]),
|
| 75 |
+
name="main_loss")
|
| 76 |
+
self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 77 |
+
group_loss=group_loss), name="add_loss")
|
| 78 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
| 79 |
+
if self.classification:
|
| 80 |
+
self.loss_fn_3 = add_loss(
|
| 81 |
+
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
| 82 |
+
group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
|
| 83 |
+
name="class_loss")
|
| 84 |
+
|
| 85 |
+
def call(self, inputs):
|
| 86 |
+
losses = []
|
| 87 |
+
output = inputs[OUTPUT]
|
| 88 |
+
target = inputs[TARGET]
|
| 89 |
+
labels = inputs[LABELS]
|
| 90 |
+
|
| 91 |
+
if self.trans:
|
| 92 |
+
losses.append(self.loss_fn([labels[:, 2], output[0]]))
|
| 93 |
+
losses.append(self.loss_fn([labels[:, 5], output[1]]))
|
| 94 |
+
losses.append(self.loss_fn_2([target, output[2]]))
|
| 95 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
| 96 |
+
if self.classification:
|
| 97 |
+
for i in range(8):
|
| 98 |
+
losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
|
| 99 |
+
return {**inputs, LOSS: losses}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class VTRavenLoss(Model):
|
| 103 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
| 104 |
+
classification=False, trans=True, anneal=False, plw=None):
|
| 105 |
+
super().__init__()
|
| 106 |
+
if anneal:
|
| 107 |
+
self.weight_scheduler
|
| 108 |
+
self.classification = classification
|
| 109 |
+
self.trans = trans
|
| 110 |
+
self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
|
| 111 |
+
out=[PREDICT, MASK], name="pred")
|
| 112 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 113 |
+
group_loss=group_loss, plw=plw), lw=lw[0] , name="add_loss")
|
| 114 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
| 115 |
+
if self.classification:
|
| 116 |
+
self.loss_fn_2 = add_loss(
|
| 117 |
+
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
| 118 |
+
group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
|
| 119 |
+
|
| 120 |
+
def call(self, inputs):
|
| 121 |
+
losses = []
|
| 122 |
+
output = inputs[OUTPUT]
|
| 123 |
+
target = inputs[TARGET]
|
| 124 |
+
labels = inputs[LABELS]
|
| 125 |
+
|
| 126 |
+
for i in range(9):
|
| 127 |
+
losses.append(self.loss_fn_2([labels[:, i], output[:, i]]))
|
| 128 |
+
losses.append(self.loss_fn([target, output[:, 8]]))
|
| 129 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
| 130 |
+
return {**inputs, LOSS: losses}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class SingleVTRavenLoss(Model):
|
| 134 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
| 135 |
+
classification=False, trans=True, anneal=False):
|
| 136 |
+
super().__init__()
|
| 137 |
+
if anneal:
|
| 138 |
+
self.weight_scheduler
|
| 139 |
+
self.classification = classification
|
| 140 |
+
self.trans = trans
|
| 141 |
+
self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
|
| 142 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 143 |
+
group_loss=group_loss), lw=lw[0], name="add_loss")
|
| 144 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
| 145 |
+
|
| 146 |
+
def call(self, inputs):
|
| 147 |
+
losses = []
|
| 148 |
+
output = inputs[OUTPUT]
|
| 149 |
+
target = inputs[TARGET]
|
| 150 |
+
labels = inputs[LABELS]
|
| 151 |
+
|
| 152 |
+
losses.append(self.loss_fn([target, output]))
|
| 153 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
| 154 |
+
return {**inputs, LOSS: losses}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class ClassRavenModel(Model):
|
| 158 |
+
def __init__(self, mode=create_all_mask,plw=None, number_loss=False, slot_loss=True, group_loss=True, enable_metrics=True,
|
| 159 |
+
lw=1.0):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.number_loss = number_loss
|
| 162 |
+
self.group_loss = group_loss
|
| 163 |
+
self.enable_metrics = enable_metrics
|
| 164 |
+
self.slot_loss = slot_loss
|
| 165 |
+
self.predict_fn = PredictModel()
|
| 166 |
+
self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
|
| 167 |
+
if self.slot_loss:
|
| 168 |
+
self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
|
| 169 |
+
if self.enable_metrics:
|
| 170 |
+
self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
|
| 171 |
+
self.metric_fn = [
|
| 172 |
+
SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
|
| 173 |
+
rv.properties.NAMES]
|
| 174 |
+
if self.group_loss:
|
| 175 |
+
self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
|
| 176 |
+
if self.slot_loss:
|
| 177 |
+
self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
|
| 178 |
+
self.range_mask = RangeMask()
|
| 179 |
+
self.mode = mode
|
| 180 |
+
self.lw = lw
|
| 181 |
+
if not plw:
|
| 182 |
+
plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
| 183 |
+
elif isinstance(plw, int) or isinstance(plw, float):
|
| 184 |
+
plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
| 185 |
+
# plw = [plw] * 6
|
| 186 |
+
self.plw = plw
|
| 187 |
+
|
| 188 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 189 |
+
|
| 190 |
+
def call(self, inputs):
|
| 191 |
+
losses = []
|
| 192 |
+
metrics = {}
|
| 193 |
+
target = inputs[0]
|
| 194 |
+
output = inputs[1]
|
| 195 |
+
|
| 196 |
+
target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
|
| 197 |
+
|
| 198 |
+
group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
|
| 199 |
+
|
| 200 |
+
# group
|
| 201 |
+
if self.group_loss:
|
| 202 |
+
group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
|
| 203 |
+
losses.append(group_loss)
|
| 204 |
+
|
| 205 |
+
if isinstance(self.enable_metrics, str):
|
| 206 |
+
group_metric = self.metric_fn_group(target_group, group_output)
|
| 207 |
+
# metrics[GROUP] = group_metric
|
| 208 |
+
self.add_metric(group_metric)
|
| 209 |
+
self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
|
| 210 |
+
|
| 211 |
+
# setting uniformity mask
|
| 212 |
+
full_properties_musks = self.mode(target)
|
| 213 |
+
|
| 214 |
+
range_mask = self.range_mask(target_group)
|
| 215 |
+
|
| 216 |
+
if self.slot_loss:
|
| 217 |
+
# number
|
| 218 |
+
number_mask = range_mask & full_properties_musks[0]
|
| 219 |
+
number_mask = tf.cast(number_mask, tf.float32)
|
| 220 |
+
target_number = tf.reduce_sum(
|
| 221 |
+
tf.cast(target_slot, "float32") * number_mask, axis=-1)
|
| 222 |
+
output_number = tf.reduce_sum(
|
| 223 |
+
tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
|
| 224 |
+
|
| 225 |
+
# output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
| 226 |
+
if self.number_loss:
|
| 227 |
+
scale = 1 / 9
|
| 228 |
+
if self.number_loss == 2:
|
| 229 |
+
output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
| 230 |
+
else:
|
| 231 |
+
output_number_2 = output_number
|
| 232 |
+
number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale, output_number_2 * scale)
|
| 233 |
+
losses.append(number_loss)
|
| 234 |
+
|
| 235 |
+
# metrics[NUMBER] = number_acc
|
| 236 |
+
|
| 237 |
+
if isinstance(self.enable_metrics, str):
|
| 238 |
+
number_acc = tf.reduce_mean(
|
| 239 |
+
tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
|
| 240 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
|
| 241 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
|
| 242 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
| 243 |
+
|
| 244 |
+
# position/slot
|
| 245 |
+
slot_mask = range_mask & full_properties_musks[1]
|
| 246 |
+
# tf.boolean_mask(target_slot,slot_mask)
|
| 247 |
+
|
| 248 |
+
if tf.reduce_any(slot_mask):
|
| 249 |
+
# if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
|
| 250 |
+
target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
|
| 251 |
+
output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
|
| 252 |
+
loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
|
| 253 |
+
self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
|
| 254 |
+
if isinstance(self.enable_metrics, str):
|
| 255 |
+
acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
|
| 256 |
+
self.add_metric(acc_slot)
|
| 257 |
+
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
|
| 258 |
+
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
| 259 |
+
else:
|
| 260 |
+
loss_slot = 0.0
|
| 261 |
+
acc_slot = -1.0
|
| 262 |
+
|
| 263 |
+
losses.append(loss_slot)
|
| 264 |
+
# metrics[SLOT] = acc_slot
|
| 265 |
+
# if loss_slot != 0:
|
| 266 |
+
|
| 267 |
+
# if tf.reduce_any(slot_mask):
|
| 268 |
+
|
| 269 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
|
| 270 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
|
| 271 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
| 272 |
+
|
| 273 |
+
# properties
|
| 274 |
+
for i, out in enumerate(outputs):
|
| 275 |
+
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
|
| 276 |
+
out_reshaped = tf.reshape(out, shape)
|
| 277 |
+
properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
|
| 278 |
+
|
| 279 |
+
if tf.reduce_any(properties_mask):
|
| 280 |
+
out_masked = tf.boolean_mask(out_reshaped, properties_mask)
|
| 281 |
+
out_target = tf.boolean_mask(target_all[i], properties_mask)
|
| 282 |
+
loss = self.lw * self.plw[3+i] * self.loss_fn(out_target, out_masked)
|
| 283 |
+
if isinstance(self.enable_metrics, str):
|
| 284 |
+
metric = self.metric_fn[i](out_target, out_masked)
|
| 285 |
+
self.add_metric(metric)
|
| 286 |
+
# self.add_metric(metric, f"{self.enable_metrics}{ACC}")
|
| 287 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
|
| 288 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
|
| 289 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
| 290 |
+
else:
|
| 291 |
+
loss = 0.0
|
| 292 |
+
metric = -1.0
|
| 293 |
+
|
| 294 |
+
losses.append(loss)
|
| 295 |
+
return losses
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class FullMask(Model):
|
| 299 |
+
def __init__(self, mode=create_uniform_mask):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.range_mask = RangeMask()
|
| 302 |
+
self.mode = mode
|
| 303 |
+
|
| 304 |
+
def call(self, inputs):
|
| 305 |
+
target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
|
| 306 |
+
full_properties_musks = self.mode(inputs)
|
| 307 |
+
range_mask = self.range_mask(target_group)
|
| 308 |
+
|
| 309 |
+
number_mask = range_mask & full_properties_musks[0]
|
| 310 |
+
|
| 311 |
+
slot_mask = range_mask & full_properties_musks[1]
|
| 312 |
+
properties_mask = []
|
| 313 |
+
for property_mask in full_properties_musks[2:]:
|
| 314 |
+
properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
|
| 315 |
+
return [slot_mask, properties_mask, number_mask]
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def create_mask(rules, i):
|
| 319 |
+
mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
|
| 320 |
+
mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
|
| 321 |
+
shape = tf.shape(rules)
|
| 322 |
+
full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
|
| 323 |
+
full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
|
| 324 |
+
return tf.transpose(full_mask_2)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# class PredictModel(Model):
|
| 328 |
+
# def __init__(self):
|
| 329 |
+
# super().__init__()
|
| 330 |
+
# self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
| 331 |
+
# self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
| 332 |
+
# self.range_mask = RangeMask()
|
| 333 |
+
#
|
| 334 |
+
# # self.predict_fn = partial(tf.argmax, axis=-1)
|
| 335 |
+
#
|
| 336 |
+
# def call(self, inputs):
|
| 337 |
+
# group_output = inputs[rv.OUTPUT_GROUP_SLICE]
|
| 338 |
+
# group_loss = self.predict_fn(group_output)[:, None]
|
| 339 |
+
#
|
| 340 |
+
# output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
|
| 341 |
+
# range_mask = self.range_mask(group_loss[:, 0])
|
| 342 |
+
# loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
|
| 343 |
+
#
|
| 344 |
+
# properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
|
| 345 |
+
# properties = []
|
| 346 |
+
# outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
| 347 |
+
# for i, out in enumerate(outputs):
|
| 348 |
+
# shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
| 349 |
+
# out_reshaped = tf.reshape(out, shape)
|
| 350 |
+
# properties.append(self.predict_fn(out_reshaped))
|
| 351 |
+
# number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
| 352 |
+
#
|
| 353 |
+
# result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
| 354 |
+
#
|
| 355 |
+
# return [result, range_mask, range_mask, range_mask, range_mask]
|
| 356 |
+
|
| 357 |
+
class PredictModel(Model):
|
| 358 |
+
def __init__(self):
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.predict_fn = Predict()
|
| 361 |
+
self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
| 362 |
+
self.range_mask = RangeMask()
|
| 363 |
+
|
| 364 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 365 |
+
|
| 366 |
+
def call(self, inputs):
|
| 367 |
+
group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
|
| 368 |
+
number_loss = K.int64(K.sum(output_slot))
|
| 369 |
+
result = tf.concat(
|
| 370 |
+
[group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
|
| 371 |
+
axis=-1)
|
| 372 |
+
|
| 373 |
+
range_mask = self.range_mask(group_output)
|
| 374 |
+
return [result, range_mask]
|
| 375 |
+
# return [result, range_mask, range_mask, range_mask, range_mask]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# todo change slices
|
| 379 |
+
class PredictModelMasked(Model):
|
| 380 |
+
def __init__(self):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
| 383 |
+
self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
| 384 |
+
self.range_mask = RangeMask()
|
| 385 |
+
|
| 386 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 387 |
+
|
| 388 |
+
def call(self, inputs):
|
| 389 |
+
group_output = inputs[:, -rv.GROUPS_NO:]
|
| 390 |
+
group_loss = self.predict_fn(group_output)[:, None]
|
| 391 |
+
|
| 392 |
+
output_slot = inputs[:, :rv.ENTITY_SUM]
|
| 393 |
+
range_mask = self.range_mask(group_loss[:, 0])
|
| 394 |
+
loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
|
| 395 |
+
|
| 396 |
+
properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
|
| 397 |
+
|
| 398 |
+
properties = []
|
| 399 |
+
outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
| 400 |
+
for i, out in enumerate(outputs):
|
| 401 |
+
shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
| 402 |
+
out_reshaped = tf.reshape(out, shape)
|
| 403 |
+
out_masked = out_reshaped * loss_slot[..., None]
|
| 404 |
+
properties.append(self.predict_fn(out_masked))
|
| 405 |
+
# out_masked[0].numpy()
|
| 406 |
+
number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
| 407 |
+
|
| 408 |
+
result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
| 409 |
+
|
| 410 |
+
return result
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def final_predict_mask(x, mask):
|
| 414 |
+
r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
|
| 415 |
+
return tf.ragged.boolean_mask(r, mask)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def final_predict(x, mode=False):
|
| 419 |
+
m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
|
| 420 |
+
return final_predict_mask(x[0], m)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def final_predict_2(x):
|
| 424 |
+
ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
|
| 425 |
+
mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
|
| 426 |
+
return tf.ragged.boolean_mask(x[0], mask)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class PredictModelOld(Model):
|
| 430 |
+
|
| 431 |
+
def call(self, inputs):
|
| 432 |
+
output = inputs[-2]
|
| 433 |
+
|
| 434 |
+
rest_output = output[:, :-rv.GROUPS_NO]
|
| 435 |
+
|
| 436 |
+
result_all = []
|
| 437 |
+
outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
|
| 438 |
+
for i, out in enumerate(outputs):
|
| 439 |
+
shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
| 440 |
+
out_reshaped = tf.reshape(out, shape)
|
| 441 |
+
|
| 442 |
+
result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
|
| 443 |
+
result_all.append(result)
|
| 444 |
+
|
| 445 |
+
result_all = interleave(result_all)
|
| 446 |
+
return result_all
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def get_matches(diff, target_index):
|
| 450 |
+
diff_sum = K.sum(diff)
|
| 451 |
+
db_argsort = tf.argsort(diff_sum, axis=-1)
|
| 452 |
+
db_sorted = tf.sort(diff_sum)
|
| 453 |
+
db_mask = db_sorted[:, 0, None] == db_sorted
|
| 454 |
+
db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
|
| 455 |
+
matched_index = db_same == target_index
|
| 456 |
+
# setting shape needed for TensorFlow graph
|
| 457 |
+
matched_index.set_shape(db_same.shape)
|
| 458 |
+
matches = K.any(matched_index)
|
| 459 |
+
more_matches = K.sum(db_mask) > 1
|
| 460 |
+
once_matches = K.sum(matches & tf.math.logical_not(more_matches))
|
| 461 |
+
return matches, more_matches, once_matches
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class SimilarityRaven(Model):
|
| 465 |
+
def __init__(self, mode=create_all_mask, number_loss=False):
|
| 466 |
+
super().__init__()
|
| 467 |
+
self.range_mask = RangeMask()
|
| 468 |
+
self.mode = mode
|
| 469 |
+
|
| 470 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 471 |
+
|
| 472 |
+
# INDEX, PREDICT, LABELS
|
| 473 |
+
def call(self, inputs):
|
| 474 |
+
metrics = []
|
| 475 |
+
target_index = inputs[0] - 8
|
| 476 |
+
predict = inputs[1]
|
| 477 |
+
answers = inputs[2][:, 8:]
|
| 478 |
+
shape = tf.shape(predict)
|
| 479 |
+
|
| 480 |
+
target = K.gather(answers, target_index[:, 0])
|
| 481 |
+
|
| 482 |
+
target_group = target[:, 0]
|
| 483 |
+
|
| 484 |
+
# comp_slice = np.
|
| 485 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
| 486 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
| 487 |
+
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
| 488 |
+
|
| 489 |
+
full_properties_musks = self.mode(target)
|
| 490 |
+
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
| 491 |
+
|
| 492 |
+
range_mask = self.range_mask(target_group)
|
| 493 |
+
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
| 494 |
+
|
| 495 |
+
final_mask = fpm & full_range_mask
|
| 496 |
+
|
| 497 |
+
target_masked = target_comp * final_mask
|
| 498 |
+
predict_masked = predict_comp * final_mask
|
| 499 |
+
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
| 500 |
+
|
| 501 |
+
acc_same = K.mean(K.all(target_masked == predict_masked))
|
| 502 |
+
self.add_metric(acc_same, ACC_SAME)
|
| 503 |
+
metrics.append(acc_same)
|
| 504 |
+
|
| 505 |
+
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
| 506 |
+
diff_bool = diff != 0
|
| 507 |
+
|
| 508 |
+
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
| 509 |
+
|
| 510 |
+
second_phase_mask = (more_matches & matches)
|
| 511 |
+
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
| 512 |
+
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
| 513 |
+
|
| 514 |
+
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
| 515 |
+
matches_2_no = K.sum(matches_2)
|
| 516 |
+
|
| 517 |
+
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
| 518 |
+
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
| 519 |
+
metrics.append(acc_choose_upper)
|
| 520 |
+
|
| 521 |
+
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
| 522 |
+
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
| 523 |
+
metrics.append(acc_choose_lower)
|
| 524 |
+
|
| 525 |
+
return metrics
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class SimilarityRaven2(Model):
|
| 529 |
+
def __init__(self, mode=create_all_mask, number_loss=False):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.range_mask = RangeMask()
|
| 532 |
+
self.mode = mode
|
| 533 |
+
|
| 534 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 535 |
+
|
| 536 |
+
# INDEX, PREDICT, LABELS
|
| 537 |
+
def call(self, inputs):
|
| 538 |
+
metrics = []
|
| 539 |
+
target_index = inputs[0] - 8
|
| 540 |
+
predict = inputs[1]
|
| 541 |
+
answers = inputs[2][:, 8:]
|
| 542 |
+
shape = tf.shape(predict)
|
| 543 |
+
|
| 544 |
+
target = K.gather(answers, target_index[:, 0])
|
| 545 |
+
|
| 546 |
+
target_group = target[:, 0]
|
| 547 |
+
|
| 548 |
+
# comp_slice = np.
|
| 549 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
| 550 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
| 551 |
+
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
| 552 |
+
|
| 553 |
+
full_properties_musks = self.mode(target)
|
| 554 |
+
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
| 555 |
+
|
| 556 |
+
range_mask = self.range_mask(target_group)
|
| 557 |
+
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
| 558 |
+
|
| 559 |
+
final_mask = fpm & full_range_mask
|
| 560 |
+
|
| 561 |
+
target_masked = target_comp * final_mask
|
| 562 |
+
predict_masked = predict_comp * final_mask
|
| 563 |
+
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
| 564 |
+
|
| 565 |
+
acc_same = K.mean(K.all(target_masked == predict_masked))
|
| 566 |
+
self.add_metric(acc_same, ACC_SAME)
|
| 567 |
+
metrics.append(acc_same)
|
| 568 |
+
|
| 569 |
+
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
| 570 |
+
diff_bool = diff != 0
|
| 571 |
+
|
| 572 |
+
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
| 573 |
+
|
| 574 |
+
second_phase_mask = (more_matches & matches)
|
| 575 |
+
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
| 576 |
+
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
| 577 |
+
|
| 578 |
+
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
| 579 |
+
matches_2_no = K.sum(matches_2)
|
| 580 |
+
|
| 581 |
+
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
| 582 |
+
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
| 583 |
+
metrics.append(acc_choose_upper)
|
| 584 |
+
|
| 585 |
+
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
| 586 |
+
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
| 587 |
+
metrics.append(acc_choose_lower)
|
| 588 |
+
|
| 589 |
+
metrics.append(K.sum(target_masked != predict_masked))
|
| 590 |
+
|
| 591 |
+
return metrics
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class LatentLossModel(Model):
|
| 595 |
+
def __init__(self, dir_=HORIZONTAL):
|
| 596 |
+
super().__init__()
|
| 597 |
+
# self.sum_metrics = []
|
| 598 |
+
# for i in range(8):
|
| 599 |
+
# self.sum_metrics.append(Sum(name=f"no_{i}"))
|
| 600 |
+
self.metric_fn = Accuracy(name="acc_latent")
|
| 601 |
+
if dir_ == VERTICAL:
|
| 602 |
+
self.dir = (6, 7)
|
| 603 |
+
else:
|
| 604 |
+
self.dir = (2, 5)
|
| 605 |
+
|
| 606 |
+
def call(self, inputs):
|
| 607 |
+
target_image = tf.reshape(inputs[0][2], [-1])
|
| 608 |
+
output = inputs[1]
|
| 609 |
+
latents = tnp.asarray(inputs[2])
|
| 610 |
+
|
| 611 |
+
target_hor = tf.concat([
|
| 612 |
+
latents[:, self.dir],
|
| 613 |
+
latents[tf.range(latents.shape[0]), target_image + 8][:, None]
|
| 614 |
+
],
|
| 615 |
+
axis=1)
|
| 616 |
+
|
| 617 |
+
loss_hor = mse(K.stop_gradient(target_hor), output)
|
| 618 |
+
self.add_loss(loss_hor)
|
| 619 |
+
|
| 620 |
+
self.add_metric(self.metric_fn(inputs[3], target_image))
|
| 621 |
+
|
| 622 |
+
return loss_hor
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class PredRav(Model):
|
| 626 |
+
|
| 627 |
+
def call(self, inputs):
|
| 628 |
+
output = inputs[0][:, -1]
|
| 629 |
+
answers = inputs[1][:, 8:]
|
| 630 |
+
return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
|
raven_utils/models/loss_3.py
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import tensorflow.experimental.numpy as tnp
|
| 5 |
+
from models_utils import OUTPUT, TARGET, PREDICT, DictModel, add_loss, LOSS, Predict
|
| 6 |
+
from models_utils import SubClassingModel
|
| 7 |
+
from models_utils.models.utils import interleave
|
| 8 |
+
from models_utils.op import reshape
|
| 9 |
+
from tensorflow.keras import Model
|
| 10 |
+
# from tensorflow.keras import backend as K
|
| 11 |
+
from tensorflow.keras.layers import Lambda
|
| 12 |
+
from tensorflow.keras.losses import SparseCategoricalCrossentropy, mse
|
| 13 |
+
from tensorflow.keras.metrics import SparseCategoricalAccuracy, Accuracy, BinaryAccuracy
|
| 14 |
+
import models_utils.ops as K
|
| 15 |
+
|
| 16 |
+
import raven_utils.decode
|
| 17 |
+
import raven_utils as rv
|
| 18 |
+
from raven_utils.config.constant import LABELS, INDEX, ACC_SAME, ACC_CHOOSE_LOWER, ACC_CHOOSE_UPPER, CLASSIFICATION, \
|
| 19 |
+
SLOT, \
|
| 20 |
+
PROPERTIES, ACC, GROUP, NUMBER, MASK
|
| 21 |
+
from raven_utils.models.uitls_ import RangeMask
|
| 22 |
+
from raven_utils.const import VERTICAL, HORIZONTAL
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_properties_mask(target):
|
| 26 |
+
return target[:, rv.target.END_INDEX:rv.target.UNIFORMITY_INDEX] > 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_change_mask(target):
|
| 30 |
+
properties_mask = get_properties_mask(target)
|
| 31 |
+
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_uniform_mask(target):
|
| 35 |
+
u_mask = lambda i: tf.tile(target[:, rv.target.UNIFORMITY_INDEX + i, None] == 3, [1, rv.rules.ATTRIBUTES_LEN])
|
| 36 |
+
properties_mask = tf.concat([u_mask(0), u_mask(1)], axis=-1) | get_properties_mask(target)
|
| 37 |
+
return [create_mask(properties_mask, i) for i, _ in enumerate(rv.rules.ATTRIBUTES)]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def create_all_mask(target):
|
| 41 |
+
return [
|
| 42 |
+
tf.cast(tf.ones(tf.stack([tf.shape(target)[0], rv.entity.SUM])), dtype=tf.bool) for i, _ in
|
| 43 |
+
enumerate(rv.rules.ATTRIBUTES)]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BaselineClassificationLossModel(Model):
|
| 47 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.predict_fn = SubClassingModel([lambda x: x[0], PredictModel()])
|
| 50 |
+
self.loss_fn = ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 51 |
+
group_loss=group_loss)
|
| 52 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
| 53 |
+
|
| 54 |
+
def call(self, inputs):
|
| 55 |
+
losses = []
|
| 56 |
+
output = inputs[1]
|
| 57 |
+
losses.append(self.loss_fn([inputs[0][0], output]))
|
| 58 |
+
losses.append(self.metric_fn([inputs[0][2], inputs[3][0], inputs[0][1][:, 8:]]))
|
| 59 |
+
return losses
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RavenLoss(Model):
|
| 63 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.3),
|
| 64 |
+
classification=False, trans=True, anneal=False):
|
| 65 |
+
super().__init__()
|
| 66 |
+
if anneal:
|
| 67 |
+
self.weight_scheduler
|
| 68 |
+
self.classification = classification
|
| 69 |
+
self.trans = trans
|
| 70 |
+
self.predict_fn = DictModel(SubClassingModel([lambda x: x[-1], PredictModel()]), in_=OUTPUT,
|
| 71 |
+
out=[PREDICT, MASK], name="pred")
|
| 72 |
+
if self.trans:
|
| 73 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 74 |
+
group_loss=group_loss, enable_metrics=False, lw=lw[0]),
|
| 75 |
+
name="main_loss")
|
| 76 |
+
self.loss_fn_2 = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 77 |
+
group_loss=group_loss), name="add_loss")
|
| 78 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
| 79 |
+
if self.classification:
|
| 80 |
+
self.loss_fn_3 = add_loss(
|
| 81 |
+
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
| 82 |
+
group_loss=group_loss, enable_metrics="c" if self.trans else True), lw=lw[1],
|
| 83 |
+
name="class_loss")
|
| 84 |
+
|
| 85 |
+
def call(self, inputs):
|
| 86 |
+
losses = []
|
| 87 |
+
output = inputs[OUTPUT]
|
| 88 |
+
target = inputs[TARGET]
|
| 89 |
+
labels = inputs[LABELS]
|
| 90 |
+
|
| 91 |
+
if self.trans:
|
| 92 |
+
losses.append(self.loss_fn([labels[:, 2], output[0]]))
|
| 93 |
+
losses.append(self.loss_fn([labels[:, 5], output[1]]))
|
| 94 |
+
losses.append(self.loss_fn_2([target, output[2]]))
|
| 95 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
| 96 |
+
if self.classification:
|
| 97 |
+
for i in range(8):
|
| 98 |
+
losses.append(self.loss_fn_3([labels[:, i], inputs[CLASSIFICATION][i]]))
|
| 99 |
+
return {**inputs, LOSS: losses}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class VTRavenLoss(Model):
|
| 103 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(2.0, 1.0),
|
| 104 |
+
classification=False, trans=True, anneal=False, plw=None):
|
| 105 |
+
super().__init__()
|
| 106 |
+
if anneal:
|
| 107 |
+
self.weight_scheduler
|
| 108 |
+
self.classification = classification
|
| 109 |
+
self.trans = trans
|
| 110 |
+
self.predict_fn = DictModel(SubClassingModel([lambda x: x[:, -1], PredictModel()]), in_=OUTPUT,
|
| 111 |
+
out=[PREDICT, "predict_mask"], name="pred")
|
| 112 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 113 |
+
group_loss=group_loss, plw=plw), lw=lw[0], name="add_loss")
|
| 114 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
| 115 |
+
if self.classification:
|
| 116 |
+
self.loss_fn_2 = add_loss(
|
| 117 |
+
ClassRavenModel(mode=create_all_mask, number_loss=number_loss, slot_loss=slot_loss,
|
| 118 |
+
group_loss=group_loss, enable_metrics="c", plw=plw), lw=lw[1], name="class_loss")
|
| 119 |
+
|
| 120 |
+
def call(self, inputs):
|
| 121 |
+
losses = []
|
| 122 |
+
output = inputs[OUTPUT]
|
| 123 |
+
target = inputs[TARGET]
|
| 124 |
+
labels = inputs[LABELS]
|
| 125 |
+
mask = inputs[MASK]
|
| 126 |
+
|
| 127 |
+
target_masked = target[mask]
|
| 128 |
+
output_masked = output[mask]
|
| 129 |
+
losses.append(self.loss_fn([target_masked, output_masked]))
|
| 130 |
+
|
| 131 |
+
target_unmasked = target[~mask]
|
| 132 |
+
output_unmasked = output[~mask]
|
| 133 |
+
losses.append(self.loss_fn_2([target_unmasked, output_unmasked]))
|
| 134 |
+
|
| 135 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
| 136 |
+
return {**inputs, LOSS: losses}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class SingleVTRavenLoss(Model):
|
| 140 |
+
def __init__(self, mode=create_all_mask, number_loss=False, slot_loss=True, group_loss=True, lw=(1.0, 0.1),
|
| 141 |
+
classification=False, trans=True, anneal=False):
|
| 142 |
+
super().__init__()
|
| 143 |
+
if anneal:
|
| 144 |
+
self.weight_scheduler
|
| 145 |
+
self.classification = classification
|
| 146 |
+
self.trans = trans
|
| 147 |
+
self.predict_fn = DictModel(PredictModel(), in_=OUTPUT, out=[PREDICT, MASK], name="pred")
|
| 148 |
+
self.loss_fn = add_loss(ClassRavenModel(mode=mode, number_loss=number_loss, slot_loss=slot_loss,
|
| 149 |
+
group_loss=group_loss), lw=lw[0], name="add_loss")
|
| 150 |
+
self.metric_fn = SimilarityRaven(mode=mode)
|
| 151 |
+
|
| 152 |
+
def call(self, inputs):
|
| 153 |
+
losses = []
|
| 154 |
+
output = inputs[OUTPUT]
|
| 155 |
+
target = inputs[TARGET]
|
| 156 |
+
labels = inputs[LABELS]
|
| 157 |
+
|
| 158 |
+
losses.append(self.loss_fn([target, output]))
|
| 159 |
+
losses.append(self.metric_fn([inputs[INDEX], inputs[PREDICT], labels]))
|
| 160 |
+
return {**inputs, LOSS: losses}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ClassRavenModel(Model):
|
| 164 |
+
def __init__(self, mode=create_all_mask, plw=None, number_loss=False, slot_loss=True, group_loss=True,
|
| 165 |
+
enable_metrics=True,
|
| 166 |
+
lw=1.0):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.number_loss = number_loss
|
| 169 |
+
self.group_loss = group_loss
|
| 170 |
+
self.enable_metrics = enable_metrics
|
| 171 |
+
self.slot_loss = slot_loss
|
| 172 |
+
self.predict_fn = PredictModel()
|
| 173 |
+
self.loss_fn = SparseCategoricalCrossentropy(from_logits=True)
|
| 174 |
+
if self.slot_loss:
|
| 175 |
+
self.loss_fn_2 = tf.nn.sigmoid_cross_entropy_with_logits
|
| 176 |
+
if self.enable_metrics:
|
| 177 |
+
self.enable_metrics = f"{self.enable_metrics}_" if isinstance(self.enable_metrics, str) else ""
|
| 178 |
+
self.metric_fn = [
|
| 179 |
+
SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{property_}") for property_ in
|
| 180 |
+
rv.properties.NAMES]
|
| 181 |
+
if self.group_loss:
|
| 182 |
+
self.metric_fn_group = SparseCategoricalAccuracy(name=f"{self.enable_metrics}{ACC}_{GROUP}")
|
| 183 |
+
if self.slot_loss:
|
| 184 |
+
self.metric_fn_2 = BinaryAccuracy(name=f"{self.enable_metrics}{ACC}_{SLOT}")
|
| 185 |
+
self.range_mask = RangeMask()
|
| 186 |
+
self.mode = mode
|
| 187 |
+
self.lw = lw
|
| 188 |
+
if not plw:
|
| 189 |
+
plw = [1., 95.37352927, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
| 190 |
+
elif isinstance(plw, int) or isinstance(plw, float):
|
| 191 |
+
plw = [1., plw, 2.83426987, 0.85212836, 1.096005, 1.21943385]
|
| 192 |
+
# plw = [plw] * 6
|
| 193 |
+
self.plw = plw
|
| 194 |
+
|
| 195 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 196 |
+
|
| 197 |
+
def call(self, inputs):
|
| 198 |
+
losses = []
|
| 199 |
+
metrics = {}
|
| 200 |
+
target = inputs[0]
|
| 201 |
+
output = inputs[1]
|
| 202 |
+
|
| 203 |
+
target_group, target_slot, target_all = raven_utils.decode.decode_target(target)
|
| 204 |
+
|
| 205 |
+
group_output, output_slot, outputs = raven_utils.decode.output_divide(output, split_fn=tf.split)
|
| 206 |
+
|
| 207 |
+
# group
|
| 208 |
+
if self.group_loss:
|
| 209 |
+
group_loss = self.lw * self.plw[0] * self.loss_fn(target_group, group_output)
|
| 210 |
+
losses.append(group_loss)
|
| 211 |
+
|
| 212 |
+
if isinstance(self.enable_metrics, str):
|
| 213 |
+
group_metric = self.metric_fn_group(target_group, group_output)
|
| 214 |
+
# metrics[GROUP] = group_metric
|
| 215 |
+
self.add_metric(group_metric)
|
| 216 |
+
self.add_metric(tf.reduce_sum(group_metric), f"{self.enable_metrics}{ACC}")
|
| 217 |
+
|
| 218 |
+
# setting uniformity mask
|
| 219 |
+
full_properties_musks = self.mode(target)
|
| 220 |
+
|
| 221 |
+
range_mask = self.range_mask(target_group)
|
| 222 |
+
|
| 223 |
+
if self.slot_loss:
|
| 224 |
+
# number
|
| 225 |
+
number_mask = range_mask & full_properties_musks[0]
|
| 226 |
+
number_mask = tf.cast(number_mask, tf.float32)
|
| 227 |
+
target_number = tf.reduce_sum(
|
| 228 |
+
tf.cast(target_slot, "float32") * number_mask, axis=-1)
|
| 229 |
+
output_number = tf.reduce_sum(
|
| 230 |
+
tf.cast(tf.sigmoid(output_slot) >= 0.5, "float32") * number_mask, axis=-1)
|
| 231 |
+
|
| 232 |
+
# output_number = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
| 233 |
+
if self.number_loss:
|
| 234 |
+
scale = 1 / 9
|
| 235 |
+
if self.number_loss == 2:
|
| 236 |
+
output_number_2 = tf.reduce_sum(tf.sigmoid(output_slot) * number_mask, axis=-1)
|
| 237 |
+
else:
|
| 238 |
+
output_number_2 = output_number
|
| 239 |
+
number_loss = self.lw * self.plw[1] * mse(tf.stop_gradient(target_number) * scale,
|
| 240 |
+
output_number_2 * scale)
|
| 241 |
+
losses.append(number_loss)
|
| 242 |
+
|
| 243 |
+
# metrics[NUMBER] = number_acc
|
| 244 |
+
|
| 245 |
+
if isinstance(self.enable_metrics, str):
|
| 246 |
+
number_acc = tf.reduce_mean(
|
| 247 |
+
tf.cast(tf.cast(target_number, "int8") == tf.cast(output_number, "int8"), "float32"))
|
| 248 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_{NUMBER}")
|
| 249 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}")
|
| 250 |
+
self.add_metric(tf.reduce_sum(number_acc), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
| 251 |
+
|
| 252 |
+
# position/slot
|
| 253 |
+
slot_mask = range_mask & full_properties_musks[1]
|
| 254 |
+
# tf.boolean_mask(target_slot,slot_mask)
|
| 255 |
+
|
| 256 |
+
if tf.reduce_any(slot_mask):
|
| 257 |
+
# if tf.reduce_mean(tf.cast(slot_mask, dtype=tf.int32)) > 0:
|
| 258 |
+
target_slot_masked = tf.boolean_mask(target_slot, slot_mask)[:, None]
|
| 259 |
+
output_slot_masked = tf.boolean_mask(output_slot, slot_mask)[:, None]
|
| 260 |
+
loss_slot = self.lw * self.plw[2] * tf.reduce_mean(
|
| 261 |
+
self.loss_fn_2(tf.cast(target_slot_masked, "float32"), output_slot_masked))
|
| 262 |
+
if isinstance(self.enable_metrics, str):
|
| 263 |
+
acc_slot = self.metric_fn_2(target_slot_masked, output_slot_masked)
|
| 264 |
+
self.add_metric(acc_slot)
|
| 265 |
+
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}")
|
| 266 |
+
self.add_metric(tf.reduce_sum(acc_slot), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
| 267 |
+
else:
|
| 268 |
+
loss_slot = 0.0
|
| 269 |
+
acc_slot = -1.0
|
| 270 |
+
|
| 271 |
+
losses.append(loss_slot)
|
| 272 |
+
# metrics[SLOT] = acc_slot
|
| 273 |
+
# if loss_slot != 0:
|
| 274 |
+
|
| 275 |
+
# if tf.reduce_any(slot_mask):
|
| 276 |
+
|
| 277 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_{NUMBER}")
|
| 278 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}")
|
| 279 |
+
# self.add_metric(acc_slot, f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
| 280 |
+
|
| 281 |
+
# properties
|
| 282 |
+
for i, out in enumerate(outputs):
|
| 283 |
+
shape = (-1, rv.entity.SUM, rv.properties.RAW_SIZE[i])
|
| 284 |
+
out_reshaped = tf.reshape(out, shape)
|
| 285 |
+
properties_mask = tf.cast(target_slot, "bool") & full_properties_musks[i + 2]
|
| 286 |
+
|
| 287 |
+
if tf.reduce_any(properties_mask):
|
| 288 |
+
out_masked = tf.boolean_mask(out_reshaped, properties_mask)
|
| 289 |
+
out_target = tf.boolean_mask(target_all[i], properties_mask)
|
| 290 |
+
loss = self.lw * self.plw[3 + i] * self.loss_fn(out_target, out_masked)
|
| 291 |
+
if isinstance(self.enable_metrics, str):
|
| 292 |
+
metric = self.metric_fn[i](out_target, out_masked)
|
| 293 |
+
self.add_metric(metric)
|
| 294 |
+
# self.add_metric(metric, f"{self.enable_metrics}{ACC}")
|
| 295 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}")
|
| 296 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_{PROPERTIES}")
|
| 297 |
+
self.add_metric(tf.reduce_sum(metric), f"{self.enable_metrics}{ACC}_NO_{GROUP}")
|
| 298 |
+
else:
|
| 299 |
+
loss = 0.0
|
| 300 |
+
metric = -1.0
|
| 301 |
+
|
| 302 |
+
losses.append(loss)
|
| 303 |
+
return losses
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class FullMask(Model):
|
| 307 |
+
def __init__(self, mode=create_uniform_mask):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.range_mask = RangeMask()
|
| 310 |
+
self.mode = mode
|
| 311 |
+
|
| 312 |
+
def call(self, inputs):
|
| 313 |
+
target_group, target_slot, _ = raven_utils.decode.decode_target(inputs)
|
| 314 |
+
full_properties_musks = self.mode(inputs)
|
| 315 |
+
range_mask = self.range_mask(target_group)
|
| 316 |
+
|
| 317 |
+
number_mask = range_mask & full_properties_musks[0]
|
| 318 |
+
|
| 319 |
+
slot_mask = range_mask & full_properties_musks[1]
|
| 320 |
+
properties_mask = []
|
| 321 |
+
for property_mask in full_properties_musks[2:]:
|
| 322 |
+
properties_mask.append(tf.cast(target_slot, "bool") & property_mask)
|
| 323 |
+
return [slot_mask, properties_mask, number_mask]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def create_mask(rules, i):
|
| 327 |
+
mask_1 = tf.tile(rules[:, i][None], [len(rv.target.FIRST_LAYOUT), 1])
|
| 328 |
+
mask_2 = tf.tile(rules[:, i + 5][None], [len(rv.target.SECOND_LAYOUT), 1])
|
| 329 |
+
shape = tf.shape(rules)
|
| 330 |
+
full_mask_1 = tf.scatter_nd(tnp.array(rv.target.FIRST_LAYOUT)[:, None], mask_1, shape=(rv.entity.SUM, shape[0]))
|
| 331 |
+
full_mask_2 = tf.tensor_scatter_nd_update(full_mask_1, tnp.array(rv.target.SECOND_LAYOUT)[:, None], mask_2)
|
| 332 |
+
return tf.transpose(full_mask_2)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
# class PredictModel(Model):
|
| 336 |
+
# def __init__(self):
|
| 337 |
+
# super().__init__()
|
| 338 |
+
# self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
| 339 |
+
# self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
| 340 |
+
# self.range_mask = RangeMask()
|
| 341 |
+
#
|
| 342 |
+
# # self.predict_fn = partial(tf.argmax, axis=-1)
|
| 343 |
+
#
|
| 344 |
+
# def call(self, inputs):
|
| 345 |
+
# group_output = inputs[rv.OUTPUT_GROUP_SLICE]
|
| 346 |
+
# group_loss = self.predict_fn(group_output)[:, None]
|
| 347 |
+
#
|
| 348 |
+
# output_slot = inputs[rv.OUTPUT_SLOT_SLICE]
|
| 349 |
+
# range_mask = self.range_mask(group_loss[:, 0])
|
| 350 |
+
# loss_slot = tf.cast(self.predict_fn_2(output_slot), dtype=tf.int64)
|
| 351 |
+
#
|
| 352 |
+
# properties_output = inputs[rv.OUTPUT_PROPERTIES_SLICE]
|
| 353 |
+
# properties = []
|
| 354 |
+
# outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
| 355 |
+
# for i, out in enumerate(outputs):
|
| 356 |
+
# shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
| 357 |
+
# out_reshaped = tf.reshape(out, shape)
|
| 358 |
+
# properties.append(self.predict_fn(out_reshaped))
|
| 359 |
+
# number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
| 360 |
+
#
|
| 361 |
+
# result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
| 362 |
+
#
|
| 363 |
+
# return [result, range_mask, range_mask, range_mask, range_mask]
|
| 364 |
+
|
| 365 |
+
class PredictModel(Model):
|
| 366 |
+
def __init__(self):
|
| 367 |
+
super().__init__()
|
| 368 |
+
self.predict_fn = Predict()
|
| 369 |
+
self.predict_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
| 370 |
+
self.range_mask = RangeMask()
|
| 371 |
+
|
| 372 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 373 |
+
|
| 374 |
+
def call(self, inputs):
|
| 375 |
+
group_output, output_slot, *properties = rv.decode.output(inputs, tf.split, self.predict_fn, self.predict_fn_2)
|
| 376 |
+
number_loss = K.int64(K.sum(output_slot))
|
| 377 |
+
result = tf.concat(
|
| 378 |
+
[group_output[:, None], tf.cast(output_slot, dtype=tf.int64), interleave(properties), number_loss[:, None]],
|
| 379 |
+
axis=-1)
|
| 380 |
+
|
| 381 |
+
range_mask = self.range_mask(group_output)
|
| 382 |
+
return [result, range_mask]
|
| 383 |
+
# return [result, range_mask, range_mask, range_mask, range_mask]
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# todo change slices
|
| 387 |
+
class PredictModelMasked(Model):
|
| 388 |
+
def __init__(self):
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.predict_fn = Lambda(partial(tf.argmax, axis=-1))
|
| 391 |
+
self.loss_fn_2 = Lambda(lambda x: tf.sigmoid(x) > 0.5)
|
| 392 |
+
self.range_mask = RangeMask()
|
| 393 |
+
|
| 394 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 395 |
+
|
| 396 |
+
def call(self, inputs):
|
| 397 |
+
group_output = inputs[:, -rv.GROUPS_NO:]
|
| 398 |
+
group_loss = self.predict_fn(group_output)[:, None]
|
| 399 |
+
|
| 400 |
+
output_slot = inputs[:, :rv.ENTITY_SUM]
|
| 401 |
+
range_mask = self.range_mask(group_loss[:, 0])
|
| 402 |
+
loss_slot = tf.cast(self.predict_fn_2(output_slot * range_mask), dtype=tf.int64)
|
| 403 |
+
|
| 404 |
+
properties_output = inputs[:, rv.ENTITY_SUM:-rv.GROUPS_NO]
|
| 405 |
+
|
| 406 |
+
properties = []
|
| 407 |
+
outputs = tf.split(properties_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-1)
|
| 408 |
+
for i, out in enumerate(outputs):
|
| 409 |
+
shape = (-1, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
| 410 |
+
out_reshaped = tf.reshape(out, shape)
|
| 411 |
+
out_masked = out_reshaped * loss_slot[..., None]
|
| 412 |
+
properties.append(self.predict_fn(out_masked))
|
| 413 |
+
# out_masked[0].numpy()
|
| 414 |
+
number_loss = tf.reduce_sum(loss_slot, axis=-1, keepdims=True)
|
| 415 |
+
|
| 416 |
+
result = tf.concat([group_loss, loss_slot, interleave(properties), number_loss], axis=-1)
|
| 417 |
+
|
| 418 |
+
return result
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def final_predict_mask(x, mask):
|
| 422 |
+
r = reshape(x[0][:, rv.INDEX[0]:-1], [-1, 3])
|
| 423 |
+
return tf.ragged.boolean_mask(r, mask)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def final_predict(x, mode=False):
|
| 427 |
+
m = x[1] if mode else tf.cast(x[0][:, 1:rv.INDEX[0]], tf.bool)
|
| 428 |
+
return final_predict_mask(x[0], m)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def final_predict_2(x):
|
| 432 |
+
ones = tf.cast(tf.ones(tf.shape(x[0])[0]), tf.bool)[:, None]
|
| 433 |
+
mask = tf.concat([ones, tf.tile(x[1], [1, 4]), ones], axis=-1)
|
| 434 |
+
return tf.ragged.boolean_mask(x[0], mask)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class PredictModelOld(Model):
|
| 438 |
+
|
| 439 |
+
def call(self, inputs):
|
| 440 |
+
output = inputs[-2]
|
| 441 |
+
|
| 442 |
+
rest_output = output[:, :-rv.GROUPS_NO]
|
| 443 |
+
|
| 444 |
+
result_all = []
|
| 445 |
+
outputs = tf.split(rest_output, list(rv.ENTITY_PROPERTIES_INDEX.values()), axis=-3)
|
| 446 |
+
for i, out in enumerate(outputs):
|
| 447 |
+
shape = (-3, rv.ENTITY_SUM, rv.ENTITY_PROPERTIES_VALUES[i])
|
| 448 |
+
out_reshaped = tf.reshape(out, shape)
|
| 449 |
+
|
| 450 |
+
result = tf.cast(tf.argmax(out_reshaped, axis=-3), dtype="int8")
|
| 451 |
+
result_all.append(result)
|
| 452 |
+
|
| 453 |
+
result_all = interleave(result_all)
|
| 454 |
+
return result_all
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def get_matches(diff, target_index):
|
| 458 |
+
diff_sum = K.sum(diff)
|
| 459 |
+
db_argsort = tf.argsort(diff_sum, axis=-1)
|
| 460 |
+
db_sorted = tf.sort(diff_sum)
|
| 461 |
+
db_mask = db_sorted[:, 0, None] == db_sorted
|
| 462 |
+
db_same = tf.where(db_mask, db_argsort, -1 * tf.ones_like(db_argsort))
|
| 463 |
+
matched_index = db_same == target_index
|
| 464 |
+
# setting shape needed for TensorFlow graph
|
| 465 |
+
matched_index.set_shape(db_same.shape)
|
| 466 |
+
matches = K.any(matched_index)
|
| 467 |
+
more_matches = K.sum(db_mask) > 1
|
| 468 |
+
once_matches = K.sum(matches & tf.math.logical_not(more_matches))
|
| 469 |
+
return matches, more_matches, once_matches
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class SimilarityRaven(Model):
|
| 473 |
+
def __init__(self, mode=create_all_mask, number_loss=False):
|
| 474 |
+
super().__init__()
|
| 475 |
+
self.range_mask = RangeMask()
|
| 476 |
+
self.mode = mode
|
| 477 |
+
|
| 478 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 479 |
+
|
| 480 |
+
# INDEX, PREDICT, LABELS
|
| 481 |
+
def call(self, inputs):
|
| 482 |
+
metrics = []
|
| 483 |
+
target_index = inputs[0] - 8
|
| 484 |
+
predict = inputs[1]
|
| 485 |
+
answers = inputs[2][:, 8:]
|
| 486 |
+
shape = tf.shape(predict)
|
| 487 |
+
|
| 488 |
+
target = K.gather(answers, target_index[:, 0])
|
| 489 |
+
|
| 490 |
+
target_group = target[:, 0]
|
| 491 |
+
|
| 492 |
+
# comp_slice = np.
|
| 493 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
| 494 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
| 495 |
+
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
| 496 |
+
|
| 497 |
+
full_properties_musks = self.mode(target)
|
| 498 |
+
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
| 499 |
+
|
| 500 |
+
range_mask = self.range_mask(target_group)
|
| 501 |
+
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
| 502 |
+
|
| 503 |
+
final_mask = fpm & full_range_mask
|
| 504 |
+
|
| 505 |
+
target_masked = target_comp * final_mask
|
| 506 |
+
predict_masked = predict_comp * final_mask
|
| 507 |
+
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
| 508 |
+
|
| 509 |
+
acc_same = K.mean(K.all(target_masked == predict_masked))
|
| 510 |
+
self.add_metric(acc_same, ACC_SAME)
|
| 511 |
+
metrics.append(acc_same)
|
| 512 |
+
|
| 513 |
+
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
| 514 |
+
diff_bool = diff != 0
|
| 515 |
+
|
| 516 |
+
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
| 517 |
+
|
| 518 |
+
second_phase_mask = (more_matches & matches)
|
| 519 |
+
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
| 520 |
+
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
| 521 |
+
|
| 522 |
+
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
| 523 |
+
matches_2_no = K.sum(matches_2)
|
| 524 |
+
|
| 525 |
+
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
| 526 |
+
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
| 527 |
+
metrics.append(acc_choose_upper)
|
| 528 |
+
|
| 529 |
+
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
| 530 |
+
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
| 531 |
+
metrics.append(acc_choose_lower)
|
| 532 |
+
|
| 533 |
+
return metrics
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class SimilarityRaven2(Model):
|
| 537 |
+
def __init__(self, mode=create_all_mask, number_loss=False):
|
| 538 |
+
super().__init__()
|
| 539 |
+
self.range_mask = RangeMask()
|
| 540 |
+
self.mode = mode
|
| 541 |
+
|
| 542 |
+
# self.predict_fn = partial(tf.argmax, axis=-1)
|
| 543 |
+
|
| 544 |
+
# INDEX, PREDICT, LABELS
|
| 545 |
+
def call(self, inputs):
|
| 546 |
+
metrics = []
|
| 547 |
+
target_index = inputs[0] - 8
|
| 548 |
+
predict = inputs[1]
|
| 549 |
+
answers = inputs[2][:, 8:]
|
| 550 |
+
shape = tf.shape(predict)
|
| 551 |
+
|
| 552 |
+
target = K.gather(answers, target_index[:, 0])
|
| 553 |
+
|
| 554 |
+
target_group = target[:, 0]
|
| 555 |
+
|
| 556 |
+
# comp_slice = np.
|
| 557 |
+
target_comp = target[:, 1:rv.target.END_INDEX]
|
| 558 |
+
predict_comp = predict[:, 1:rv.target.END_INDEX]
|
| 559 |
+
answers_comp = answers[:, :, 1:rv.target.END_INDEX]
|
| 560 |
+
|
| 561 |
+
full_properties_musks = self.mode(target)
|
| 562 |
+
fpm = K.cat([full_properties_musks[0], interleave(full_properties_musks[2:])])
|
| 563 |
+
|
| 564 |
+
range_mask = self.range_mask(target_group)
|
| 565 |
+
full_range_mask = K.cat([range_mask, tf.repeat(range_mask, 3, axis=-1)], axis=-1)
|
| 566 |
+
|
| 567 |
+
final_mask = fpm & full_range_mask
|
| 568 |
+
|
| 569 |
+
target_masked = target_comp * final_mask
|
| 570 |
+
predict_masked = predict_comp * final_mask
|
| 571 |
+
answers_masked = answers_comp * tf.tile(final_mask[:, None], [1, 8, 1])
|
| 572 |
+
|
| 573 |
+
acc_same = K.mean(K.all(target_masked == predict_masked))
|
| 574 |
+
self.add_metric(acc_same, ACC_SAME)
|
| 575 |
+
metrics.append(acc_same)
|
| 576 |
+
|
| 577 |
+
diff = tf.abs(predict_masked[:, None] - answers_masked)
|
| 578 |
+
diff_bool = diff != 0
|
| 579 |
+
|
| 580 |
+
matches, more_matches, once_matches = get_matches(tf.cast(diff_bool, dtype=tf.int32), target_index)
|
| 581 |
+
|
| 582 |
+
second_phase_mask = (more_matches & matches)
|
| 583 |
+
diff_second_phase = tf.boolean_mask(diff, second_phase_mask)
|
| 584 |
+
target_index_2 = tf.boolean_mask(target_index, second_phase_mask, axis=0)
|
| 585 |
+
|
| 586 |
+
matches_2, more_matches_2, once_matches_2 = get_matches(diff_second_phase, target_index_2)
|
| 587 |
+
matches_2_no = K.sum(matches_2)
|
| 588 |
+
|
| 589 |
+
acc_choose_upper = (once_matches + matches_2_no) / shape[0]
|
| 590 |
+
self.add_metric(acc_choose_upper, ACC_CHOOSE_UPPER)
|
| 591 |
+
metrics.append(acc_choose_upper)
|
| 592 |
+
|
| 593 |
+
acc_choose_lower = (once_matches + once_matches_2) / shape[0]
|
| 594 |
+
self.add_metric(acc_choose_lower, ACC_CHOOSE_LOWER)
|
| 595 |
+
metrics.append(acc_choose_lower)
|
| 596 |
+
|
| 597 |
+
metrics.append(K.sum(target_masked != predict_masked))
|
| 598 |
+
|
| 599 |
+
return metrics
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
class LatentLossModel(Model):
|
| 603 |
+
def __init__(self, dir_=HORIZONTAL):
|
| 604 |
+
super().__init__()
|
| 605 |
+
# self.sum_metrics = []
|
| 606 |
+
# for i in range(8):
|
| 607 |
+
# self.sum_metrics.append(Sum(name=f"no_{i}"))
|
| 608 |
+
self.metric_fn = Accuracy(name="acc_latent")
|
| 609 |
+
if dir_ == VERTICAL:
|
| 610 |
+
self.dir = (6, 7)
|
| 611 |
+
else:
|
| 612 |
+
self.dir = (2, 5)
|
| 613 |
+
|
| 614 |
+
def call(self, inputs):
|
| 615 |
+
target_image = tf.reshape(inputs[0][2], [-1])
|
| 616 |
+
output = inputs[1]
|
| 617 |
+
latents = tnp.asarray(inputs[2])
|
| 618 |
+
|
| 619 |
+
target_hor = tf.concat([
|
| 620 |
+
latents[:, self.dir],
|
| 621 |
+
latents[tf.range(latents.shape[0]), target_image + 8][:, None]
|
| 622 |
+
],
|
| 623 |
+
axis=1)
|
| 624 |
+
|
| 625 |
+
loss_hor = mse(K.stop_gradient(target_hor), output)
|
| 626 |
+
self.add_loss(loss_hor)
|
| 627 |
+
|
| 628 |
+
self.add_metric(self.metric_fn(inputs[3], target_image))
|
| 629 |
+
|
| 630 |
+
return loss_hor
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class PredRav(Model):
|
| 634 |
+
|
| 635 |
+
def call(self, inputs):
|
| 636 |
+
output = inputs[0][:, -1]
|
| 637 |
+
answers = inputs[1][:, 8:]
|
| 638 |
+
return tf.argmin(tf.reduce_sum(tf.abs(output[:, None] - answers), axis=-1), axis=-1)
|
raven_utils/models/multi_transformer.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from functools import partial
|
| 3 |
+
from tensorflow.keras.layers import Lambda
|
| 4 |
+
from tensorflow.keras.layers import Dense
|
| 5 |
+
from tensorflow.keras import Input, Model
|
| 6 |
+
from tensorflow.python.keras import Sequential
|
| 7 |
+
|
| 8 |
+
from config.constant import TRANS
|
| 9 |
+
from ml_utils import filter_init
|
| 10 |
+
from models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
| 11 |
+
from models_utils import pmodel, DictModel, bt, INPUTS, bm, OUTPUT, LATENTS, transformer, BatchModel, get_extractor, \
|
| 12 |
+
build_seq_model, BUILD, build_train_list, InitialWeight
|
| 13 |
+
from models_utils import SumPositionEmbedding, TransformerBlock, CatPositionEmbedding, transformer, BatchInitialWeight
|
| 14 |
+
import models_utils.ops as K
|
| 15 |
+
from models_utils.image import inverse_fn
|
| 16 |
+
from models_utils.ops_core import IndexReshape
|
| 17 |
+
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
| 18 |
+
from models_utils.step import StepDict
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def init_weights(shape, dtype=None):
|
| 22 |
+
return tf.cast(K.var.image(shape=shape, pre=True), dtype=tf.float32)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def conversion(x, max_=45):
|
| 26 |
+
shape = tf.shape(x)
|
| 27 |
+
return tf.reshape(x[:, :max_], tf.stack([shape[0], 9, -1]))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def take_left(x):
|
| 31 |
+
return x[..., 7:8]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def take_by_index(x, i=8):
|
| 35 |
+
return x[..., i:i + 1]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def mix(x):
|
| 39 |
+
return (x[..., 7:8] + x[..., 5:6]) / 2
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def empty_last(x):
|
| 43 |
+
return tf.zeros_like(x[..., 7:8])
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Conversion(Model):
|
| 47 |
+
def __init__(self):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.model = IndexReshape((0, "9", None))
|
| 50 |
+
|
| 51 |
+
def call(self, inputs):
|
| 52 |
+
return self.model(inputs[:, :45])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class RandomImageMask(Model):
|
| 56 |
+
def __init__(self, last, last_index=9):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.get_last = last
|
| 59 |
+
self.last_index = last_index
|
| 60 |
+
|
| 61 |
+
def call(self, inputs):
|
| 62 |
+
shape = tf.shape(inputs)
|
| 63 |
+
indexes = tf.random.uniform(shape=shape[0:1], maxval=self.last_index, dtype=tf.int32)
|
| 64 |
+
mask = tf.one_hot(indexes, self.last_index)[:, None, None]
|
| 65 |
+
|
| 66 |
+
return (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
| 67 |
+
(1, 1, 1, self.last_index))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
| 71 |
+
# (1, 1, 1, self.last_index))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# from data_utils import ims
|
| 75 |
+
# for i in range(50):
|
| 76 |
+
# ims(res[i].numpy().swapaxes(0, 2))
|
| 77 |
+
# res[12].numpy()
|
| 78 |
+
# self.get_last(inputs).numpy()
|
| 79 |
+
# import tensorflow as tf
|
| 80 |
+
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
| 81 |
+
# from ml_utils import print_error
|
| 82 |
+
# ims(mask[0].numpy())
|
| 83 |
+
# print_error(lambda :ims(mask[0]))
|
| 84 |
+
# from models_utils import ops as K
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ImageMask(Model):
|
| 88 |
+
def __init__(self, last, index=8, last_index=9):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.get_last = last
|
| 91 |
+
self.index = index
|
| 92 |
+
self.last_index = last_index
|
| 93 |
+
|
| 94 |
+
def call(self, inputs):
|
| 95 |
+
return tf.concat([inputs[..., :8], self.get_last(inputs)], axis=-1)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class CreateGrid(Model):
|
| 99 |
+
def __init__(self,
|
| 100 |
+
no=4,
|
| 101 |
+
extractor="ef",
|
| 102 |
+
type_=3,
|
| 103 |
+
base="seq",
|
| 104 |
+
last=take_left,
|
| 105 |
+
epsilon=None,
|
| 106 |
+
pooling=None,
|
| 107 |
+
mask_fn=None,
|
| 108 |
+
model=None,
|
| 109 |
+
**kwargs
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.type_ = type_
|
| 113 |
+
if type_ == 9:
|
| 114 |
+
self.start_shape = 75
|
| 115 |
+
data = (224, 224, 3)
|
| 116 |
+
conv = lambda: Conversion()
|
| 117 |
+
else:
|
| 118 |
+
self.start_shape = 84
|
| 119 |
+
data = (84, 84, 3)
|
| 120 |
+
extractor = BUILD[base]([
|
| 121 |
+
BatchModel(get_extractor(data=data, model=extractor)),
|
| 122 |
+
lambda x: tf.transpose(x, (1, 0, 2, 3, 4))
|
| 123 |
+
# lambda x: tf.tile(x[:, :224, :224], (1, 1, 1, 3))
|
| 124 |
+
])
|
| 125 |
+
conv = lambda: conversion
|
| 126 |
+
|
| 127 |
+
self.epsilon = epsilon
|
| 128 |
+
if mask_fn == "random":
|
| 129 |
+
mask_fn = RandomImageMask(last=last)
|
| 130 |
+
elif mask_fn is None:
|
| 131 |
+
mask_fn = ImageMask(last=last)
|
| 132 |
+
|
| 133 |
+
self.mask_fn = mask_fn
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def call(self, inputs):
|
| 137 |
+
transposed = tf.image.resize(tf.transpose(inputs, (0, 2, 3, 1)), (self.start_shape, self.start_shape))
|
| 138 |
+
re = self.mask_fn(transposed)
|
| 139 |
+
|
| 140 |
+
# re = tf.concat([transposed[..., :8], self.get_last(transposed)], axis=-1)
|
| 141 |
+
if self.type_ == 9:
|
| 142 |
+
x = tf.transpose(re, [0, 3, 1, 2])[..., None]
|
| 143 |
+
x = K.create_image_grid(x, 3, 3)
|
| 144 |
+
x = x[:, :224, :224]
|
| 145 |
+
x = tf.tile(x, [1, 1, 1, 3])
|
| 146 |
+
else:
|
| 147 |
+
|
| 148 |
+
x = tf.stack([
|
| 149 |
+
re[..., :3],
|
| 150 |
+
re[..., 3:6],
|
| 151 |
+
re[..., 6:9],
|
| 152 |
+
])
|
| 153 |
+
return self.model(x)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# self.model.layers[0](x)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def grid_transformer(
|
| 160 |
+
*args,
|
| 161 |
+
type_=9,
|
| 162 |
+
no=4,
|
| 163 |
+
extractor="ef",
|
| 164 |
+
loss_mode=create_uniform_mask,
|
| 165 |
+
output_size=10,
|
| 166 |
+
loss_weight=1.0,
|
| 167 |
+
out_layers=(1000, 1000, 1000),
|
| 168 |
+
pos_emd="cat",
|
| 169 |
+
base="seq",
|
| 170 |
+
inverse_image=True,
|
| 171 |
+
last="left",
|
| 172 |
+
mask_fn=None,
|
| 173 |
+
model=None,
|
| 174 |
+
trans=None,
|
| 175 |
+
**kwargs):
|
| 176 |
+
|
| 177 |
+
if last == "left":
|
| 178 |
+
last = take_left
|
| 179 |
+
elif last == "mix":
|
| 180 |
+
last = mix
|
| 181 |
+
elif last == "empty":
|
| 182 |
+
last = empty_last
|
| 183 |
+
elif last == "start":
|
| 184 |
+
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
| 185 |
+
|
| 186 |
+
create_grid = CreateGrid(
|
| 187 |
+
type_=type_,
|
| 188 |
+
no=no,
|
| 189 |
+
extractor=extractor,
|
| 190 |
+
model=model,
|
| 191 |
+
output_size=output_size,
|
| 192 |
+
out_layer=out_layers,
|
| 193 |
+
pos_emd=pos_emd,
|
| 194 |
+
base=base,
|
| 195 |
+
last=last,
|
| 196 |
+
mask_fn=mask_fn,
|
| 197 |
+
**kwargs
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
if model is None:
|
| 201 |
+
trans = transformer(
|
| 202 |
+
extractor=extractor,
|
| 203 |
+
pos_emd=pos_emd,
|
| 204 |
+
data=data,
|
| 205 |
+
output_size=output_size,
|
| 206 |
+
out_layers=out_layer,
|
| 207 |
+
pooling=conv,
|
| 208 |
+
no=no,
|
| 209 |
+
base=base,
|
| 210 |
+
**kwargs
|
| 211 |
+
# **as_dict(p.trans)
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
trans = trans
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_rav_trans(
|
| 219 |
+
*args,
|
| 220 |
+
type_=9,
|
| 221 |
+
no=4,
|
| 222 |
+
extractor="ef",
|
| 223 |
+
loss_mode=create_uniform_mask,
|
| 224 |
+
output_size=10,
|
| 225 |
+
loss_weight=1.0,
|
| 226 |
+
out_layers=(1000, 1000, 1000),
|
| 227 |
+
pos_emd="cat",
|
| 228 |
+
base="seq",
|
| 229 |
+
inverse_image=True,
|
| 230 |
+
last="left",
|
| 231 |
+
epsilon="greedy",
|
| 232 |
+
epsilon_step=500,
|
| 233 |
+
mask_fn=None,
|
| 234 |
+
model=None,
|
| 235 |
+
loss="multi",
|
| 236 |
+
**kwargs):
|
| 237 |
+
if last == "left":
|
| 238 |
+
last = take_left
|
| 239 |
+
elif last == "mix":
|
| 240 |
+
last = mix
|
| 241 |
+
elif last == "empty":
|
| 242 |
+
last = empty_last
|
| 243 |
+
elif last == "start":
|
| 244 |
+
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
| 245 |
+
|
| 246 |
+
trans_raven = CreateGrid(
|
| 247 |
+
type_=type_,
|
| 248 |
+
no=no,
|
| 249 |
+
extractor=extractor,
|
| 250 |
+
model=model,
|
| 251 |
+
output_size=output_size,
|
| 252 |
+
out_layer=out_layers,
|
| 253 |
+
pos_emd=pos_emd,
|
| 254 |
+
base=base,
|
| 255 |
+
last=last,
|
| 256 |
+
epsilon=epsilon,
|
| 257 |
+
mask_fn=mask_fn,
|
| 258 |
+
**kwargs
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if loss == "single":
|
| 262 |
+
loss = SingleVTRavenLoss
|
| 263 |
+
else:
|
| 264 |
+
loss = VTRavenLoss
|
| 265 |
+
|
| 266 |
+
return bt(
|
| 267 |
+
DictModel(
|
| 268 |
+
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
| 269 |
+
in_=INPUTS,
|
| 270 |
+
name="Body"
|
| 271 |
+
),
|
| 272 |
+
loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
| 273 |
+
loss_wrap=False
|
| 274 |
+
)
|
raven_utils/models/raven.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ml_utils import lw, lu
|
| 2 |
+
from models_utils import bm, Base, res, bt, DictModel, dense_drop, drop, build_encoder, MODEL_ARCH, ListModel, short, \
|
| 3 |
+
dense, Flatten, Cat, CatDenseBefore, \
|
| 4 |
+
CatDense, CatBefore, Drop, Flat2, down, Pass, conv, Flat, Get, bs, Res, SoftBlock
|
| 5 |
+
from models_utils import SubClassingModel
|
| 6 |
+
from models_utils.config.constants import *
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from config.constant import *
|
| 10 |
+
from tensorflow.keras.layers import Dense, Activation, BatchNormalization
|
| 11 |
+
import tensorflow as tf
|
| 12 |
+
|
| 13 |
+
import raven_utils as rv
|
| 14 |
+
|
| 15 |
+
from models.body import create_block
|
| 16 |
+
from models.class_ import Merge, RavenClass
|
| 17 |
+
from models.head import LatentHeadModel
|
| 18 |
+
|
| 19 |
+
from models.loss import RavenLoss
|
| 20 |
+
from models.trans import TransModel, FullTrans
|
| 21 |
+
from raven_utils.const import HORIZONTAL
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def raven_model(scales,
|
| 25 |
+
out_layers,
|
| 26 |
+
latent=(64, 128, 256),
|
| 27 |
+
output_size=None,
|
| 28 |
+
padding=SAME,
|
| 29 |
+
body_layers=1,
|
| 30 |
+
encoder=None,
|
| 31 |
+
loop=1,
|
| 32 |
+
model=None,
|
| 33 |
+
act=None,
|
| 34 |
+
simpler=0,
|
| 35 |
+
loss_mode=None,
|
| 36 |
+
loss_weight=0.3,
|
| 37 |
+
dir_=HORIZONTAL,
|
| 38 |
+
global_context=False,
|
| 39 |
+
images_no=8,
|
| 40 |
+
context_mul=2,
|
| 41 |
+
res_act="pass",
|
| 42 |
+
drop_latent=0,
|
| 43 |
+
drop_inference=0,
|
| 44 |
+
drop_end=0,
|
| 45 |
+
ga=False,
|
| 46 |
+
trans_norm=None,
|
| 47 |
+
trans_act="relu",
|
| 48 |
+
arch=HEAD3,
|
| 49 |
+
encoder_norm=False,
|
| 50 |
+
encoder_pool=False,
|
| 51 |
+
encoder_global="GM",
|
| 52 |
+
encoder_before=False,
|
| 53 |
+
tail_units=256,
|
| 54 |
+
tail_flatten=None,
|
| 55 |
+
# for now by default
|
| 56 |
+
tail_down="MP",
|
| 57 |
+
trans_no=1,
|
| 58 |
+
trans_score_activation=tf.nn.softmax,
|
| 59 |
+
block_=SoftBlock,
|
| 60 |
+
**kwargs):
|
| 61 |
+
if isinstance(latent, int):
|
| 62 |
+
latent = (latent, 128, 256)
|
| 63 |
+
scales = lw(scales)
|
| 64 |
+
|
| 65 |
+
context_size = np.array(latent) * context_mul
|
| 66 |
+
# context_size = latent[scales] * context_mul
|
| 67 |
+
|
| 68 |
+
# if scales == 2:
|
| 69 |
+
# arch = HEAD
|
| 70 |
+
# elif scales == 1:
|
| 71 |
+
# arch = HEAD2
|
| 72 |
+
# else:
|
| 73 |
+
# arch = VERY2
|
| 74 |
+
|
| 75 |
+
if encoder_pool:
|
| 76 |
+
strides = (1, 1)
|
| 77 |
+
else:
|
| 78 |
+
strides = (2, 2)
|
| 79 |
+
if not isinstance(encoder_before, tuple):
|
| 80 |
+
encoder_before = [encoder_before] * 3
|
| 81 |
+
|
| 82 |
+
# if trans == 1:
|
| 83 |
+
# trans_model = TransModel2
|
| 84 |
+
# else:
|
| 85 |
+
# trans_model = TransModel
|
| 86 |
+
|
| 87 |
+
# if scales == 3:
|
| 88 |
+
# head = MultiHeadModel(encoder=encoder)
|
| 89 |
+
arch = MODEL_ARCH[arch]
|
| 90 |
+
heads = []
|
| 91 |
+
for s in list(range(0, max(scales) + 1)):
|
| 92 |
+
if s in (0, 1):
|
| 93 |
+
if s == 0:
|
| 94 |
+
encoder = build_encoder(arch[:3], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
| 95 |
+
strides=strides)
|
| 96 |
+
else:
|
| 97 |
+
encoder = build_encoder(arch[3:4], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
| 98 |
+
strides=strides)
|
| 99 |
+
head = LatentHeadModel(
|
| 100 |
+
encoder=encoder,
|
| 101 |
+
inference_network=(
|
| 102 |
+
bm([
|
| 103 |
+
CatBefore(filters=int(context_size[s] / 8)) if encoder_before[s] else Cat(
|
| 104 |
+
filters=context_size[s]),
|
| 105 |
+
# todo activation?
|
| 106 |
+
Res(filters=context_size[s], padding=padding)
|
| 107 |
+
] + ([drop(drop_inference)] if drop_inference else []),
|
| 108 |
+
name="inference")
|
| 109 |
+
) if s in scales else Pass(),
|
| 110 |
+
stem=Base(
|
| 111 |
+
bm(
|
| 112 |
+
# ok we choose by parameters anyway
|
| 113 |
+
[res(filters=latent[s], padding=padding, act=act)] + (
|
| 114 |
+
[drop(drop_latent)] if drop_latent else [])
|
| 115 |
+
),
|
| 116 |
+
name="stem")
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
encoder = bm([
|
| 120 |
+
Res(),
|
| 121 |
+
build_encoder(arch[4:], add_norm=encoder_norm, add_pool=encoder_pool, kerner_size=(4, 4),
|
| 122 |
+
strides=strides),
|
| 123 |
+
short(encoder_global) if encoder_global else Flatten(),
|
| 124 |
+
dense(latent[s])
|
| 125 |
+
])
|
| 126 |
+
head = LatentHeadModel(
|
| 127 |
+
encoder=encoder,
|
| 128 |
+
inference_network=bm([
|
| 129 |
+
# todo Echeck Cat
|
| 130 |
+
CatDenseBefore(filters=int(context_size[s] / 8)) if encoder_before[
|
| 131 |
+
s] else CatDense(filters=context_size[s]),
|
| 132 |
+
# todo activation?
|
| 133 |
+
Res(model="dv2", filters=context_size[s], padding=padding)
|
| 134 |
+
] + ([dense_drop(drop_inference)] if drop_inference else []),
|
| 135 |
+
name="inference"),
|
| 136 |
+
stem=Base(
|
| 137 |
+
bm(
|
| 138 |
+
# ok we choose by parameters anyway
|
| 139 |
+
[res(model="dv2", units=latent[s], padding=padding, act=act)] + (
|
| 140 |
+
[dense_drop(drop_latent)] if drop_latent else [])
|
| 141 |
+
),
|
| 142 |
+
name="stem")
|
| 143 |
+
)
|
| 144 |
+
heads.append(head)
|
| 145 |
+
|
| 146 |
+
concat_input = [f"{LATENT}_{i}" for i, _ in enumerate(heads)] + [f"{INFERENCE}_{i}" for i, _ in enumerate(heads)]
|
| 147 |
+
concat_output = ["LATENTS", "INFERENCES"]
|
| 148 |
+
|
| 149 |
+
def head_concat(inputs):
|
| 150 |
+
latents = inputs[:len(heads)]
|
| 151 |
+
inferences = inputs[len(heads):]
|
| 152 |
+
return latents, inferences
|
| 153 |
+
|
| 154 |
+
head = ListModel([(h, (INPUTS if i == 0 else OUTPUT), [f"{LATENT}_{i}", f"{INFERENCE}_{i}", OUTPUT]) for i, h in
|
| 155 |
+
enumerate(heads)] + [
|
| 156 |
+
(head_concat, concat_input, concat_output)], out=concat_output)
|
| 157 |
+
# from rav_utils.raven import init_image
|
| 158 |
+
# a = init_image()
|
| 159 |
+
# head(a)
|
| 160 |
+
|
| 161 |
+
if model is None:
|
| 162 |
+
model = []
|
| 163 |
+
for i in scales:
|
| 164 |
+
trans_models = []
|
| 165 |
+
for t in range(trans_no):
|
| 166 |
+
trans_models.append(
|
| 167 |
+
bm(
|
| 168 |
+
[create_block(latent=latent[i], simpler=simpler, padding=padding, norm=trans_norm, act=res_act,
|
| 169 |
+
loop=loop, type_="dense" if i == 2 else "conv", block_=block_)] +
|
| 170 |
+
[Activation(trans_act)] + [
|
| 171 |
+
res(filters=latent[i],
|
| 172 |
+
padding=padding,
|
| 173 |
+
act=act,
|
| 174 |
+
name="body_out",
|
| 175 |
+
model="dv2" if i == 2 else "v2") for _ in
|
| 176 |
+
range(body_layers)] + ([Drop(drop_latent)] if drop_latent else []),
|
| 177 |
+
base_class=SubClassingModel)
|
| 178 |
+
)
|
| 179 |
+
trans_models = lu(trans_models)
|
| 180 |
+
if trans_no > 1:
|
| 181 |
+
trans_models = bm([
|
| 182 |
+
lambda x: [[x[0], x[1]], x[1]],
|
| 183 |
+
SoftBlock(
|
| 184 |
+
model=trans_models,
|
| 185 |
+
score_model=bm([
|
| 186 |
+
Flat2(filters=latent[i], units=256, res_no=2),
|
| 187 |
+
Dense(trans_no, trans_score_activation)
|
| 188 |
+
])
|
| 189 |
+
)
|
| 190 |
+
],
|
| 191 |
+
base_class=SubClassingModel
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
model.append(
|
| 195 |
+
TransModel(
|
| 196 |
+
body=trans_models,
|
| 197 |
+
dir_=dir_,
|
| 198 |
+
images_no=images_no
|
| 199 |
+
)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
tail = []
|
| 203 |
+
for i, s in enumerate(scales):
|
| 204 |
+
flatting = lambda: Flat2(filters=latent[s + 1], base_class=tail_flatten, units=tail_units)
|
| 205 |
+
if s == 0:
|
| 206 |
+
if tail_flatten is None:
|
| 207 |
+
branch = bm([res(filters=latent[s], padding=padding),
|
| 208 |
+
conv(filters=latent[s], padding=padding),
|
| 209 |
+
BatchNormalization(),
|
| 210 |
+
conv(filters=latent[s], padding=padding),
|
| 211 |
+
Flatten()])
|
| 212 |
+
else:
|
| 213 |
+
branch = bm([down(base_class=tail_down), flatting()])
|
| 214 |
+
elif s == 1:
|
| 215 |
+
if tail_flatten is None:
|
| 216 |
+
branch = bm([res(filters=latent[s], padding=padding),
|
| 217 |
+
Flatten()])
|
| 218 |
+
else:
|
| 219 |
+
branch = flatting()
|
| 220 |
+
else:
|
| 221 |
+
branch = bm([tail_units] * 2, add_flatten=False)
|
| 222 |
+
tail.append(branch)
|
| 223 |
+
|
| 224 |
+
tail.append(
|
| 225 |
+
bm([dense(tail_units)] + ([dense_drop(drop_end)] if drop_end else []) + [Dense(output_size)], add_flatten=False,
|
| 226 |
+
name=TAIL))
|
| 227 |
+
class_input = []
|
| 228 |
+
|
| 229 |
+
return bt([
|
| 230 |
+
DictModel(head, in_=INPUTS, out=[LATENT, INFERENCE], name="Head"),
|
| 231 |
+
DictModel(FullTrans(model, scales=scales), in_=[LATENT, INFERENCE], out=TRANS, name="Body"),
|
| 232 |
+
DictModel(RavenClass(Merge(tail), scales=scales, no=8), in_=[LATENT] + class_input, out=CLASSIFICATION,
|
| 233 |
+
name="Classificator"),
|
| 234 |
+
DictModel(RavenClass(Merge(tail), scales=list(range(len(scales))), no=3), in_=[TRANS] + class_input,
|
| 235 |
+
out=OUTPUT, name="Classificator_trans"),
|
| 236 |
+
],
|
| 237 |
+
loss=RavenLoss(mode=loss_mode, classification=True, trans=True, lw=(1.0, loss_weight)),
|
| 238 |
+
loss_wrap=False
|
| 239 |
+
)
|
raven_utils/models/trans.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from ml_utils import lw
|
| 3 |
+
from models_utils import ops as K, SubClassingModel
|
| 4 |
+
from tensorflow.keras import Model
|
| 5 |
+
|
| 6 |
+
from models.body import create_dense_block
|
| 7 |
+
import raven_utils as rv
|
| 8 |
+
from raven_utils.const import HORIZONTAL, VERTICAL
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TransModel(Model):
|
| 12 |
+
def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = body or create_dense_block(latent=latent)
|
| 15 |
+
if dir_ == VERTICAL:
|
| 16 |
+
self.dir = (0, 3, 1, 4, 3, 5)
|
| 17 |
+
else:
|
| 18 |
+
self.dir = (0, 1, 3, 4, 6, 7)
|
| 19 |
+
self.images_no = images_no
|
| 20 |
+
self.latent = latent
|
| 21 |
+
|
| 22 |
+
def call(self, inputs):
|
| 23 |
+
# latents = tnp.asarray(inputs[0])
|
| 24 |
+
latents = inputs[0]
|
| 25 |
+
inference = inputs[1]
|
| 26 |
+
shape = tf.shape(latents)
|
| 27 |
+
new_shape = K.cat([[-1, 3, 2], shape[2:]])
|
| 28 |
+
horizontal = latents[:, self.dir].reshape(new_shape)
|
| 29 |
+
res = tf.TensorArray(tf.float32, size=3)
|
| 30 |
+
for i in range(3):
|
| 31 |
+
res = res.write(i, self.model([horizontal[:, i], inference]))
|
| 32 |
+
result = K.tran(res.stack())
|
| 33 |
+
return result
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TransModel2(Model):
|
| 37 |
+
def __init__(self, body=None, dir_=HORIZONTAL, images_no=8, latent=64):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.body = body or create_dense_block(latent=latent)
|
| 40 |
+
if dir_ == VERTICAL:
|
| 41 |
+
self.dir = (0, 3, 1, 4, 3, 5)
|
| 42 |
+
else:
|
| 43 |
+
self.dir = (0, 1, 3, 4, 6, 7)
|
| 44 |
+
self.images_no = images_no
|
| 45 |
+
self.latent = latent
|
| 46 |
+
|
| 47 |
+
def call(self, inputs):
|
| 48 |
+
# latents = tnp.asarray(inputs[0])
|
| 49 |
+
latents = inputs[0]
|
| 50 |
+
inference = inputs[1]
|
| 51 |
+
shape = tf.shape(latents)
|
| 52 |
+
new_shape = K.cat([[-1, 3, 2], shape[2:]])
|
| 53 |
+
horizontal = latents[:, self.dir].reshape(new_shape)
|
| 54 |
+
res = tf.TensorArray(tf.float32, size=3)
|
| 55 |
+
for i in tf.range(3):
|
| 56 |
+
res = res.write(i, self.body([horizontal[:, i], inference[:,i]]))
|
| 57 |
+
result = K.tran(res.stack())
|
| 58 |
+
return result
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class FullTrans(SubClassingModel):
|
| 62 |
+
def __init__(self, model,scales,name=None):
|
| 63 |
+
super().__init__(model=model,name=name)
|
| 64 |
+
self.scales = scales
|
| 65 |
+
|
| 66 |
+
def call(self, inputs):
|
| 67 |
+
latent = lw(inputs[0])
|
| 68 |
+
inference = lw(inputs[1])
|
| 69 |
+
results = []
|
| 70 |
+
# todo merging inference?
|
| 71 |
+
for i,s in enumerate(self.scales):
|
| 72 |
+
# results.append(model([latent[::-1][i], inference]))
|
| 73 |
+
results.append(self.model[i]([latent[s], inference[s]]))
|
| 74 |
+
return results,
|
raven_utils/models/transformer.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras.layers import Lambda
|
| 3 |
+
from tensorflow.python.keras import Sequential
|
| 4 |
+
|
| 5 |
+
# from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
| 6 |
+
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight
|
| 7 |
+
import models_utils.ops as K
|
| 8 |
+
from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
|
| 9 |
+
from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
|
| 10 |
+
from models_utils.ops_core import IndexReshape
|
| 11 |
+
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
| 12 |
+
from models_utils.step import StepDict
|
| 13 |
+
|
| 14 |
+
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
| 15 |
+
# (1, 1, 1, self.last_index))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# from data_utils import ims
|
| 19 |
+
# for i in range(50):
|
| 20 |
+
# ims(res[i].numpy().swapaxes(0, 2))
|
| 21 |
+
# res[12].numpy()
|
| 22 |
+
# self.get_last(inputs).numpy()
|
| 23 |
+
# import tensorflow as tf
|
| 24 |
+
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
| 25 |
+
# from ml_utils import print_error
|
| 26 |
+
# ims(mask[0].numpy())
|
| 27 |
+
# print_error(lambda :ims(mask[0]))
|
| 28 |
+
# from models_utils import ops as K
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# self.model.layers[0](x)
|
| 32 |
+
from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_rav_trans(
|
| 36 |
+
data,
|
| 37 |
+
type_=9,
|
| 38 |
+
no=4,
|
| 39 |
+
extractor="ef",
|
| 40 |
+
loss_mode=create_uniform_mask,
|
| 41 |
+
output_size=10,
|
| 42 |
+
loss_weight=1.0,
|
| 43 |
+
out_layers=(1000, 1000, 1000),
|
| 44 |
+
pos_emd="cat",
|
| 45 |
+
base="seq",
|
| 46 |
+
inverse_image=True,
|
| 47 |
+
last="left",
|
| 48 |
+
epsilon="greedy",
|
| 49 |
+
epsilon_step=500,
|
| 50 |
+
mask_fn=None,
|
| 51 |
+
model=None,
|
| 52 |
+
loss="multi",
|
| 53 |
+
**kwargs):
|
| 54 |
+
if last == "left":
|
| 55 |
+
last = take_left
|
| 56 |
+
elif last == "mix":
|
| 57 |
+
last = mix
|
| 58 |
+
elif last == "empty":
|
| 59 |
+
last = empty_last
|
| 60 |
+
elif last == "start":
|
| 61 |
+
last = Sequential([Lambda(empty_last), BatchInitialWeight(initializer=init_weights)])
|
| 62 |
+
|
| 63 |
+
if epsilon == "greedy":
|
| 64 |
+
epsilon = EpsilonGreedy(step=epsilon_step)
|
| 65 |
+
elif epsilon == "soft":
|
| 66 |
+
epsilon = EpsilonSoft(step=epsilon_step)
|
| 67 |
+
elif epsilon is False:
|
| 68 |
+
epsilon = None
|
| 69 |
+
|
| 70 |
+
if epsilon:
|
| 71 |
+
trans_raven = TransRavenwithStep(
|
| 72 |
+
type_=type_,
|
| 73 |
+
no=no,
|
| 74 |
+
extractor=extractor,
|
| 75 |
+
output_size=output_size,
|
| 76 |
+
out_layer=out_layers,
|
| 77 |
+
pos_emd=pos_emd,
|
| 78 |
+
base=base,
|
| 79 |
+
last=last,
|
| 80 |
+
epsilon=epsilon,
|
| 81 |
+
**kwargs
|
| 82 |
+
)
|
| 83 |
+
return StepDict(bt(
|
| 84 |
+
DictModel(
|
| 85 |
+
Sequential([Lambda(lambda x: (255 - x[0], x[1])), trans_raven]) if inverse_image else trans_raven,
|
| 86 |
+
in_=[INPUTS, "step"],
|
| 87 |
+
name="Body"
|
| 88 |
+
),
|
| 89 |
+
loss=VTRavenLoss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
| 90 |
+
loss_wrap=False),
|
| 91 |
+
add_step=epsilon_step,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
trans_raven = img_sec_trans(
|
| 95 |
+
type_=type_,
|
| 96 |
+
no=no,
|
| 97 |
+
extractor=extractor,
|
| 98 |
+
model=model,
|
| 99 |
+
output_size=output_size,
|
| 100 |
+
out_layer=out_layers,
|
| 101 |
+
pos_emd=pos_emd,
|
| 102 |
+
base=base,
|
| 103 |
+
last=last,
|
| 104 |
+
epsilon=epsilon,
|
| 105 |
+
mask_fn=mask_fn,
|
| 106 |
+
**kwargs
|
| 107 |
+
)
|
| 108 |
+
if loss == "single":
|
| 109 |
+
loss = SingleVTRavenLoss
|
| 110 |
+
else:
|
| 111 |
+
loss = VTRavenLoss
|
| 112 |
+
|
| 113 |
+
# return bt(
|
| 114 |
+
# DictModel(
|
| 115 |
+
# Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
| 116 |
+
# inputs=INPUTS,
|
| 117 |
+
# name="Body"
|
| 118 |
+
# ),
|
| 119 |
+
# loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
| 120 |
+
# loss_wrap=False
|
| 121 |
+
# )
|
| 122 |
+
|
| 123 |
+
return bt([
|
| 124 |
+
DictModel(
|
| 125 |
+
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
| 126 |
+
in_=INPUTS,
|
| 127 |
+
name="Body"
|
| 128 |
+
),
|
| 129 |
+
|
| 130 |
+
],
|
| 131 |
+
loss=loss(mode=loss_mode, classification=True, lw=(loss_weight, 1.0)),
|
| 132 |
+
loss_wrap=False
|
| 133 |
+
)
|
raven_utils/models/transformer_2.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from tensorflow.keras.layers import Lambda
|
| 5 |
+
from tensorflow.python.keras import Sequential
|
| 6 |
+
from models_utils import ops as K, SubClassing
|
| 7 |
+
from models_utils.models.transformer import aug
|
| 8 |
+
|
| 9 |
+
# from models_utils.models.loss import VTRavenLoss, create_uniform_mask, SingleVTRavenLoss
|
| 10 |
+
from data_utils import DataGenerator, LOSS, TARGET, IMAGES
|
| 11 |
+
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional_model, get_input_layer
|
| 12 |
+
import models_utils.ops as K
|
| 13 |
+
from models_utils.models.transformer.img_seq import init_weights, take_left, mix, empty_last
|
| 14 |
+
from models_utils.models.transformer.img_seq2 import init_weights, take_left, mix, empty_last, img_sec_trans
|
| 15 |
+
from models_utils.ops_core import IndexReshape
|
| 16 |
+
from models_utils.random_ import EpsilonGreedy, EpsilonSoft
|
| 17 |
+
from models_utils.step import StepDict
|
| 18 |
+
|
| 19 |
+
from models_utils.models.transformer import aug
|
| 20 |
+
|
| 21 |
+
# res = (1 - mask) * inputs[..., :self.last_index] + mask * tf.tile(self.get_last(inputs),
|
| 22 |
+
# (1, 1, 1, self.last_index))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# from data_utils import ims
|
| 26 |
+
# for i in range(50):
|
| 27 |
+
# ims(res[i].numpy().swapaxes(0, 2))
|
| 28 |
+
# res[12].numpy()
|
| 29 |
+
# self.get_last(inputs).numpy()
|
| 30 |
+
# import tensorflow as tf
|
| 31 |
+
# tf.random.uniform(shape=shape[0:1], maxval=255, dtype=tf.int32)
|
| 32 |
+
# from ml_utils import print_error
|
| 33 |
+
# ims(mask[0].numpy())
|
| 34 |
+
# print_error(lambda :ims(mask[0]))
|
| 35 |
+
# from models_utils import ops as K
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# self.model.layers[0](x)
|
| 39 |
+
from raven_utils.constant import INDEX, LABELS
|
| 40 |
+
from raven_utils.models.loss import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_matrix(inputs, index):
|
| 44 |
+
return tf.concat([inputs[:, :8], K.gather(inputs, index[:, 0])[:, None]], axis=1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_images(inputs):
|
| 48 |
+
return get_matrix(inputs[0], inputs[1])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def random_last(inputs, max_=8):
|
| 52 |
+
index = K.init.label(max=max_, shape=[tf.shape(inputs[0])[0]])[..., None]
|
| 53 |
+
return get_matrix(inputs[0], index)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_images_no_answer(inputs):
|
| 57 |
+
return inputs[0][:, :9]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def repeat_last(inputs):
|
| 61 |
+
return inputs[0][:, list(range(8)) + [7]]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_rav_trans(
|
| 65 |
+
data,
|
| 66 |
+
inverse_image=True,
|
| 67 |
+
loss_mode=create_uniform_mask,
|
| 68 |
+
loss_weight=1.0,
|
| 69 |
+
loss="multi",
|
| 70 |
+
number_loss=False,
|
| 71 |
+
plw=None,
|
| 72 |
+
pre="auto",
|
| 73 |
+
augmentation=None,
|
| 74 |
+
**kwargs):
|
| 75 |
+
if isinstance(data, DataGenerator):
|
| 76 |
+
data = data[0]['inputs'], data[0]['index']
|
| 77 |
+
# u = img_sec_trans(**kwargs)(get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data))
|
| 78 |
+
# u.shape
|
| 79 |
+
from keras import Model
|
| 80 |
+
if pre == "auto":
|
| 81 |
+
pre = get_images if kwargs['mask'] == "random" else get_images_no_answer
|
| 82 |
+
elif pre == "no_answer":
|
| 83 |
+
pre = get_images_no_answer
|
| 84 |
+
elif pre == "last":
|
| 85 |
+
pre = repeat_last
|
| 86 |
+
elif pre == "images":
|
| 87 |
+
pre = get_images
|
| 88 |
+
elif pre == "random_last":
|
| 89 |
+
pre = random_last
|
| 90 |
+
elif pre == "noise":
|
| 91 |
+
pre = SubClassing([get_matrix, partial(aug.noise, max_=8)])
|
| 92 |
+
elif pre == "batch_noise":
|
| 93 |
+
pre = SubClassing([get_matrix, partial(aug.batch_noise, max_=8)])
|
| 94 |
+
|
| 95 |
+
if augmentation == "transpose":
|
| 96 |
+
augmentation = aug.Transpose(axis=(0, 2, 1))
|
| 97 |
+
augmentation_label = aug.Transpose(axis=(0, 2, 1))
|
| 98 |
+
elif augmentation == "shuffle_col":
|
| 99 |
+
augmentation = aug.shuffle_col
|
| 100 |
+
augmentation_label = aug.shuffle_col
|
| 101 |
+
elif augmentation == "shuffle":
|
| 102 |
+
augmentation = aug.shuffle
|
| 103 |
+
augmentation_label = aug.shuffle
|
| 104 |
+
if augmentation:
|
| 105 |
+
augmentation = [
|
| 106 |
+
# DictModel(augmentation, IMAGES, IMAGES),
|
| 107 |
+
# DictModel(aug.reshape_static(pre(data),augmentation), IMAGES, IMAGES),
|
| 108 |
+
DictModel(aug.ReshapeStatic(augmentation), IMAGES, IMAGES),
|
| 109 |
+
DictModel(
|
| 110 |
+
aug.PartialModel(
|
| 111 |
+
aug.ReshapeStatic(augmentation_label),
|
| 112 |
+
last_axis=9)
|
| 113 |
+
, LABELS, LABELS)
|
| 114 |
+
]
|
| 115 |
+
else:
|
| 116 |
+
augmentation = []
|
| 117 |
+
|
| 118 |
+
trans_raven = build_functional_model(
|
| 119 |
+
img_sec_trans(**kwargs),
|
| 120 |
+
# get_images(data) if kwargs['mask'] == "random" else get_images_no_answer(data)
|
| 121 |
+
pre(data)
|
| 122 |
+
# data[0]
|
| 123 |
+
)
|
| 124 |
+
if loss == "single":
|
| 125 |
+
loss = SingleVTRavenLoss
|
| 126 |
+
else:
|
| 127 |
+
loss = VTRavenLoss
|
| 128 |
+
if isinstance(loss_weight, float):
|
| 129 |
+
loss_weight = (loss_weight, 1.0)
|
| 130 |
+
|
| 131 |
+
return bt([
|
| 132 |
+
# DictModel(get_images if kwargs['mask'] == "random" else get_images_no_answer, [INPUTS, INDEX], IMAGES),
|
| 133 |
+
DictModel(pre, [INPUTS, INDEX], IMAGES),
|
| 134 |
+
*augmentation,
|
| 135 |
+
DictModel(
|
| 136 |
+
Sequential([Lambda(lambda x: 255 - x), trans_raven]) if inverse_image else trans_raven,
|
| 137 |
+
in_=IMAGES,
|
| 138 |
+
# inputs=INPUTS,
|
| 139 |
+
name="Body"
|
| 140 |
+
),
|
| 141 |
+
|
| 142 |
+
],
|
| 143 |
+
loss=loss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
|
| 144 |
+
predict=LOSS,
|
| 145 |
+
loss_wrap=False
|
| 146 |
+
)
|
raven_utils/models/transformer_3.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from tensorflow.keras.layers import Lambda
|
| 5 |
+
from tensorflow.keras.layers import Activation
|
| 6 |
+
|
| 7 |
+
from grid_transformer import aug_trans
|
| 8 |
+
from raven_utils.models.loss_3 import VTRavenLoss, SingleVTRavenLoss, create_uniform_mask
|
| 9 |
+
from data_utils import get_shape, TakeDict
|
| 10 |
+
|
| 11 |
+
from data_utils import DataGenerator, LOSS, TARGET, IMAGES
|
| 12 |
+
from models_utils import DictModel, bt, INPUTS, BatchInitialWeight, build_functional, get_input_layer, Last, bm, \
|
| 13 |
+
add_end, AUGMENTATION
|
| 14 |
+
# from report.select_ import SelectModel2, SelectModel, SelectModel9
|
| 15 |
+
from experiment_utils.keras_model import load_weights as model_load_weights
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_rav_trans(
|
| 19 |
+
data,
|
| 20 |
+
loss_mode=create_uniform_mask,
|
| 21 |
+
loss_weight=2.0,
|
| 22 |
+
number_loss=False,
|
| 23 |
+
dry_run="auto",
|
| 24 |
+
plw=None,
|
| 25 |
+
**kwargs):
|
| 26 |
+
if isinstance(loss_weight, float):
|
| 27 |
+
loss_weight = (loss_weight, 1.0)
|
| 28 |
+
|
| 29 |
+
# seq_trans(**kwargs)(data[0])
|
| 30 |
+
# trans_raven = build_functional_model2(
|
| 31 |
+
# seq_trans(**kwargs),
|
| 32 |
+
# data[0],
|
| 33 |
+
# batch=None
|
| 34 |
+
# )
|
| 35 |
+
trans_raven = build_functional(
|
| 36 |
+
model=aug_trans,
|
| 37 |
+
inputs_=data[0] if isinstance(data, DataGenerator) else data,
|
| 38 |
+
batch_=None,
|
| 39 |
+
dry_run=dry_run,
|
| 40 |
+
**kwargs
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return bt(
|
| 44 |
+
model=trans_raven,
|
| 45 |
+
loss=VTRavenLoss(mode=loss_mode, classification=True, number_loss=number_loss, lw=loss_weight, plw=plw),
|
| 46 |
+
model_wrap=False,
|
| 47 |
+
predict=LOSS,
|
| 48 |
+
loss_wrap=False
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def rav_select_model(
|
| 53 |
+
data,
|
| 54 |
+
load_weights=None,
|
| 55 |
+
loss_weight=(0.01, 0.0),
|
| 56 |
+
plw=5.0,
|
| 57 |
+
result_metric="sparse_categorical_accuracy",
|
| 58 |
+
select_type=2,
|
| 59 |
+
select_out=0,
|
| 60 |
+
additional_out=0,
|
| 61 |
+
additional_copy=True,
|
| 62 |
+
tail_out=(1000, 1000),
|
| 63 |
+
**kwargs
|
| 64 |
+
):
|
| 65 |
+
out_layers = Last()
|
| 66 |
+
if additional_out > 0:
|
| 67 |
+
model3 = get_rav_trans(
|
| 68 |
+
data,
|
| 69 |
+
plw=plw,
|
| 70 |
+
loss_weight=loss_weight,
|
| 71 |
+
**kwargs
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
model_load_weights(
|
| 75 |
+
model3,
|
| 76 |
+
load_weights,
|
| 77 |
+
# sample_data,
|
| 78 |
+
None,
|
| 79 |
+
template="weights_{epoch:02d}-{val_loss:.2f}",
|
| 80 |
+
key=result_metric,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
|
| 84 |
+
index = -1
|
| 85 |
+
else:
|
| 86 |
+
index = -2
|
| 87 |
+
|
| 88 |
+
out = model3[0, index, :additional_out]
|
| 89 |
+
logger.info(f"Additional out from: {model3[0, index]}.")
|
| 90 |
+
|
| 91 |
+
if additional_out > 2:
|
| 92 |
+
out += [Activation("gelu")]
|
| 93 |
+
out_layers = bm([out_layers] + out, add_flatten=False)
|
| 94 |
+
model = get_rav_trans(
|
| 95 |
+
TakeDict(data[0])[:, 8:],
|
| 96 |
+
plw=plw,
|
| 97 |
+
loss_weight=loss_weight,
|
| 98 |
+
**{
|
| 99 |
+
**kwargs,
|
| 100 |
+
"out_layers": out_layers,
|
| 101 |
+
}
|
| 102 |
+
# **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
| 103 |
+
)
|
| 104 |
+
# from data_utils.ops import Equal
|
| 105 |
+
# o = []
|
| 106 |
+
# for i in range(1, 3):
|
| 107 |
+
# for j in range(2):
|
| 108 |
+
# o.append(
|
| 109 |
+
# Equal(
|
| 110 |
+
# # model[0,:,-2, i].variables[j],
|
| 111 |
+
# model2[0, :, -2, i].variables[j],
|
| 112 |
+
# # out_layers[i].variables[j]
|
| 113 |
+
# second_pooling[i].variables[j]
|
| 114 |
+
# ).equal
|
| 115 |
+
# )
|
| 116 |
+
# assert all(o)
|
| 117 |
+
# model = get_rav_trans(
|
| 118 |
+
# # TakeDict(val_generator[0])[:, 8:],
|
| 119 |
+
# # TakeDict(val_generator[0])[:, 8:],
|
| 120 |
+
# val_generator[0],
|
| 121 |
+
# plw=p.plw,
|
| 122 |
+
# loss_weight=p.loss_weight,
|
| 123 |
+
# **{**as_dict(p.mp),
|
| 124 |
+
# # "out_layers": out_layers,
|
| 125 |
+
# }
|
| 126 |
+
# # **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
| 127 |
+
# )
|
| 128 |
+
model_load_weights(model,
|
| 129 |
+
load_weights,
|
| 130 |
+
# sample_data,
|
| 131 |
+
None,
|
| 132 |
+
template="weights_{epoch:02d}-{val_loss:.2f}",
|
| 133 |
+
key=result_metric,
|
| 134 |
+
)
|
| 135 |
+
# model.compile()
|
| 136 |
+
# model.evaluate(val_generator.data[:1000])
|
| 137 |
+
# model(TakeDict(val_generator[0])[:, 8:])
|
| 138 |
+
trans_raven = model[0]
|
| 139 |
+
# s = trans_raven(TakeDict(val_generator[0])[:, 8:])
|
| 140 |
+
if select_type == 2:
|
| 141 |
+
second_pooling = Lambda(lambda x: x[:, :-1])
|
| 142 |
+
else:
|
| 143 |
+
second_pooling = Last()
|
| 144 |
+
if additional_out > 0:
|
| 145 |
+
if additional_copy:
|
| 146 |
+
model4 = get_rav_trans(
|
| 147 |
+
data,
|
| 148 |
+
plw=plw,
|
| 149 |
+
loss_weight=loss_weight,
|
| 150 |
+
**kwargs
|
| 151 |
+
)
|
| 152 |
+
model_load_weights(model4,
|
| 153 |
+
load_weights,
|
| 154 |
+
# sample_data,
|
| 155 |
+
None,
|
| 156 |
+
template="weights_{epoch:02d}-{val_loss:.2f}",
|
| 157 |
+
key=result_metric,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if AUGMENTATION in kwargs and kwargs[AUGMENTATION] is not None:
|
| 161 |
+
index = -1
|
| 162 |
+
else:
|
| 163 |
+
index = -2
|
| 164 |
+
out2 = model4[0, index, :additional_out]
|
| 165 |
+
logger.info(f"Additional out from: {model4[0, index]}.")
|
| 166 |
+
|
| 167 |
+
if additional_out > 2:
|
| 168 |
+
out2 += [Activation("gelu")]
|
| 169 |
+
else:
|
| 170 |
+
out2 = out
|
| 171 |
+
|
| 172 |
+
second_pooling = bm([second_pooling] + out2, add_flatten=False)
|
| 173 |
+
|
| 174 |
+
model2 = get_rav_trans(
|
| 175 |
+
TakeDict(data[0])[:, 8:],
|
| 176 |
+
plw=plw,
|
| 177 |
+
loss_weight=loss_weight,
|
| 178 |
+
**{
|
| 179 |
+
**kwargs,
|
| 180 |
+
"out_layers": second_pooling,
|
| 181 |
+
}
|
| 182 |
+
# **{**as_dict(p.mp), "show_shape": True, "save_shape": f"output/shapes/type_{p.mp.type_}.json"},
|
| 183 |
+
)
|
| 184 |
+
model_load_weights(
|
| 185 |
+
model2,
|
| 186 |
+
load_weights,
|
| 187 |
+
# sample_data,
|
| 188 |
+
None,
|
| 189 |
+
template="weights_{epoch:02d}-{val_loss:.2f}",
|
| 190 |
+
key=result_metric,
|
| 191 |
+
)
|
| 192 |
+
if select_type == 0:
|
| 193 |
+
# not working
|
| 194 |
+
trans_raven2 = model2[0]
|
| 195 |
+
else:
|
| 196 |
+
trans_raven2 = model2[0]
|
| 197 |
+
tail = add_end(out_layers=tail_out, output_size=8 if select_out else 1)
|
| 198 |
+
# trans_raven2.mask_fn = ImageMask(last=take_by_index)
|
| 199 |
+
if select_type == 2:
|
| 200 |
+
select_model_class = SelectModel2
|
| 201 |
+
elif select_type == 1:
|
| 202 |
+
select_model_class = SelectModel
|
| 203 |
+
else:
|
| 204 |
+
select_model_class = SelectModel9
|
| 205 |
+
select_model = select_model_class(trans_raven, model2=trans_raven2, tail=tail, select_out=select_out)
|
| 206 |
+
return select_model
|
raven_utils/models/uitls_.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import tensorflow.experimental.numpy as tnp
|
| 3 |
+
from tensorflow.keras import Model
|
| 4 |
+
import raven_utils as rv
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RangeMask(Model):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1])
|
| 11 |
+
start_index = rv.entity.INDEX[:-1][:, None]
|
| 12 |
+
end_index = rv.entity.INDEX[1:][:, None]
|
| 13 |
+
self.mask = tnp.array((start_index <= ranges) & (ranges < end_index))
|
| 14 |
+
|
| 15 |
+
def call(self, inputs):
|
| 16 |
+
return self.mask[inputs]
|
raven_utils/output.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import raven_utils.entity as entity
|
| 3 |
+
import raven_utils.properties as properties
|
| 4 |
+
import raven_utils.group as group
|
| 5 |
+
|
| 6 |
+
SIZE = entity.SUM * properties.SUM + group.NO + entity.SUM
|
| 7 |
+
|
| 8 |
+
SLOT_AND_GROUP = group.NO + entity.SUM
|
| 9 |
+
|
| 10 |
+
PROPERTIES_SLICE = np.s_[:, :-SLOT_AND_GROUP]
|
| 11 |
+
SLOT_SLICE = np.s_[:, -SLOT_AND_GROUP:-group.NO]
|
| 12 |
+
GROUP_SLICE = np.s_[:, -group.NO:]
|
| 13 |
+
|
| 14 |
+
GROUP_SLICE_END = np.s_[-group.NO:]
|
| 15 |
+
SLOT_SLICE_END = np.s_[-SLOT_AND_GROUP:-group.NO]
|
| 16 |
+
PROPERTIES_SLICE_END = np.s_[:-SLOT_AND_GROUP]
|
raven_utils/params.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, Tuple
|
| 3 |
+
|
| 4 |
+
from ml_utils import get_str_name
|
| 5 |
+
from grid_transformer.params import ImgSeqTransformerParameters
|
| 6 |
+
from raven_utils import output
|
| 7 |
+
|
| 8 |
+
from experiment_utils.parameters.nn_default import TP, EP
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SudokuParameters(ImgSeqTransformerParameters):
|
| 13 |
+
mask: str = "input"
|
| 14 |
+
col: int = 3
|
| 15 |
+
row: int = 3
|
| 16 |
+
pooling: int = 81
|
| 17 |
+
output_size: int = 9
|
| 18 |
+
size: int = 384
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class RavenTransParameters(ImgSeqTransformerParameters):
|
| 23 |
+
mask: str = "last"
|
| 24 |
+
last_index: int = 8
|
| 25 |
+
col: int = 1
|
| 26 |
+
row: int = 1
|
| 27 |
+
output_size: int = output.SIZE
|
| 28 |
+
number_loss: bool = 0
|
| 29 |
+
pre: str = "images"
|
| 30 |
+
num_heads: int = 8
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
MP = RavenTransParameters
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class RavenSelectTransParameters(RavenTransParameters):
|
| 38 |
+
select_type: int = 2
|
| 39 |
+
select_out: int = 0
|
| 40 |
+
additional_out: int = 0
|
| 41 |
+
additional_copy: bool = True
|
| 42 |
+
tail_out: Tuple = (1000, 1000)
|
| 43 |
+
pre: str = "index"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
SMP = RavenSelectTransParameters
|
| 47 |
+
|
| 48 |
+
from raven_utils.config.models import AVAILABLE_MODELS
|
| 49 |
+
from experiment_utils.parameters.nn_clean import Parameters as BaseParameters
|
| 50 |
+
from raven_utils.config.constant import RAVEN, LABELS, INDEX, FEATURES, RAV_METRICS, IMP_RAV_METRICS, ACC_NO_GROUP, \
|
| 51 |
+
ACC_SAME
|
| 52 |
+
|
| 53 |
+
MODEL_NO = -1
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class RavenParameters(BaseParameters):
|
| 58 |
+
dataset_name: str = RAVEN
|
| 59 |
+
data: Any = (
|
| 60 |
+
f"{dataset_name}/train.npy",
|
| 61 |
+
f"{dataset_name}/val.npy",
|
| 62 |
+
f"{dataset_name}/train_labels.npy",
|
| 63 |
+
f"{dataset_name}/val_labels.npy",
|
| 64 |
+
f"{dataset_name}/train_target.npy",
|
| 65 |
+
f"{dataset_name}/val_target.npy",
|
| 66 |
+
f"arr/train_features_{AVAILABLE_MODELS[MODEL_NO]}.npy",
|
| 67 |
+
f"arr/val_features_{AVAILABLE_MODELS[MODEL_NO]}.npy",
|
| 68 |
+
f"{dataset_name}/val_index.npy"
|
| 69 |
+
# DataParameters2()
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# core_metrics: tuple = tuple(RAV_METRICS)
|
| 73 |
+
filter_metrics: tuple = tuple(IMP_RAV_METRICS)
|
| 74 |
+
# result_metric: str = ACC_NO_GROUP
|
| 75 |
+
result_metric: str = ACC_SAME
|
| 76 |
+
|
| 77 |
+
lw: float = 0.0001 # Autoencoder
|
| 78 |
+
loss_weight: float = 2.0
|
| 79 |
+
plw: int = 5.0
|
| 80 |
+
mp: RavenTransParameters = RavenTransParameters()
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def experiment(self):
|
| 84 |
+
# return "rav/trans"
|
| 85 |
+
return "rav/best_test3"
|
| 86 |
+
# return "rav/trans_weight"
|
| 87 |
+
|
| 88 |
+
# @property
|
| 89 |
+
# def name(self):
|
| 90 |
+
# # return f"i{self.extractor}_{len(self.tail)}{self.tail[0]}_{self.type_}_{self.epsilon}_{self.last}_{self.epsilon_step}"
|
| 91 |
+
# return f"{get_str_name(self.mp.pre)[0]}_{str(self.plw)[0]}_{str(self.mp.number_loss)[0]}_{self.mp.extractor}_{self.mp.noise if self.mp.noise else ''}_{self.mp.augmentation if self.mp.augmentation else ''}_{self.mp.extractor_shape}_{self.mp.no}_{self.mp.num_heads}_{self.mp.size}_{self.mp.pos_emd}_{self.mp.ff_mul}_{self.tp.batch}"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class BaselineRavenParameters(RavenParameters):
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def experiment(self):
|
| 99 |
+
# return "rav/best_test3"
|
| 100 |
+
return "rav/baseline"
|
| 101 |
+
# return "rav/trans_weight"
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def name(self):
|
| 105 |
+
# return f"i{self.extractor}_{len(self.tail)}{self.tail[0]}_{self.type_}_{self.epsilon}_{self.last}_{self.epsilon_step}"
|
| 106 |
+
return f"{get_str_name(self.mp.pre)[0]}_{str(self.plw)[0]}_{str(self.mp.number_loss)[0]}_{self.mp.extractor}_{self.mp.noise if self.mp.noise else ''}_{self.mp.augmentation if self.mp.augmentation else ''}_{self.mp.extractor_shape}_{self.mp.no}_{self.mp.num_heads}_{self.mp.size}_{self.mp.pos_emd}_{self.mp.ff_mul}_{self.tp.batch}"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == '__main__':
|
| 110 |
+
params = PreRavenTransParameters()
|
raven_utils/properties.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import raven_utils as rv
|
| 2 |
+
from ml_utils import dict_from_list2, CalcDict
|
| 3 |
+
import raven_utils.entity as entity
|
| 4 |
+
|
| 5 |
+
NAMES = [
|
| 6 |
+
'Color',
|
| 7 |
+
'Size',
|
| 8 |
+
'Type',
|
| 9 |
+
]
|
| 10 |
+
RAW_SIZE = [10, 6, 5]
|
| 11 |
+
SIZE = dict_from_list2(NAMES, RAW_SIZE)
|
| 12 |
+
ANGLE_SIZE = 7
|
| 13 |
+
NO = len(NAMES)
|
| 14 |
+
|
| 15 |
+
INDEX = (CalcDict(SIZE) * entity.SUM).to_dict()
|
| 16 |
+
SUM = sum(list(SIZE.values()))
|
raven_utils/range_mask.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import tensorflow.experimental.numpy as tnp
|
| 3 |
+
from tensorflow.keras import Model
|
| 4 |
+
import raven_utils as rv
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RangeMask(Model):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
ranges = tf.tile(tf.range(rv.entity.INDEX[-1])[None], [rv.group.NO, 1])
|
| 11 |
+
start_index = rv.entity.INDEX[:-1][:, None]
|
| 12 |
+
end_index = rv.entity.INDEX[1:][:, None]
|
| 13 |
+
self.mask = tnp.array((start_index <= ranges) & (ranges < end_index))
|
| 14 |
+
|
| 15 |
+
def call(self, inputs):
|
| 16 |
+
return self.mask[inputs]
|
raven_utils/render/__init__.py
ADDED
|
File without changes
|
raven_utils/render/const.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# Maximum number of components in a RPM
|
| 5 |
+
MAX_COMPONENTS = 2
|
| 6 |
+
|
| 7 |
+
# Canvas parameters
|
| 8 |
+
IMAGE_SIZE = 160
|
| 9 |
+
CENTER = (IMAGE_SIZE / 2, IMAGE_SIZE / 2)
|
| 10 |
+
DEFAULT_RADIUS = IMAGE_SIZE / 4
|
| 11 |
+
DEFAULT_WIDTH = 2
|
| 12 |
+
|
| 13 |
+
# Attribute parameters
|
| 14 |
+
# Number
|
| 15 |
+
NUM_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
| 16 |
+
NUM_MIN = 0
|
| 17 |
+
NUM_MAX = len(NUM_VALUES) - 1
|
| 18 |
+
|
| 19 |
+
# Uniformity
|
| 20 |
+
UNI_VALUES = [False, False, False, True]
|
| 21 |
+
UNI_MIN = 0
|
| 22 |
+
UNI_MAX = len(UNI_VALUES) - 1
|
| 23 |
+
|
| 24 |
+
# Type
|
| 25 |
+
TYPE_VALUES = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
|
| 26 |
+
TYPE_MIN = 0
|
| 27 |
+
TYPE_MAX = len(TYPE_VALUES) - 1
|
| 28 |
+
|
| 29 |
+
# Size
|
| 30 |
+
SIZE_VALUES = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
| 31 |
+
SIZE_MIN = 0
|
| 32 |
+
SIZE_MAX = len(SIZE_VALUES) - 1
|
| 33 |
+
|
| 34 |
+
# Color
|
| 35 |
+
COLOR_VALUES = [255, 224, 196, 168, 140, 112, 84, 56, 28, 0]
|
| 36 |
+
COLOR_MIN = 0
|
| 37 |
+
COLOR_MAX = len(COLOR_VALUES) - 1
|
| 38 |
+
|
| 39 |
+
# Angle: self-rotation
|
| 40 |
+
ANGLE_VALUES = [-135, -90, -45, 0, 45, 90, 135, 180]
|
| 41 |
+
ANGLE_MIN = 0
|
| 42 |
+
ANGLE_MAX = len(ANGLE_VALUES) - 1
|
| 43 |
+
|
| 44 |
+
META_TARGET_FORMAT = ["Constant", "Progression", "Arithmetic", "Distribute_Three", "Number", "Position", "Type", "Size", "Color"]
|
| 45 |
+
META_STRUCTURE_FORMAT = ["Singleton", "Left_Right", "Up_Down", "Out_In", "Left", "Right", "Up", "Down", "Out", "In", "Grid", "Center_Single", "Distribute_Four", "Distribute_Nine", "Left_Center_Single", "Right_Center_Single", "Up_Center_Single", "Down_Center_Single", "Out_Center_Single", "In_Center_Single", "In_Distribute_Four"]
|
| 46 |
+
|
| 47 |
+
# Rule, Attr, Param
|
| 48 |
+
# The design encodes rule priority order: Number/Position always comes first
|
| 49 |
+
# Number and Position could not both be sampled
|
| 50 |
+
# Progression on Number: Number on each Panel +1/2 or -1/2
|
| 51 |
+
# Progression on Position: Entities on each Panel roll over the layout
|
| 52 |
+
# Arithmetic on Number: Numeber on the third Panel = Number on first +/- Number on second (1 for + and -1 for -)
|
| 53 |
+
# Arithmetic on Position: 1 for SET_UNION and -1 for SET_DIFF
|
| 54 |
+
# Distribute_Three on Number: Three numbers through each row
|
| 55 |
+
# Distribute_Three on Position: Three positions (same number) through each row
|
| 56 |
+
# Constant on Number/Position: Nothing changes
|
| 57 |
+
# Progression on Type: Type progression defined as the number of edges on each entity (Triangle, Square, Pentagon, Hexagon, Circle)
|
| 58 |
+
# Distribute_Three on Type: Three types through each row
|
| 59 |
+
# Constant on Type: Nothing changes
|
| 60 |
+
# Progression on Size: Size on each entity +1/2 or -1/2
|
| 61 |
+
# Arithmetic on Size: Size on the third Panel = Size on the first +/- Size on the second (1 for + and -1 for -)
|
| 62 |
+
# Distribute_Three on Size: Three sizes through each row
|
| 63 |
+
# Constant on Size: Nothing changes
|
| 64 |
+
# Progression on Color: Color +1/2 or -1/2
|
| 65 |
+
# Arithmetic on Color: Color on the third Panel = Color on the first +/- Color on the second (1 for + and -1 for -)
|
| 66 |
+
# Distribute_Three on Color: Three colors through each row
|
| 67 |
+
# Constant on Color: Nothing changes
|
| 68 |
+
# Note that all rules on Type, Size and Color enforce value consistency in a panel
|
| 69 |
+
RULE_ATTR = [[["Progression", "Number", [-2, -1, 1, 2]],
|
| 70 |
+
["Progression", "Position", [-2, -1, 1, 2]],
|
| 71 |
+
["Arithmetic", "Number", [1, -1]],
|
| 72 |
+
["Arithmetic", "Position", [1, -1]],
|
| 73 |
+
["Distribute_Three", "Number", None],
|
| 74 |
+
["Distribute_Three", "Position", None],
|
| 75 |
+
["Constant", "Number/Position", None]],
|
| 76 |
+
[["Progression", "Type", [-2, -1, 1, 2]],
|
| 77 |
+
["Distribute_Three", "Type", None],
|
| 78 |
+
["Constant", "Type", None]],
|
| 79 |
+
[["Progression", "Size", [-2, -1, 1, 2]],
|
| 80 |
+
["Arithmetic", "Size", [1, -1]],
|
| 81 |
+
["Distribute_Three", "Size", None],
|
| 82 |
+
["Constant", "Size", None]],
|
| 83 |
+
[["Progression", "Color", [-2, -1, 1, 2]],
|
| 84 |
+
["Arithmetic", "Color", [1, -1]],
|
| 85 |
+
["Distribute_Three", "Color", None],
|
| 86 |
+
["Constant", "Color", None]]]
|
raven_utils/render/rendering.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
#
|
| 8 |
+
# from AoT import Root
|
| 9 |
+
import raven_utils.decode
|
| 10 |
+
from raven_utils.render.const import CENTER, DEFAULT_WIDTH, IMAGE_SIZE
|
| 11 |
+
|
| 12 |
+
from data_utils import Bag
|
| 13 |
+
|
| 14 |
+
from raven_utils.render_ import COLOR_VALUES, SIZE_VALUES, TYPE_VALUES, ANGLE_VALUES, RENDER_POSITIONS
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def imshow(array):
|
| 18 |
+
image = Image.fromarray(array)
|
| 19 |
+
image.show()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def imsave(array, filepath):
|
| 23 |
+
image = Image.fromarray(array)
|
| 24 |
+
image.save(filepath)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def generate_matrix(array_list):
|
| 28 |
+
# row-major array_list
|
| 29 |
+
assert len(array_list) <= 9
|
| 30 |
+
img_grid = np.zeros((IMAGE_SIZE * 3, IMAGE_SIZE * 3), np.uint8)
|
| 31 |
+
for idx in range(len(array_list)):
|
| 32 |
+
i, j = divmod(idx, 3)
|
| 33 |
+
img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
|
| 34 |
+
# draw grid
|
| 35 |
+
for x in [0.33, 0.67]:
|
| 36 |
+
img_grid[int(x * IMAGE_SIZE * 3) - 1:int(x * IMAGE_SIZE * 3) + 1, :] = 0
|
| 37 |
+
for y in [0.33, 0.67]:
|
| 38 |
+
img_grid[:, int(y * IMAGE_SIZE * 3) - 1:int(y * IMAGE_SIZE * 3) + 1] = 0
|
| 39 |
+
return img_grid
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def generate_answers(array_list):
|
| 43 |
+
assert len(array_list) <= 8
|
| 44 |
+
img_grid = np.zeros((IMAGE_SIZE * 2, IMAGE_SIZE * 4), np.uint8)
|
| 45 |
+
for idx in range(len(array_list)):
|
| 46 |
+
i, j = divmod(idx, 4)
|
| 47 |
+
img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
|
| 48 |
+
# draw grid
|
| 49 |
+
for x in [0.5]:
|
| 50 |
+
img_grid[int(x * IMAGE_SIZE * 2) - 1:int(x * IMAGE_SIZE * 2) + 1, :] = 0
|
| 51 |
+
for y in [0.25, 0.5, 0.75]:
|
| 52 |
+
img_grid[:, int(y * IMAGE_SIZE * 4) - 1:int(y * IMAGE_SIZE * 4) + 1] = 0
|
| 53 |
+
return img_grid
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def generate_matrix_answer(array_list):
|
| 57 |
+
# row-major array_list
|
| 58 |
+
assert len(array_list) <= 18
|
| 59 |
+
img_grid = np.zeros((IMAGE_SIZE * 6, IMAGE_SIZE * 3), np.uint8)
|
| 60 |
+
for idx in range(len(array_list)):
|
| 61 |
+
i, j = divmod(idx, 3)
|
| 62 |
+
img_grid[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE] = array_list[idx]
|
| 63 |
+
# draw grid
|
| 64 |
+
for x in [0.33, 0.67, 1.00, 1.33, 1.67]:
|
| 65 |
+
img_grid[int(x * IMAGE_SIZE * 3), :] = 0
|
| 66 |
+
for y in [0.33, 0.67]:
|
| 67 |
+
img_grid[:, int(y * IMAGE_SIZE * 3)] = 0
|
| 68 |
+
return img_grid
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def merge_matrix_answer(matrix, answer):
|
| 72 |
+
matrix_image = generate_matrix(matrix)
|
| 73 |
+
answer_image = generate_answers(answer)
|
| 74 |
+
img_grid = np.ones((IMAGE_SIZE * 5 + 20, IMAGE_SIZE * 4), np.uint8) * 255
|
| 75 |
+
img_grid[:IMAGE_SIZE * 3, int(0.5 * IMAGE_SIZE):int(3.5 * IMAGE_SIZE)] = matrix_image
|
| 76 |
+
img_grid[-(IMAGE_SIZE * 2):, :] = answer_image
|
| 77 |
+
return img_grid
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def render_panels(feature, target=True,angle=None):
|
| 81 |
+
# Decompose the panel into a structure and its entities
|
| 82 |
+
# root
|
| 83 |
+
# rv.decode_output(root)
|
| 84 |
+
# rv.decode_output_reshape(root)
|
| 85 |
+
# decoded =
|
| 86 |
+
# panel = decoded[0]
|
| 87 |
+
panels = []
|
| 88 |
+
for group, exist, color, size, type_ in Bag(raven_utils.decode.decode_target_flat(feature)):
|
| 89 |
+
canvas = np.ones((IMAGE_SIZE, IMAGE_SIZE), np.uint8) * 255
|
| 90 |
+
structure_img = render_structure(group)
|
| 91 |
+
background = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
| 92 |
+
# note left components entities are in the lower layer
|
| 93 |
+
for i, entity in enumerate(exist):
|
| 94 |
+
if entity:
|
| 95 |
+
entity_img = render_entity(RENDER_POSITIONS[i], color[i], size[i], type_[i] + 1, angle=angle)
|
| 96 |
+
background = layer_add(background, entity_img)
|
| 97 |
+
background = layer_add(background, structure_img)
|
| 98 |
+
panels.append(canvas - background)
|
| 99 |
+
return np.stack(panels)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def render_structure(structure):
|
| 103 |
+
if structure == 5:
|
| 104 |
+
ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
| 105 |
+
ret[:, int(0.5 * IMAGE_SIZE)] = 255.0
|
| 106 |
+
elif structure == 6:
|
| 107 |
+
ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
| 108 |
+
ret[int(0.5 * IMAGE_SIZE), :] = 255.0
|
| 109 |
+
else:
|
| 110 |
+
ret = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
| 111 |
+
return ret
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def render_entity(bbox, color, size, type_, angle=None):
|
| 115 |
+
color = COLOR_VALUES[color]
|
| 116 |
+
size = SIZE_VALUES[size]
|
| 117 |
+
type_ = TYPE_VALUES[type_]
|
| 118 |
+
if angle is None:
|
| 119 |
+
angle = np.random.randint(0, 7, 1)[0]
|
| 120 |
+
angle = ANGLE_VALUES[angle]
|
| 121 |
+
img = np.zeros((IMAGE_SIZE, IMAGE_SIZE), np.uint8)
|
| 122 |
+
|
| 123 |
+
# planar position: [x, y, w, h]
|
| 124 |
+
# angular position: [x, y, w, h, x_c, y_c, omega]
|
| 125 |
+
# center: (columns, rows)
|
| 126 |
+
center = (int(bbox[1] * IMAGE_SIZE), int(bbox[0] * IMAGE_SIZE))
|
| 127 |
+
if type_ == "triangle":
|
| 128 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
| 129 |
+
dl = int(unit * size)
|
| 130 |
+
pts = np.array([[center[0], center[1] - dl],
|
| 131 |
+
[center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
|
| 132 |
+
[center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)]],
|
| 133 |
+
np.int32)
|
| 134 |
+
pts = pts.reshape((-1, 1, 2))
|
| 135 |
+
color = 255 - color
|
| 136 |
+
width = DEFAULT_WIDTH
|
| 137 |
+
draw_triangle(img, pts, color, width)
|
| 138 |
+
elif type_ == "square":
|
| 139 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
| 140 |
+
dl = int(unit / 2 * np.sqrt(2) * size)
|
| 141 |
+
pt1 = (center[0] - dl, center[1] - dl)
|
| 142 |
+
pt2 = (center[0] + dl, center[1] + dl)
|
| 143 |
+
color = 255 - color
|
| 144 |
+
width = DEFAULT_WIDTH
|
| 145 |
+
draw_square(img, pt1, pt2, color, width)
|
| 146 |
+
elif type_ == "pentagon":
|
| 147 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
| 148 |
+
dl = int(unit * size)
|
| 149 |
+
pts = np.array([[center[0], center[1] - dl],
|
| 150 |
+
[center[0] - int(dl * np.cos(np.pi / 10)), center[1] - int(dl * np.sin(np.pi / 10))],
|
| 151 |
+
[center[0] - int(dl * np.sin(np.pi / 5)), center[1] + int(dl * np.cos(np.pi / 5))],
|
| 152 |
+
[center[0] + int(dl * np.sin(np.pi / 5)), center[1] + int(dl * np.cos(np.pi / 5))],
|
| 153 |
+
[center[0] + int(dl * np.cos(np.pi / 10)), center[1] - int(dl * np.sin(np.pi / 10))]],
|
| 154 |
+
np.int32)
|
| 155 |
+
pts = pts.reshape((-1, 1, 2))
|
| 156 |
+
color = 255 - color
|
| 157 |
+
width = DEFAULT_WIDTH
|
| 158 |
+
draw_pentagon(img, pts, color, width)
|
| 159 |
+
elif type_ == "hexagon":
|
| 160 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
| 161 |
+
dl = int(unit * size)
|
| 162 |
+
pts = np.array([[center[0], center[1] - dl],
|
| 163 |
+
[center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] - int(dl / 2.0)],
|
| 164 |
+
[center[0] - int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
|
| 165 |
+
[center[0], center[1] + dl],
|
| 166 |
+
[center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] + int(dl / 2.0)],
|
| 167 |
+
[center[0] + int(dl / 2.0 * np.sqrt(3)), center[1] - int(dl / 2.0)]],
|
| 168 |
+
np.int32)
|
| 169 |
+
pts = pts.reshape((-1, 1, 2))
|
| 170 |
+
color = 255 - color
|
| 171 |
+
width = DEFAULT_WIDTH
|
| 172 |
+
draw_hexagon(img, pts, color, width)
|
| 173 |
+
elif type_ == "circle":
|
| 174 |
+
# Minus because of the way we show the image. See: render_panel's return
|
| 175 |
+
color = 255 - color
|
| 176 |
+
unit = min(bbox[2], bbox[3]) * IMAGE_SIZE / 2
|
| 177 |
+
radius = int(unit * size)
|
| 178 |
+
width = DEFAULT_WIDTH
|
| 179 |
+
draw_circle(img, center, radius, color, width)
|
| 180 |
+
elif type_ == "none":
|
| 181 |
+
pass
|
| 182 |
+
# angular
|
| 183 |
+
if len(bbox) > 4:
|
| 184 |
+
# [x, y, w, h, x_c, y_c, omega]
|
| 185 |
+
angle = bbox[6]
|
| 186 |
+
center = (int(bbox[5] * IMAGE_SIZE), int(bbox[4] * IMAGE_SIZE))
|
| 187 |
+
img = rotate(img, angle, center=center)
|
| 188 |
+
# planar
|
| 189 |
+
else:
|
| 190 |
+
img = rotate(img, angle, center=center)
|
| 191 |
+
# img = shift(img, *entity_position)
|
| 192 |
+
|
| 193 |
+
return img
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def shift(img, dx, dy):
|
| 197 |
+
M = np.array([[1, 0, dx], [0, 1, dy]], np.float32)
|
| 198 |
+
img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
|
| 199 |
+
return img
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def rotate(img, angle, center=CENTER):
|
| 203 |
+
M = cv2.getRotationMatrix2D(center, angle, 1)
|
| 204 |
+
img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
|
| 205 |
+
return img
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def scale(img, tx, ty, center=CENTER):
|
| 209 |
+
M = np.array([[tx, 0, center[0] * (1 - tx)], [0, ty, center[1] * (1 - ty)]], np.float32)
|
| 210 |
+
img = cv2.warpAffine(img, M, (IMAGE_SIZE, IMAGE_SIZE), flags=cv2.INTER_LINEAR)
|
| 211 |
+
return img
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def layer_add(lower_layer_np, higher_layer_np):
|
| 215 |
+
# higher_layer_np is superimposed on lower_layer_np
|
| 216 |
+
# new_np = lower_layer_np.copy()
|
| 217 |
+
# lower_layer_np is modified
|
| 218 |
+
lower_layer_np[higher_layer_np > 0] = 0
|
| 219 |
+
return lower_layer_np + higher_layer_np
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# Draw primitives
|
| 223 |
+
def draw_triangle(img, pts, color, width):
|
| 224 |
+
# if filled
|
| 225 |
+
if color != 0:
|
| 226 |
+
# fill the interior
|
| 227 |
+
cv2.fillConvexPoly(img, pts, color)
|
| 228 |
+
# draw the edge
|
| 229 |
+
cv2.polylines(img, [pts], True, 255, width)
|
| 230 |
+
# if not filled
|
| 231 |
+
else:
|
| 232 |
+
cv2.polylines(img, [pts], True, 255, width)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def draw_square(img, pt1, pt2, color, width):
|
| 236 |
+
# if filled
|
| 237 |
+
if color != 0:
|
| 238 |
+
# fill the interior
|
| 239 |
+
cv2.rectangle(img,
|
| 240 |
+
pt1,
|
| 241 |
+
pt2,
|
| 242 |
+
color,
|
| 243 |
+
-1)
|
| 244 |
+
# draw the edge
|
| 245 |
+
cv2.rectangle(img,
|
| 246 |
+
pt1,
|
| 247 |
+
pt2,
|
| 248 |
+
255,
|
| 249 |
+
width)
|
| 250 |
+
# if not filled
|
| 251 |
+
else:
|
| 252 |
+
cv2.rectangle(img,
|
| 253 |
+
pt1,
|
| 254 |
+
pt2,
|
| 255 |
+
255,
|
| 256 |
+
width)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def draw_pentagon(img, pts, color, width):
|
| 260 |
+
# if filled
|
| 261 |
+
if color != 0:
|
| 262 |
+
# fill the interior
|
| 263 |
+
cv2.fillConvexPoly(img, pts, color)
|
| 264 |
+
# draw the edge
|
| 265 |
+
cv2.polylines(img, [pts], True, 255, width)
|
| 266 |
+
# if not filled
|
| 267 |
+
else:
|
| 268 |
+
cv2.polylines(img, [pts], True, 255, width)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def draw_hexagon(img, pts, color, width):
|
| 272 |
+
# if filled
|
| 273 |
+
if color != 0:
|
| 274 |
+
# fill the interior
|
| 275 |
+
cv2.fillConvexPoly(img, pts, color)
|
| 276 |
+
# draw the edge
|
| 277 |
+
cv2.polylines(img, [pts], True, 255, width)
|
| 278 |
+
# if not filled
|
| 279 |
+
else:
|
| 280 |
+
cv2.polylines(img, [pts], True, 255, width)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def draw_circle(img, center, radius, color, width):
|
| 284 |
+
# if filled
|
| 285 |
+
if color != 0:
|
| 286 |
+
# fill the interior
|
| 287 |
+
cv2.circle(img,
|
| 288 |
+
center,
|
| 289 |
+
radius,
|
| 290 |
+
color,
|
| 291 |
+
-1)
|
| 292 |
+
# draw the edge
|
| 293 |
+
cv2.circle(img,
|
| 294 |
+
center,
|
| 295 |
+
radius,
|
| 296 |
+
255,
|
| 297 |
+
width)
|
| 298 |
+
# if not filled
|
| 299 |
+
else:
|
| 300 |
+
cv2.circle(img,
|
| 301 |
+
center,
|
| 302 |
+
radius,
|
| 303 |
+
255,
|
| 304 |
+
width)
|
raven_utils/render_.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
COLOR_VALUES = [255, 224, 196, 168, 140, 112, 84, 56, 28, 0]
|
| 2 |
+
SIZE_VALUES = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
| 3 |
+
TYPE_VALUES = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
|
| 4 |
+
ANGLE_VALUES = [-135, -90, -45, 0, 45, 90, 135, 180]
|
| 5 |
+
RENDER_POSITIONS_GROUPED = [
|
| 6 |
+
[(0.5, 0.5, 1, 1)],
|
| 7 |
+
# ...
|
| 8 |
+
[(0.25, 0.25, 0.5, 0.5),
|
| 9 |
+
(0.25, 0.75, 0.5, 0.5),
|
| 10 |
+
(0.75, 0.25, 0.5, 0.5),
|
| 11 |
+
(0.75, 0.75, 0.5, 0.5)],
|
| 12 |
+
# ...
|
| 13 |
+
[(0.16, 0.16, 0.33, 0.33),
|
| 14 |
+
(0.16, 0.5, 0.33, 0.33),
|
| 15 |
+
(0.16, 0.83, 0.33, 0.33),
|
| 16 |
+
(0.5, 0.16, 0.33, 0.33),
|
| 17 |
+
(0.5, 0.5, 0.33, 0.33),
|
| 18 |
+
(0.5, 0.83, 0.33, 0.33),
|
| 19 |
+
(0.83, 0.16, 0.33, 0.33),
|
| 20 |
+
(0.83, 0.5, 0.33, 0.33),
|
| 21 |
+
(0.83, 0.83, 0.33, 0.33)],
|
| 22 |
+
# ...
|
| 23 |
+
[(0.5, 0.5, 1, 1)],
|
| 24 |
+
[(0.5, 0.5, 0.33, 0.33)],
|
| 25 |
+
# ...
|
| 26 |
+
[(0.5, 0.5, 1, 1)],
|
| 27 |
+
[(0.42, 0.42, 0.15, 0.15),
|
| 28 |
+
(0.42, 0.58, 0.15, 0.15),
|
| 29 |
+
(0.58, 0.42, 0.15, 0.15),
|
| 30 |
+
(0.58, 0.58, 0.15, 0.15)],
|
| 31 |
+
# ....
|
| 32 |
+
[(0.5, 0.25, 0.5, 0.5)],
|
| 33 |
+
[(0.5, 0.75, 0.5, 0.5)],
|
| 34 |
+
# ...
|
| 35 |
+
[(0.25, 0.5, 0.5, 0.5)],
|
| 36 |
+
[(0.75, 0.5, 0.5, 0.5)],
|
| 37 |
+
# ...
|
| 38 |
+
]
|
| 39 |
+
RENDER_POSITIONS = [pos_ for group_pos_ in RENDER_POSITIONS_GROUPED for pos_ in group_pos_]
|
| 40 |
+
MAPPING = {
|
| 41 |
+
"distribute_nine":
|
| 42 |
+
{0.16: 0,
|
| 43 |
+
0.5: 1,
|
| 44 |
+
0.83: 2},
|
| 45 |
+
"distribute_four":
|
| 46 |
+
{0.25: 0,
|
| 47 |
+
0.75: 1},
|
| 48 |
+
'in_distribute_four_out_center_single':
|
| 49 |
+
{0.42: 0,
|
| 50 |
+
0.58: 1}
|
| 51 |
+
}
|
| 52 |
+
MUL = {
|
| 53 |
+
"distribute_nine": 3,
|
| 54 |
+
"distribute_four": 2,
|
| 55 |
+
'in_distribute_four_out_center_single': 2
|
| 56 |
+
}
|
| 57 |
+
TYPES = ["triangle", "square", "pentagon", "hexagon", "circle"]
|
| 58 |
+
TYPES_NONE = ["none", "triangle", "square", "pentagon", "hexagon", "circle"]
|
| 59 |
+
SIZES = ["vs", "s", "m", "h", "vh", "e"]
|
| 60 |
+
SIZES_NAME = ["Very Small", "Small", "Medium", "High", "Very High", "Enormous"]
|
| 61 |
+
COLORS = ["vs", "s", "m", "h", "vh", "e"]
|
| 62 |
+
|
| 63 |
+
SAMPLE_TARGET = [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 64 |
+
0, 0, 0, 0, 9, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 65 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 66 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 67 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0,
|
| 68 |
+
0, 1, 3],
|
| 69 |
+
[1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 70 |
+
0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 2, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 71 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 72 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 73 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 2, 3, 0, 0, 0, 0,
|
| 74 |
+
0, 3, 3],
|
| 75 |
+
[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 76 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 77 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0,
|
| 78 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 79 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 2, 0, 0, 0, 0,
|
| 80 |
+
0, 0, 3],
|
| 81 |
+
[3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
|
| 82 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 83 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 84 |
+
0, 0, 0, 5, 2, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 85 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
|
| 86 |
+
3, 2, 1],
|
| 87 |
+
[4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1,
|
| 88 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 89 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 90 |
+
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3,
|
| 91 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 2, 1,
|
| 92 |
+
3, 0, 1],
|
| 93 |
+
[5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 94 |
+
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 95 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 96 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 97 |
+
0, 2, 0, 0, 7, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 3, 0, 0, 3, 3,
|
| 98 |
+
3, 1, 0],
|
| 99 |
+
[6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 100 |
+
0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 101 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 102 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 103 |
+
0, 0, 0, 0, 0, 0, 0, 6, 5, 0, 8, 5, 1, 0, 0, 0, 3, 2, 0, 0, 1, 0,
|
| 104 |
+
3, 3, 3]]
|
raven_utils/rules.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ml_utils import dict_from_list
|
| 2 |
+
COMBINE = "Number/Position"
|
| 3 |
+
|
| 4 |
+
ATTRIBUTES = [
|
| 5 |
+
"Number",
|
| 6 |
+
"Position",
|
| 7 |
+
"Color",
|
| 8 |
+
"Size",
|
| 9 |
+
"Type"
|
| 10 |
+
]
|
| 11 |
+
ATTRIBUTES_LEN = len(ATTRIBUTES)
|
| 12 |
+
ATTRIBUTES_INDEX = dict_from_list(ATTRIBUTES)
|
| 13 |
+
|
| 14 |
+
TYPES = [
|
| 15 |
+
"Constant",
|
| 16 |
+
"Arithmetic",
|
| 17 |
+
"Progression",
|
| 18 |
+
"Distribute_Three"
|
| 19 |
+
]
|
| 20 |
+
TYPES_INDEX = dict_from_list(TYPES)
|
| 21 |
+
TYPES_LEN = len(TYPES)
|
raven_utils/target.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import raven_utils.group as group
|
| 3 |
+
import raven_utils.entity as entity
|
| 4 |
+
import raven_utils.rules as rules
|
| 5 |
+
import raven_utils.properties as properties
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
ENTITY_INDEX = entity.INDEX + 1
|
| 9 |
+
ENTITY_DICT = dict(zip(group.NAMES, ENTITY_INDEX[:-1]))
|
| 10 |
+
NAMES_ORDER = dict(zip(group.NAMES, np.arange(len(group.NAMES))))
|
| 11 |
+
PROPERTIES_INDEXES = np.cumsum(np.array(list(entity.NO.values())) * properties.NO)
|
| 12 |
+
INDEX = np.concatenate([[0], PROPERTIES_INDEXES]) + entity.SUM + 1 # +2 type and uniformity
|
| 13 |
+
|
| 14 |
+
SECOND_LAYOUT = [i - 1 for i in [
|
| 15 |
+
ENTITY_DICT["in_center_single_out_center_single"] + 1,
|
| 16 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 1,
|
| 17 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 2,
|
| 18 |
+
ENTITY_DICT["in_distribute_four_out_center_single"] + 3,
|
| 19 |
+
ENTITY_DICT["left_center_single_right_center_single"] + 1,
|
| 20 |
+
ENTITY_DICT["up_center_single_down_center_single"] + 1
|
| 21 |
+
]]
|
| 22 |
+
|
| 23 |
+
FIRST_LAYOUT = list(set(range(entity.SUM)) - set(SECOND_LAYOUT))
|
| 24 |
+
LAYOUT_NO = 2
|
| 25 |
+
|
| 26 |
+
START_INDEX = dict(zip(group.NAMES, INDEX[:-1]))
|
| 27 |
+
END_INDEX = INDEX[-1]
|
| 28 |
+
|
| 29 |
+
RULES_ATTRIBUTES_ALL_LEN = rules.ATTRIBUTES_LEN * LAYOUT_NO
|
| 30 |
+
UNIFORMITY_NO = 2
|
| 31 |
+
UNIFORMITY_INDEX = END_INDEX + RULES_ATTRIBUTES_ALL_LEN
|
| 32 |
+
|
| 33 |
+
SIZE = UNIFORMITY_INDEX + UNIFORMITY_NO
|
| 34 |
+
|
| 35 |
+
def take(target):
|
| 36 |
+
return target[1], target[2]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create(images, index, pattern_index=(2, 5), full_index=False, arrange=np.arange, shape=lambda x: x.shape):
|
| 40 |
+
return [images[:, pattern_index[0]], images[:, pattern_index[1]],
|
| 41 |
+
images[arrange(shape(index)[0]), (0 if full_index else 8) + index[:, 0]]]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def take_simple(target):
|
| 46 |
+
return target[1], target[0]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def create_simple(images, target, index=slice(None), pattern_index=(2, 5)):
|
| 50 |
+
return [images[:, pattern_index[0]], images[:, pattern_index[1]], target][index]
|
raven_utils/uitls.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from itertools import product
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from funcy import identity
|
| 6 |
+
|
| 7 |
+
from data_utils import gather, DataGenerator, Data
|
| 8 |
+
from data_utils.sampling import DataSampler
|
| 9 |
+
from models_utils import init_image as def_init_image, INPUTS, TARGET
|
| 10 |
+
|
| 11 |
+
import raven_utils.group as group
|
| 12 |
+
|
| 13 |
+
from data_utils import ops as D
|
| 14 |
+
|
| 15 |
+
init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_val_index(no=group.NO, base=3,add_end=False):
|
| 19 |
+
indexes = np.arange(no) * 2000 + base
|
| 20 |
+
if add_end:
|
| 21 |
+
indexes = np.concatenate([indexes, no*2000])
|
| 22 |
+
return indexes
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_matrix(inputs, index):
|
| 26 |
+
return np.concatenate([inputs[:, :8], gather(inputs, index[:, 0])[:, None]], axis=1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_matrix_from_data(x):
|
| 30 |
+
inputs = x["inputs"]
|
| 31 |
+
index = x["index"]
|
| 32 |
+
return get_matrix(inputs, index)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_data_class(data, batch_size=128):
|
| 36 |
+
fn = identity
|
| 37 |
+
shape = data[0].shape
|
| 38 |
+
train_generator = DataGenerator(
|
| 39 |
+
{
|
| 40 |
+
INPUTS: Data(data[0], fn),
|
| 41 |
+
TARGET: Data(data[2], fn),
|
| 42 |
+
},
|
| 43 |
+
sampler=DataSampler(np.array(list(product(np.arange(shape[0]), np.arange(shape[1]))))),
|
| 44 |
+
batch=batch_size
|
| 45 |
+
)
|
| 46 |
+
shape = data[1].shape
|
| 47 |
+
val_generator = DataGenerator(
|
| 48 |
+
{
|
| 49 |
+
INPUTS: Data(data[1], fn),
|
| 50 |
+
TARGET: Data(data[3], fn),
|
| 51 |
+
},
|
| 52 |
+
sampler=DataSampler(np.array(list(product(np.arange(shape[0]), np.arange(shape[1])))), shuffle=False),
|
| 53 |
+
batch=batch_size
|
| 54 |
+
)
|
| 55 |
+
return train_generator, val_generator
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def compare_from_result(result, data):
|
| 59 |
+
data = data.data.data
|
| 60 |
+
answer = D.gather(data['target'].data, data['index'].data[:, 0])
|
| 61 |
+
import raven_utils as rv
|
| 62 |
+
predict = result['predict']
|
| 63 |
+
predict_mask = result['predict_mask']
|
| 64 |
+
return np.all(rv.decode.compare(answer[:len(predict)], predict, predict_mask), axis=-1)
|
saved_model/1/keras_metadata.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3065f1580247f096711cd61201a17a730a1e5a3d719f2c2778030dea78bb17b4
|
| 3 |
+
size 730275
|
saved_model/1/saved_model.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:48d10d74324a5e993ceacc0a4bffc1fcb232d7e2f708a2ebbeabd864650baeeb
|
| 3 |
+
size 12159312
|
saved_model/1/variables/variables.data-00000-of-00001
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3e3c44b273228c834166b40b8a062a53dce76cc21d4cce42f65df2edc53533a7
|
| 3 |
+
size 43002413
|
saved_model/1/variables/variables.index
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e69bbffa5d415625538b629762f1aaeeb355a83d676242110af3d633e31017dd
|
| 3 |
+
size 24958
|
utils.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from data_utils.image import draw_images
|
| 3 |
+
from ml_utils import il
|
| 4 |
+
|
| 5 |
+
import raven_utils as rv
|
| 6 |
+
from raven_utils.uitls import get_matrix
|
| 7 |
+
from tensorflow.keras.models import load_model
|
| 8 |
+
from raven_utils.draw import render_from_model
|
| 9 |
+
import models
|
| 10 |
+
import ast
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_example(index=0):
|
| 14 |
+
index = ast.literal_eval(str(index))
|
| 15 |
+
if il(index):
|
| 16 |
+
example = rv.draw.render_panels(np.array(index))
|
| 17 |
+
desc = "Custom matrix"
|
| 18 |
+
else:
|
| 19 |
+
if not index:
|
| 20 |
+
index = 0
|
| 21 |
+
index = int(index)
|
| 22 |
+
|
| 23 |
+
desc = rv.draw.extract_rules(models.properties[index])
|
| 24 |
+
desc = "<br /><br />".join(["<br />".join(d) for d in desc])
|
| 25 |
+
|
| 26 |
+
example = get_matrix(models.data[index:index + 1], models.indexes[index:index + 1, None] + 8)
|
| 27 |
+
result = np.tile(draw_images(example[:9], row=3), reps=(1, 1, 3))
|
| 28 |
+
return result, desc
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_model_(name):
|
| 32 |
+
if name == "Transformer":
|
| 33 |
+
path = "/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09"
|
| 34 |
+
else:
|
| 35 |
+
path = name
|
| 36 |
+
models.model = load_model(path)
|
| 37 |
+
return f"Success loading: {name}"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def run_nn(index=0):
|
| 41 |
+
index = ast.literal_eval(str(index))
|
| 42 |
+
if il(index):
|
| 43 |
+
data = rv.draw.render_panels(np.array(index))
|
| 44 |
+
data = np.concatenate([data, data[:7]])[None]
|
| 45 |
+
else:
|
| 46 |
+
if not index:
|
| 47 |
+
index = models.START_IMAGE
|
| 48 |
+
index = int(index)
|
| 49 |
+
data = models.data[index:index + 1]
|
| 50 |
+
|
| 51 |
+
# model = load_model("/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09")
|
| 52 |
+
data = {
|
| 53 |
+
'inputs': data,
|
| 54 |
+
'index': np.zeros(shape=(1, 1), dtype="uint8"),
|
| 55 |
+
'labels': np.zeros(shape=(1, 16, 113), dtype="int8"),
|
| 56 |
+
'target': np.zeros(shape=(1, 16, 113), dtype="int8"),
|
| 57 |
+
# 'features': np.zeros(shape=(1, 16, 64), dtype="float32")
|
| 58 |
+
}
|
| 59 |
+
res = np.tile(render_from_model(data, models.model)[0, ..., None], reps=(1, 1, 3))
|
| 60 |
+
|
| 61 |
+
# res = model({'inputs': data[0:1]})
|
| 62 |
+
|
| 63 |
+
return res
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def next_(index=0):
|
| 67 |
+
index = ast.literal_eval(str(index))
|
| 68 |
+
if not isinstance(index, int):
|
| 69 |
+
index = models.START_IMAGE
|
| 70 |
+
index = int(index) + 1
|
| 71 |
+
return (index,) + load_example(index)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def prev_(index=0):
|
| 75 |
+
index = ast.literal_eval(str(index))
|
| 76 |
+
if not isinstance(index, int):
|
| 77 |
+
index = models.START_IMAGE
|
| 78 |
+
index = int(index) - 1
|
| 79 |
+
return (index,) + load_example(index)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == '__main__':
|
| 83 |
+
image, _ = load_example(5)
|
| 84 |
+
run_nn(image)
|