github-actions[bot] commited on
Commit
7e62de3
·
0 Parent(s):

Sync from GitHub: f981b5b319b104ae2ede1bb5a4b163e5d21c968a

Browse files
Files changed (3) hide show
  1. README.md +33 -0
  2. app.py +222 -0
  3. requirements.txt +6 -0
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: WorldFlux Demo
3
+ emoji: 🌐
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: "5.9.1"
8
+ python_version: "3.12"
9
+ app_file: app.py
10
+ pinned: false
11
+ ---
12
+
13
+ # WorldFlux Demo
14
+
15
+ Experience world models in action.
16
+
17
+ This Space demonstrates imagination rollouts from WorldFlux world models.
18
+
19
+ ## Features
20
+
21
+ - Interactive imagination rollout visualization
22
+ - Multiple model presets
23
+ - Real-time reward and continuation prediction
24
+
25
+ ## Try it out
26
+
27
+ Select a model type, set the imagination horizon, and click "Run Imagination".
28
+
29
+ ## Links
30
+
31
+ - [GitHub](https://github.com/worldflux/WorldFlux)
32
+ - [PyPI](https://pypi.org/project/worldflux/)
33
+ - [Documentation](https://github.com/worldflux/WorldFlux/tree/main/docs)
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WorldFlux imagination demo powered by actual WorldFlux model inference."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+
13
+ from worldflux import create_world_model
14
+
15
+ MODEL_SPECS = {
16
+ "DreamerV3": {
17
+ "model_id": "dreamerv3:size12m",
18
+ "obs_shape": (3, 64, 64),
19
+ "action_dim": 6,
20
+ },
21
+ "TD-MPC2": {
22
+ "model_id": "tdmpc2:5m",
23
+ "obs_shape": (39,),
24
+ "action_dim": 6,
25
+ },
26
+ }
27
+
28
+
29
+ def _to_numpy_frame(tensor: torch.Tensor) -> np.ndarray | None:
30
+ value = tensor.detach().cpu()
31
+ if value.ndim == 4:
32
+ value = value[0]
33
+ if value.ndim == 3:
34
+ frame = value.numpy()
35
+ if frame.shape[0] in {1, 3}:
36
+ frame = np.transpose(frame, (1, 2, 0))
37
+ frame = np.nan_to_num(frame.astype(np.float32))
38
+ if frame.shape[-1] == 1:
39
+ frame = np.repeat(frame, 3, axis=-1)
40
+ if frame.max() > frame.min():
41
+ frame = (frame - frame.min()) / (frame.max() - frame.min())
42
+ return frame
43
+ return None
44
+
45
+
46
+ def _extract_predictions(decoded: Any) -> dict[str, torch.Tensor]:
47
+ if isinstance(decoded, dict):
48
+ return decoded
49
+ predictions = getattr(decoded, "predictions", {})
50
+ if isinstance(predictions, dict):
51
+ return predictions
52
+ return {}
53
+
54
+
55
+ @lru_cache(maxsize=8)
56
+ def _load_model(model_type: str, checkpoint_path: str) -> Any:
57
+ spec = MODEL_SPECS[model_type]
58
+ model = create_world_model(
59
+ spec["model_id"],
60
+ obs_shape=spec["obs_shape"],
61
+ action_dim=spec["action_dim"],
62
+ )
63
+
64
+ resolved_checkpoint = Path(checkpoint_path).expanduser() if checkpoint_path else None
65
+ if resolved_checkpoint and resolved_checkpoint.exists():
66
+ model = model.__class__.from_pretrained(str(resolved_checkpoint))
67
+
68
+ model.eval()
69
+ return model
70
+
71
+
72
+ def _build_initial_obs(model_type: str) -> torch.Tensor:
73
+ spec = MODEL_SPECS[model_type]
74
+ obs_shape = spec["obs_shape"]
75
+ if len(obs_shape) == 3:
76
+ channels, height, width = obs_shape
77
+ grid = np.linspace(0.0, 1.0, height * width, dtype=np.float32).reshape(height, width)
78
+ obs = np.stack([(grid + i / max(1, channels)) % 1.0 for i in range(channels)], axis=0)
79
+ return torch.from_numpy(obs).unsqueeze(0)
80
+ if len(obs_shape) == 1:
81
+ vector = np.linspace(-1.0, 1.0, obs_shape[0], dtype=np.float32)
82
+ return torch.from_numpy(vector).unsqueeze(0)
83
+ raise ValueError(f"Unsupported observation shape: {obs_shape}")
84
+
85
+
86
+ def _build_action_sequence(model_type: str, horizon: int, device: torch.device) -> torch.Tensor:
87
+ action_dim = MODEL_SPECS[model_type]["action_dim"]
88
+ actions = torch.zeros(horizon, 1, action_dim, device=device)
89
+ for t in range(horizon):
90
+ actions[t, 0, t % action_dim] = 1.0
91
+ return actions
92
+
93
+
94
+ def _plot_rewards(rewards: list[float]):
95
+ import matplotlib.pyplot as plt
96
+
97
+ fig, ax = plt.subplots(figsize=(8, 4))
98
+ ax.plot(rewards, marker="o")
99
+ ax.set_xlabel("Time Step")
100
+ ax.set_ylabel("Predicted Reward")
101
+ ax.set_title("Imagined Rewards")
102
+ ax.grid(True, alpha=0.3)
103
+ return fig
104
+
105
+
106
+ def _plot_continues(continues: list[float]):
107
+ import matplotlib.pyplot as plt
108
+
109
+ fig, ax = plt.subplots(figsize=(8, 4))
110
+ ax.plot(continues, marker="s", color="green")
111
+ ax.set_xlabel("Time Step")
112
+ ax.set_ylabel("Continue Probability")
113
+ ax.set_title("Episode Continuation")
114
+ ax.grid(True, alpha=0.3)
115
+ ax.set_ylim([0, 1])
116
+ return fig
117
+
118
+
119
+ def _plot_frames(frames: list[np.ndarray]):
120
+ import matplotlib.pyplot as plt
121
+
122
+ fig, ax = plt.subplots(figsize=(8, 4))
123
+ if not frames:
124
+ ax.text(
125
+ 0.5,
126
+ 0.5,
127
+ "This model does not decode image observations.\n(Reward/continue are real model outputs)",
128
+ ha="center",
129
+ va="center",
130
+ fontsize=10,
131
+ )
132
+ ax.axis("off")
133
+ return fig
134
+
135
+ preview = np.concatenate(frames[: min(5, len(frames))], axis=1)
136
+ ax.imshow(preview)
137
+ ax.set_title("Imagined Frame Preview")
138
+ ax.axis("off")
139
+ return fig
140
+
141
+
142
+ def run_imagination(model_type: str, horizon: int, checkpoint_path: str):
143
+ model = _load_model(model_type, checkpoint_path.strip())
144
+ device = next(model.parameters()).device
145
+ initial_obs = _build_initial_obs(model_type).to(device=device)
146
+ action_sequence = _build_action_sequence(model_type, int(horizon), device)
147
+
148
+ rewards: list[float] = []
149
+ continues: list[float] = []
150
+ frames: list[np.ndarray] = []
151
+
152
+ with torch.no_grad():
153
+ state = model.encode({"obs": initial_obs})
154
+ trajectory = model.rollout(state, action_sequence)
155
+
156
+ rollout_rewards = getattr(trajectory, "rewards", None)
157
+ if isinstance(rollout_rewards, torch.Tensor):
158
+ rewards = rollout_rewards.detach().cpu().view(-1).tolist()
159
+ rollout_continues = getattr(trajectory, "continues", None)
160
+ if isinstance(rollout_continues, torch.Tensor):
161
+ continues = torch.sigmoid(rollout_continues).detach().cpu().view(-1).tolist()
162
+
163
+ for state_t in trajectory.states[1:]:
164
+ decoded = model.decode(state_t)
165
+ predictions = _extract_predictions(decoded)
166
+
167
+ if not rewards and isinstance(predictions.get("reward"), torch.Tensor):
168
+ rewards.append(float(predictions["reward"].detach().cpu().view(-1)[0]))
169
+ if not continues and isinstance(predictions.get("continue"), torch.Tensor):
170
+ continues.append(
171
+ float(torch.sigmoid(predictions["continue"]).detach().cpu().view(-1)[0])
172
+ )
173
+
174
+ obs_pred = predictions.get("obs")
175
+ if isinstance(obs_pred, torch.Tensor):
176
+ frame = _to_numpy_frame(obs_pred)
177
+ if frame is not None:
178
+ frames.append(frame)
179
+
180
+ if not rewards:
181
+ rewards = [0.0 for _ in range(int(horizon))]
182
+ if not continues:
183
+ continues = [1.0 for _ in range(int(horizon))]
184
+
185
+ rewards_plot = _plot_rewards(rewards)
186
+ continues_plot = _plot_continues(continues)
187
+ frames_plot = _plot_frames(frames)
188
+ status = (
189
+ f"Ran {model_type} inference for {int(horizon)} steps "
190
+ f"(checkpoint={checkpoint_path.strip() or 'model preset'})"
191
+ )
192
+ return rewards_plot, continues_plot, frames_plot, status
193
+
194
+
195
+ with gr.Blocks() as demo:
196
+ gr.Markdown("# WorldFlux Demo")
197
+ gr.Markdown("Actual WorldFlux encode → rollout → decode inference (no random mock outputs)")
198
+
199
+ model_type = gr.Dropdown(
200
+ choices=["DreamerV3", "TD-MPC2"], value="DreamerV3", label="Model Type"
201
+ )
202
+ checkpoint_path = gr.Textbox(
203
+ label="Checkpoint Path (optional)",
204
+ placeholder="/data/checkpoints/dreamer_final",
205
+ )
206
+ horizon = gr.Slider(5, 50, value=15, step=1, label="Imagination Horizon")
207
+ btn = gr.Button("Run Imagination")
208
+
209
+ with gr.Row():
210
+ rewards_plot = gr.Plot(label="Rewards")
211
+ continues_plot = gr.Plot(label="Continues")
212
+ frames_plot = gr.Plot(label="Imagined Frames")
213
+ output_text = gr.Textbox(label="Status")
214
+
215
+ btn.click(
216
+ run_imagination,
217
+ inputs=[model_type, horizon, checkpoint_path],
218
+ outputs=[rewards_plot, continues_plot, frames_plot, output_text],
219
+ )
220
+
221
+ if __name__ == "__main__":
222
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=5.9.0
2
+ huggingface_hub>=0.25.0
3
+ matplotlib>=3.5.0
4
+ numpy>=1.24.0
5
+ torch>=2.0.0
6
+ git+https://github.com/worldflux/WorldFlux.git