dorsar commited on
Commit
a571fce
·
verified ·
1 Parent(s): d19cc32

Upload diffusion_model.py

Browse files
Files changed (1) hide show
  1. diffusion_model.py +406 -0
diffusion_model.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset
4
+ import zarr
5
+ import numpy as np
6
+ from diffusers import UNet1DModel, DDPMScheduler
7
+ from diffusers.optimization import get_scheduler
8
+ from diffusers.training_utils import EMAModel
9
+ import torch.nn.functional as F
10
+ from tqdm.auto import tqdm
11
+ from dataclasses import dataclass
12
+ from typing import Tuple, Optional
13
+ import os
14
+ import gdown
15
+ import collections
16
+ from imageio import get_writer
17
+
18
+ # Data Configurations
19
+ @dataclass
20
+ class DataConfig:
21
+ """Configuration for dataset"""
22
+ # Dataset paths and download info
23
+ dataset_path: str = "pusht_cchi_v7_replay.zarr.zip"
24
+ dataset_gdrive_id: str = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
25
+
26
+ # Sequence parameters
27
+ pred_horizon: int = 16 # Number of steps to predict
28
+ obs_horizon: int = 2 # Number of observations to condition on
29
+ action_horizon: int = 8 # Number of actions to execute
30
+
31
+ # Data dimensions
32
+ image_size: Tuple[int, int] = (96, 96)
33
+ image_channels: int = 3
34
+ action_dim: int = 2
35
+ state_dim: int = 5 # [agent_x, agent_y, block_x, block_y, block_angle]
36
+
37
+ @dataclass
38
+ class ModelConfig:
39
+ """Configuration for neural networks"""
40
+ # Observation encoding
41
+ obs_embed_dim: int = 256
42
+
43
+ # UNet configuration
44
+ sample_size: int = 16 # pred_horizon length
45
+ in_channels: int = 2 # action dimension
46
+ out_channels: int = 2 # action dimension
47
+ layers_per_block: int = 2
48
+ block_out_channels: Tuple[int, ...] = (128,)
49
+ norm_num_groups: int = 8
50
+ down_block_types: Tuple[str, ...] = ("DownBlock1D",) * 1
51
+ up_block_types: Tuple[str, ...] = ("UpBlock1D",) * 1
52
+
53
+ def __post_init__(self):
54
+ # For conditioning through input channels
55
+ self.total_in_channels = self.in_channels + self.obs_embed_dim //8 # actions + conditioning
56
+
57
+
58
+ """
59
+ Helper Functions
60
+ """
61
+ def create_sample_indices(
62
+ episode_ends:np.ndarray, sequence_length:int,
63
+ pad_before: int=0, pad_after: int=0):
64
+ indices = list()
65
+ for i in range(len(episode_ends)):
66
+ start_idx = 0
67
+ if i > 0:
68
+ start_idx = episode_ends[i-1]
69
+ end_idx = episode_ends[i]
70
+ episode_length = end_idx - start_idx
71
+
72
+ min_start = -pad_before
73
+ max_start = episode_length - sequence_length + pad_after
74
+
75
+ # range stops one idx before end
76
+ for idx in range(min_start, max_start+1):
77
+ buffer_start_idx = max(idx, 0) + start_idx
78
+ buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
79
+ start_offset = buffer_start_idx - (idx+start_idx)
80
+ end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
81
+ sample_start_idx = 0 + start_offset
82
+ sample_end_idx = sequence_length - end_offset
83
+ indices.append([
84
+ buffer_start_idx, buffer_end_idx,
85
+ sample_start_idx, sample_end_idx])
86
+ indices = np.array(indices)
87
+ return indices
88
+
89
+ def sample_sequence(train_data, sequence_length,
90
+ buffer_start_idx, buffer_end_idx,
91
+ sample_start_idx, sample_end_idx):
92
+ result = dict()
93
+ for key, input_arr in train_data.items():
94
+ sample = input_arr[buffer_start_idx:buffer_end_idx]
95
+ data = sample
96
+ if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
97
+ data = np.zeros(
98
+ shape=(sequence_length,) + input_arr.shape[1:],
99
+ dtype=input_arr.dtype)
100
+ if sample_start_idx > 0:
101
+ data[:sample_start_idx] = sample[0]
102
+ if sample_end_idx < sequence_length:
103
+ data[sample_end_idx:] = sample[-1]
104
+ data[sample_start_idx:sample_end_idx] = sample
105
+ result[key] = data
106
+ return result
107
+
108
+ def get_data_stats(data):
109
+ data = data.reshape(-1,data.shape[-1])
110
+ stats = {
111
+ 'min': np.min(data, axis=0),
112
+ 'max': np.max(data, axis=0)
113
+ }
114
+ return stats
115
+
116
+ def normalize_data(data, stats):
117
+ # Normalize to [0,1]
118
+ ndata = (data - stats['min']) / (stats['max'] - stats['min'])
119
+ # Normalize to [-1, 1]
120
+ ndata = ndata * 2 - 1
121
+ return ndata
122
+
123
+ def unnormalize_data(ndata, stats):
124
+ ndata = (ndata + 1) / 2
125
+ data = ndata * (stats['max'] - stats['min']) + stats['min']
126
+ return data
127
+
128
+ # Dataset Class
129
+ class PushTStateDataset(torch.utils.data.Dataset):
130
+ def __init__(self, dataset_path, pred_horizon, obs_horizon, action_horizon):
131
+ # Read from zarr dataset
132
+ dataset_root = zarr.open(dataset_path, 'r')
133
+ # All demonstration episodes are concatenated in the first dimension N
134
+ train_data = {
135
+ # (N, action_dim)
136
+ 'action': dataset_root['data']['action'][:],
137
+ # (N, obs_dim)
138
+ 'obs': dataset_root['data']['state'][:]
139
+ }
140
+ # Marks one-past the last index for each episode
141
+ episode_ends = dataset_root['meta']['episode_ends'][:]
142
+
143
+ # Compute start and end of each state-action sequence
144
+ # Also handles padding
145
+ indices = create_sample_indices(
146
+ episode_ends=episode_ends,
147
+ sequence_length=pred_horizon,
148
+ # Add padding such that each timestep in the dataset are seen
149
+ pad_before=obs_horizon-1,
150
+ pad_after=action_horizon-1)
151
+
152
+ # Compute statistics and normalize data to [-1,1]
153
+ stats = dict()
154
+ normalized_train_data = dict()
155
+ for key, data in train_data.items():
156
+ stats[key] = get_data_stats(data)
157
+ normalized_train_data[key] = normalize_data(data, stats[key])
158
+
159
+ self.indices = indices
160
+ self.stats = stats
161
+ self.normalized_train_data = normalized_train_data
162
+ self.pred_horizon = pred_horizon
163
+ self.action_horizon = action_horizon
164
+ self.obs_horizon = obs_horizon
165
+
166
+ def __len__(self):
167
+ return len(self.indices)
168
+
169
+ def __getitem__(self, idx):
170
+ # Get the start/end indices for this datapoint
171
+ buffer_start_idx, buffer_end_idx, \
172
+ sample_start_idx, sample_end_idx = self.indices[idx]
173
+
174
+ # Get normalized data using these indices
175
+ nsample = sample_sequence(
176
+ train_data=self.normalized_train_data,
177
+ sequence_length=self.pred_horizon,
178
+ buffer_start_idx=buffer_start_idx,
179
+ buffer_end_idx=buffer_end_idx,
180
+ sample_start_idx=sample_start_idx,
181
+ sample_end_idx=sample_end_idx
182
+ )
183
+
184
+ # Discard unused observations
185
+ nsample['obs'] = nsample['obs'][:self.obs_horizon,:]
186
+ return nsample
187
+
188
+ # Model Classes
189
+ class ObservationEncoder(nn.Module):
190
+ """Encodes observations for conditioning"""
191
+ def __init__(self, obs_dim: int, embed_dim: int):
192
+ super().__init__()
193
+ self.net = nn.Sequential(
194
+ nn.Linear(obs_dim, embed_dim * 2),
195
+ nn.Mish(),
196
+ nn.Linear(embed_dim * 2, embed_dim)
197
+ )
198
+
199
+ def forward(self, x):
200
+ # x: [batch, timesteps, obs_dim]
201
+ batch_size, timesteps, obs_dim = x.shape
202
+ x = x.reshape(-1, obs_dim)
203
+ x = self.net(x)
204
+ x = x.reshape(batch_size, timesteps * self.net[-1].out_features)
205
+ return x
206
+
207
+ def train_diffusion():
208
+ """Train diffusion model using HuggingFace diffusers"""
209
+ # Configs
210
+ data_config = DataConfig()
211
+ model_config = ModelConfig()
212
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
213
+ num_epochs = 100
214
+ print(f"Using device: {device}")
215
+
216
+ # Create dataset (from zarr file)
217
+ dataset = PushTStateDataset(
218
+ dataset_path=data_config.dataset_path,
219
+ pred_horizon=data_config.pred_horizon,
220
+ obs_horizon=data_config.obs_horizon,
221
+ action_horizon=data_config.action_horizon
222
+ )
223
+
224
+ # Assign stats and define save directory
225
+ stats = dataset.stats
226
+ save_dir = "checkpoints"
227
+ os.makedirs(save_dir, exist_ok=True)
228
+
229
+ dataloader = torch.utils.data.DataLoader(
230
+ dataset,
231
+ batch_size=256,
232
+ num_workers=4,
233
+ shuffle=True,
234
+ pin_memory=True,
235
+ persistent_workers=True
236
+ )
237
+
238
+ # Create observation encoder
239
+ obs_encoder = ObservationEncoder(
240
+ obs_dim=data_config.state_dim,
241
+ embed_dim=model_config.obs_embed_dim
242
+ ).to(device)
243
+
244
+ # Create UNet1D model from diffusers
245
+ model = UNet1DModel(
246
+ sample_size=model_config.sample_size,
247
+ in_channels=model_config.total_in_channels, # actions + conditioning
248
+ out_channels=model_config.out_channels,
249
+ layers_per_block=model_config.layers_per_block,
250
+ block_out_channels=model_config.block_out_channels,
251
+ norm_num_groups=model_config.norm_num_groups,
252
+ down_block_types=model_config.down_block_types,
253
+ up_block_types=model_config.up_block_types,
254
+ ).to(device)
255
+
256
+ # Create noise scheduler
257
+ noise_scheduler = DDPMScheduler(
258
+ num_train_timesteps=100,
259
+ beta_schedule="squaredcos_cap_v2",
260
+ clip_sample=True,
261
+ prediction_type="epsilon"
262
+ )
263
+
264
+ # Create projection layer OUTSIDE the training loop
265
+ obs_projection = nn.Linear(model_config.obs_embed_dim * data_config.obs_horizon,
266
+ model_config.obs_embed_dim // 8).to(device)
267
+
268
+ # Update optimizer to include projection layer
269
+ optimizer = torch.optim.AdamW([
270
+ {'params': model.parameters()},
271
+ {'params': obs_encoder.parameters()},
272
+ {'params': obs_projection.parameters()}
273
+ ], lr=1e-4)
274
+
275
+ # Update EMA to include projection layer
276
+ ema = EMAModel(
277
+ parameters=list(model.parameters()) +
278
+ list(obs_encoder.parameters()) +
279
+ list(obs_projection.parameters()),
280
+ power=0.75
281
+ )
282
+
283
+ for epoch in range(num_epochs):
284
+ progress_bar = tqdm(total=len(dataloader), desc=f'Epoch {epoch}')
285
+ epoch_loss = []
286
+
287
+ for batch in dataloader:
288
+ # Get batch data
289
+ obs = batch['obs'].to(device) # [batch, obs_horizon, obs_dim]
290
+ actions = batch['action'].to(device) # [batch, pred_horizon, action_dim]
291
+ batch_size = obs.shape[0]
292
+
293
+ # Encode observations for conditioning
294
+ obs_embedding = obs_encoder(obs) # [batch, obs_embed_dim * obs_horizon]
295
+
296
+ # Sample noise and timesteps
297
+ noise = torch.randn_like(actions)
298
+ timesteps = torch.randint(
299
+ 0, noise_scheduler.config.num_train_timesteps,
300
+ (batch_size,), device=device
301
+ ).long()
302
+
303
+ # Add noise to actions according to noise schedule
304
+ noisy_actions = noise_scheduler.add_noise(actions, noise, timesteps)
305
+
306
+ # Reshape to channels format for UNet
307
+ # [batch, pred_horizon, channels] -> [batch, channels, pred_horizon]
308
+ noisy_actions = noisy_actions.transpose(1, 2)
309
+ noise = noise.transpose(1, 2)
310
+
311
+ # Project the observation embedding
312
+ obs_cond = obs_projection(obs_embedding) # [batch, obs_embed_dim//8]
313
+
314
+ # Reshape to match sequence length
315
+ obs_cond = obs_cond.unsqueeze(-1).expand(-1, -1, noisy_actions.shape[-1])
316
+
317
+ # Concatenate along channel dimension
318
+ model_input = torch.cat([noisy_actions, obs_cond], dim=1)
319
+
320
+ noise_pred = model(
321
+ model_input,
322
+ timesteps,
323
+ ).sample # Removed slicing [:, :data_config.action_dim]
324
+
325
+ # Calculate loss
326
+ loss = F.mse_loss(noise_pred, noise)
327
+ epoch_loss.append(loss.item())
328
+
329
+ # Optimize
330
+ optimizer.zero_grad()
331
+ loss.backward()
332
+ optimizer.step()
333
+
334
+ # Update EMA parameters
335
+ ema.step(list(model.parameters()) +
336
+ list(obs_encoder.parameters()) +
337
+ list(obs_projection.parameters()))
338
+
339
+ # Update progress
340
+ progress_bar.update(1)
341
+ progress_bar.set_postfix(loss=loss.item())
342
+
343
+ progress_bar.close()
344
+
345
+ # Print epoch stats
346
+ avg_loss = sum(epoch_loss) / len(epoch_loss)
347
+ print(f"\nEpoch {epoch} average loss: {avg_loss:.6f}")
348
+
349
+ # Save checkpoint every 10 epochs
350
+ if (epoch + 1) % 10 == 0:
351
+ torch.save({
352
+ 'epoch': epoch,
353
+ 'model_state_dict': model.state_dict(),
354
+ 'encoder_state_dict': obs_encoder.state_dict(),
355
+ 'projection_state_dict': obs_projection.state_dict(),
356
+ 'ema_state_dict': ema.state_dict(),
357
+ 'optimizer_state_dict': optimizer.state_dict(),
358
+ # 'noise_scheduler_state_dict': noise_scheduler.state_dict(), # Removed
359
+ 'stats': stats,
360
+ 'loss': avg_loss,
361
+ }, os.path.join(save_dir, f'diffusion_checkpoint_{epoch}.pt'))
362
+
363
+ return model, obs_encoder, obs_projection, ema, noise_scheduler, optimizer, stats
364
+
365
+ def main():
366
+ # Download dataset if needed
367
+ config = DataConfig()
368
+ if not os.path.isfile(config.dataset_path):
369
+ print("Downloading dataset...")
370
+ gdown.download(id=config.dataset_gdrive_id, output=config.dataset_path, quiet=False)
371
+
372
+ # Create dataset
373
+ dataset = PushTStateDataset(
374
+ dataset_path=config.dataset_path,
375
+ pred_horizon=config.pred_horizon,
376
+ obs_horizon=config.obs_horizon,
377
+ action_horizon=config.action_horizon
378
+ )
379
+
380
+ # Test batch
381
+ dataloader = torch.utils.data.DataLoader(
382
+ dataset, batch_size=256, num_workers=1,
383
+ shuffle=True, pin_memory=True, persistent_workers=True
384
+ )
385
+ batch = next(iter(dataloader))
386
+ print("batch['obs'].shape:", batch['obs'].shape)
387
+ print("batch['action'].shape:", batch['action'].shape)
388
+
389
+ if __name__ == "__main__":
390
+ main()
391
+
392
+ print("\nStarting diffusion model training...")
393
+ model, obs_encoder, obs_projection, ema, noise_scheduler, optimizer, stats = train_diffusion()
394
+
395
+ # Save final model
396
+ save_dir = "checkpoints"
397
+ os.makedirs(save_dir, exist_ok=True)
398
+ torch.save({
399
+ 'model_state_dict': model.state_dict(),
400
+ 'encoder_state_dict': obs_encoder.state_dict(),
401
+ 'projection_state_dict': obs_projection.state_dict(),
402
+ 'ema_state_dict': ema.state_dict(),
403
+ 'optimizer_state_dict': optimizer.state_dict(),
404
+ # 'noise_scheduler_state_dict': noise_scheduler.state_dict(),
405
+ 'stats': stats
406
+ }, os.path.join(save_dir, 'diffusion_final.pt'))