Spaces:
Build error
Build error
| import numpy as np | |
| from funcy import identity | |
| import raven_utils as rv | |
| from raven_utils.constant import PROPERTY | |
| from raven_utils.decode import target_mask | |
| from raven_utils.image import draw_images | |
| from raven_utils.render.rendering import render_panels | |
| from raven_utils.tools import filter_keys, is_model, il | |
| from raven_utils.uitls import get_matrix | |
| from tensorflow.keras.models import load_model | |
| from raven_utils.draw import render_from_model | |
| import models | |
| import ast | |
| def render_from_model(data,predict,pre_fn=identity): | |
| data = filter_keys(data, PROPERTY, reverse=True) | |
| if is_model(predict) or str(type(predict)) == "<class 'tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject'>": | |
| predict = predict(data) | |
| pro = np.array(target_mask(predict['predict_mask'].numpy()) * predict["predict"].numpy(), dtype=np.int8) | |
| return pre_fn(render_panels(pro, target=False)[None])[0] | |
| def load_example(index=0): | |
| index = ast.literal_eval(str(index)) | |
| if il(index): | |
| example = rv.draw.render_panels(np.array(index)) | |
| desc = "Custom matrix" | |
| else: | |
| if not index: | |
| index = 0 | |
| index = int(index) | |
| desc = models.properties[index]['Description'] | |
| example = get_matrix( | |
| np.array(models.data[index:index + 1]['inputs'], dtype="uint8"), | |
| np.array(models.data[index:index + 1]['index'], dtype="uint8")[..., None] | |
| ) | |
| result = np.tile(draw_images(example[:9], row=3), reps=(1, 1, 3)) | |
| return result, desc | |
| def load_model_(name): | |
| if name == "Transformer": | |
| path = "/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09" | |
| else: | |
| path = name | |
| models.model = load_model(path) | |
| return f"Success loading: {name}" | |
| def run_nn(index=0): | |
| index = ast.literal_eval(str(index)) | |
| if il(index): | |
| data = rv.draw.render_panels(np.array(index)) | |
| data = np.concatenate([data, data[:7]])[None] | |
| else: | |
| if not index: | |
| index = models.START_IMAGE | |
| index = int(index) | |
| data = models.data[index:index + 1]['inputs'] | |
| # model = load_model("/home/jkwiatkowski/all/best/rav/full_trans/6e8e6bad403e4171ad10daa1a518ba09") | |
| # data = { | |
| # 'inputs': data, | |
| # 'index': np.zeros(shape=(1, 1), dtype="uint8"), | |
| # 'target': np.zeros(shape=(1, 16, 113), dtype="int8"), | |
| # } | |
| data = { | |
| 'inputs': np.asarray(data, dtype="uint8"), | |
| 'index': np.zeros(shape=(1, 1), dtype="uint8"), | |
| 'target': np.zeros(shape=(1, 16, 113), dtype="int8"), | |
| } | |
| res = np.tile(render_from_model(data, models.model)[0, ..., None], reps=(1, 1, 3)) | |
| # res = model({'inputs': data[0:1]}) | |
| return res | |
| def next_(index=0): | |
| index = ast.literal_eval(str(index)) | |
| if not isinstance(index, int): | |
| index = models.START_IMAGE | |
| index = int(index) + 1 | |
| return (index,) + load_example(index) | |
| def prev_(index=0): | |
| index = ast.literal_eval(str(index)) | |
| if not isinstance(index, int): | |
| index = models.START_IMAGE | |
| index = int(index) - 1 | |
| return (index,) + load_example(index) | |
| if __name__ == '__main__': | |
| image, _ = load_example(5) | |
| run_nn(image) | |