File size: 4,575 Bytes
fdde807
7274c87
457061b
b7ed086
457061b
 
f31db2c
457061b
f31db2c
 
 
 
fdde807
7274c87
f31db2c
7274c87
457061b
 
f31db2c
 
457061b
b7ed086
f31db2c
5bfb237
 
f31db2c
 
 
 
 
5bfb237
f31db2c
457061b
 
f31db2c
 
 
b7ed086
7274c87
f31db2c
 
 
 
 
 
 
7274c87
 
f31db2c
7274c87
f31db2c
 
b7ed086
7274c87
f31db2c
457061b
 
5bfb237
b7ed086
 
f31db2c
5bfb237
b7ed086
 
 
 
5bfb237
f31db2c
 
5bfb237
b7ed086
5bfb237
f31db2c
457061b
f31db2c
 
5bfb237
b7ed086
f31db2c
 
 
5bfb237
f31db2c
5bfb237
b7ed086
 
 
f31db2c
 
b7ed086
 
7274c87
 
 
f31db2c
7274c87
f31db2c
7274c87
f31db2c
 
7274c87
f31db2c
 
 
 
 
 
 
5bfb237
 
5ae73bc
 
f31db2c
5ae73bc
 
 
 
 
ecabf8f
5ae73bc
 
f31db2c
 
5ae73bc
f31db2c
 
5b6ce5a
 
f31db2c
 
 
 
 
 
 
 
 
 
 
 
5ae73bc
5bfb237
 
5b6ce5a
 
5bfb237
f31db2c
5bfb237
f31db2c
 
5bfb237
f31db2c
5bfb237
7274c87
b7ed086
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import numpy as np
import json
import plotly.graph_objects as go

# -----------------------------
# DATA PATHS (tes "documents 000")
# -----------------------------
DATA_PATH_1 = "000_hinge_0/000_hinge_0_Gen175_Pop10000_baseline_MAP-Elite-Explore_APGen_initState_max_v0/Archive_space_Object_000_JointNature_hinge_Item0_Generation_175_PopSize_10000_baseline_MAP-Elite-Explore_APGen_initState_max_v0.jsonl"

DATA_PATH_2 = "000_slider_0/000_slider_0_Gen175_Pop10000_baseline_MAP-Elite-Explore_APGen_initState_max_v0/Archive_space_Object_000_JointNature_slider_Item0_Generation_175_PopSize_10000_baseline_MAP-Elite-Explore_APGen_initState_max_v0.jsonl"


# -----------------------------
# LOAD
# -----------------------------
def load_all_trajectories(path):
    traj_list = []
    count = 0

    with open(path, "r") as f:
        for line in f:
            if count % 50 == 0:
                obj = json.loads(line)
                traj = obj.get("traj_list")

                if traj is not None:
                    traj = [x[:3] for x in traj]
                    traj_list.append(traj)

            count += 1

    return traj_list


def build_dict(path):
    trajs = load_all_trajectories(path)
    return {f"Trajectory {i}": traj for i, traj in enumerate(trajs)}


TRAJ_1 = build_dict(DATA_PATH_1)
TRAJ_2 = build_dict(DATA_PATH_2)

choices_1 = list(TRAJ_1.keys())
choices_2 = list(TRAJ_2.keys())


# -----------------------------
# PLOT FUNCTION (generic)
# -----------------------------
def plot_traj(trajs_dict, choice, progress, title):

    fig = go.Figure()

    for name, traj in trajs_dict.items():
        traj = np.array(traj)

        if len(traj) == 0:
            continue

        # background
        if name != choice:
            fig.add_trace(go.Scatter3d(
                x=traj[:, 0],
                y=traj[:, 1],
                z=traj[:, 2],
                mode="lines",
                line=dict(width=4, color="green"),
                opacity=0.05,
                showlegend=False
            ))

        # selected (progressive)
        else:
            k = max(2, int(len(traj) * progress))
            traj_part = traj[:k]

            fig.add_trace(go.Scatter3d(
                x=traj_part[:, 0],
                y=traj_part[:, 1],
                z=traj_part[:, 2],
                mode="lines",
                line=dict(width=6, color="red"),
                name=name
            ))

    fig.update_layout(
        title=f"{title} | {choice} | progress={progress:.2f}",
        scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"),
        margin=dict(l=0, r=0, b=0, t=40)
    )

    return fig


# -----------------------------
# WRAPPERS
# -----------------------------
def update_1(choice, progress):
    return plot_traj(TRAJ_1, choice, progress, "Action primitive 000_hinge_0")

def update_2(choice, progress):
    return plot_traj(TRAJ_2, choice, progress, "Action primitive 000_slider_0")


# -----------------------------
# UI
# -----------------------------
with gr.Blocks() as demo:

    gr.Markdown("""
# Trajectory Explorer

This page is an interactive visualization of two example outcome archives.
This first outcome archive refers to the closing of a door on our experimental object (000_hinge_0), the second to the closing of a drawer (000_slider_0).
Chose the trajectory you want to explore among the diverse diplayed and use the slider to see the keyframe progression.

To download the full dataset, visit:  
https://huggingface.co/datasets/mathildekappel/trajectory_primitive (Copy and paste the link into another browser tab)

""")
    # DOCUMENT 1
    # -------------------------
    gr.Markdown("## 000_hinge_0 action primitive outcome archive ")

    with gr.Row():
        dd1 = gr.Dropdown(choices_1, label="Trajectory selected")
        sl1 = gr.Slider(0, 1, value=1, step=0.01, label="Keyframe Progress")

    plot1 = gr.Plot()

    dd1.change(update_1, [dd1, sl1], plot1)
    sl1.change(update_1, [dd1, sl1], plot1)

    demo.load(update_1, [dd1, sl1], plot1)


    # -------------------------
    # DOCUMENT 2
    # -------------------------
    gr.Markdown("## 000_slider_0 action primitive outcome archive ")

    with gr.Row():
        dd2 = gr.Dropdown(choices_2, label="Trajectory selected")
        sl2 = gr.Slider(0, 1, value=1, step=0.01, label="Keyframe Progress")

    plot2 = gr.Plot()

    dd2.change(update_2, [dd2, sl2], plot2)
    sl2.change(update_2, [dd2, sl2], plot2)

    demo.load(update_2, [dd2, sl2], plot2)


if __name__ == "__main__":
    demo.launch()