dorsar commited on
Commit
07e6695
·
verified ·
1 Parent(s): c4038fa

Delete diffusion_model.py

Browse files
Files changed (1) hide show
  1. diffusion_model.py +0 -406
diffusion_model.py DELETED
@@ -1,406 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.data import Dataset
4
- import zarr
5
- import numpy as np
6
- from diffusers import UNet1DModel, DDPMScheduler
7
- from diffusers.optimization import get_scheduler
8
- from diffusers.training_utils import EMAModel
9
- import torch.nn.functional as F
10
- from tqdm.auto import tqdm
11
- from dataclasses import dataclass
12
- from typing import Tuple, Optional
13
- import os
14
- import gdown
15
- import collections
16
- from imageio import get_writer
17
-
18
- # Data Configurations
19
- @dataclass
20
- class DataConfig:
21
- """Configuration for dataset"""
22
- # Dataset paths and download info
23
- dataset_path: str = "pusht_cchi_v7_replay.zarr.zip"
24
- dataset_gdrive_id: str = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t"
25
-
26
- # Sequence parameters
27
- pred_horizon: int = 16 # Number of steps to predict
28
- obs_horizon: int = 2 # Number of observations to condition on
29
- action_horizon: int = 8 # Number of actions to execute
30
-
31
- # Data dimensions
32
- image_size: Tuple[int, int] = (96, 96)
33
- image_channels: int = 3
34
- action_dim: int = 2
35
- state_dim: int = 5 # [agent_x, agent_y, block_x, block_y, block_angle]
36
-
37
- @dataclass
38
- class ModelConfig:
39
- """Configuration for neural networks"""
40
- # Observation encoding
41
- obs_embed_dim: int = 256
42
-
43
- # UNet configuration
44
- sample_size: int = 16 # pred_horizon length
45
- in_channels: int = 2 # action dimension
46
- out_channels: int = 2 # action dimension
47
- layers_per_block: int = 2
48
- block_out_channels: Tuple[int, ...] = (128,)
49
- norm_num_groups: int = 8
50
- down_block_types: Tuple[str, ...] = ("DownBlock1D",) * 1
51
- up_block_types: Tuple[str, ...] = ("UpBlock1D",) * 1
52
-
53
- def __post_init__(self):
54
- # For conditioning through input channels
55
- self.total_in_channels = self.in_channels + self.obs_embed_dim //8 # actions + conditioning
56
-
57
-
58
- """
59
- Helper Functions
60
- """
61
- def create_sample_indices(
62
- episode_ends:np.ndarray, sequence_length:int,
63
- pad_before: int=0, pad_after: int=0):
64
- indices = list()
65
- for i in range(len(episode_ends)):
66
- start_idx = 0
67
- if i > 0:
68
- start_idx = episode_ends[i-1]
69
- end_idx = episode_ends[i]
70
- episode_length = end_idx - start_idx
71
-
72
- min_start = -pad_before
73
- max_start = episode_length - sequence_length + pad_after
74
-
75
- # range stops one idx before end
76
- for idx in range(min_start, max_start+1):
77
- buffer_start_idx = max(idx, 0) + start_idx
78
- buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
79
- start_offset = buffer_start_idx - (idx+start_idx)
80
- end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
81
- sample_start_idx = 0 + start_offset
82
- sample_end_idx = sequence_length - end_offset
83
- indices.append([
84
- buffer_start_idx, buffer_end_idx,
85
- sample_start_idx, sample_end_idx])
86
- indices = np.array(indices)
87
- return indices
88
-
89
- def sample_sequence(train_data, sequence_length,
90
- buffer_start_idx, buffer_end_idx,
91
- sample_start_idx, sample_end_idx):
92
- result = dict()
93
- for key, input_arr in train_data.items():
94
- sample = input_arr[buffer_start_idx:buffer_end_idx]
95
- data = sample
96
- if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
97
- data = np.zeros(
98
- shape=(sequence_length,) + input_arr.shape[1:],
99
- dtype=input_arr.dtype)
100
- if sample_start_idx > 0:
101
- data[:sample_start_idx] = sample[0]
102
- if sample_end_idx < sequence_length:
103
- data[sample_end_idx:] = sample[-1]
104
- data[sample_start_idx:sample_end_idx] = sample
105
- result[key] = data
106
- return result
107
-
108
- def get_data_stats(data):
109
- data = data.reshape(-1,data.shape[-1])
110
- stats = {
111
- 'min': np.min(data, axis=0),
112
- 'max': np.max(data, axis=0)
113
- }
114
- return stats
115
-
116
- def normalize_data(data, stats):
117
- # Normalize to [0,1]
118
- ndata = (data - stats['min']) / (stats['max'] - stats['min'])
119
- # Normalize to [-1, 1]
120
- ndata = ndata * 2 - 1
121
- return ndata
122
-
123
- def unnormalize_data(ndata, stats):
124
- ndata = (ndata + 1) / 2
125
- data = ndata * (stats['max'] - stats['min']) + stats['min']
126
- return data
127
-
128
- # Dataset Class
129
- class PushTStateDataset(torch.utils.data.Dataset):
130
- def __init__(self, dataset_path, pred_horizon, obs_horizon, action_horizon):
131
- # Read from zarr dataset
132
- dataset_root = zarr.open(dataset_path, 'r')
133
- # All demonstration episodes are concatenated in the first dimension N
134
- train_data = {
135
- # (N, action_dim)
136
- 'action': dataset_root['data']['action'][:],
137
- # (N, obs_dim)
138
- 'obs': dataset_root['data']['state'][:]
139
- }
140
- # Marks one-past the last index for each episode
141
- episode_ends = dataset_root['meta']['episode_ends'][:]
142
-
143
- # Compute start and end of each state-action sequence
144
- # Also handles padding
145
- indices = create_sample_indices(
146
- episode_ends=episode_ends,
147
- sequence_length=pred_horizon,
148
- # Add padding such that each timestep in the dataset are seen
149
- pad_before=obs_horizon-1,
150
- pad_after=action_horizon-1)
151
-
152
- # Compute statistics and normalize data to [-1,1]
153
- stats = dict()
154
- normalized_train_data = dict()
155
- for key, data in train_data.items():
156
- stats[key] = get_data_stats(data)
157
- normalized_train_data[key] = normalize_data(data, stats[key])
158
-
159
- self.indices = indices
160
- self.stats = stats
161
- self.normalized_train_data = normalized_train_data
162
- self.pred_horizon = pred_horizon
163
- self.action_horizon = action_horizon
164
- self.obs_horizon = obs_horizon
165
-
166
- def __len__(self):
167
- return len(self.indices)
168
-
169
- def __getitem__(self, idx):
170
- # Get the start/end indices for this datapoint
171
- buffer_start_idx, buffer_end_idx, \
172
- sample_start_idx, sample_end_idx = self.indices[idx]
173
-
174
- # Get normalized data using these indices
175
- nsample = sample_sequence(
176
- train_data=self.normalized_train_data,
177
- sequence_length=self.pred_horizon,
178
- buffer_start_idx=buffer_start_idx,
179
- buffer_end_idx=buffer_end_idx,
180
- sample_start_idx=sample_start_idx,
181
- sample_end_idx=sample_end_idx
182
- )
183
-
184
- # Discard unused observations
185
- nsample['obs'] = nsample['obs'][:self.obs_horizon,:]
186
- return nsample
187
-
188
- # Model Classes
189
- class ObservationEncoder(nn.Module):
190
- """Encodes observations for conditioning"""
191
- def __init__(self, obs_dim: int, embed_dim: int):
192
- super().__init__()
193
- self.net = nn.Sequential(
194
- nn.Linear(obs_dim, embed_dim * 2),
195
- nn.Mish(),
196
- nn.Linear(embed_dim * 2, embed_dim)
197
- )
198
-
199
- def forward(self, x):
200
- # x: [batch, timesteps, obs_dim]
201
- batch_size, timesteps, obs_dim = x.shape
202
- x = x.reshape(-1, obs_dim)
203
- x = self.net(x)
204
- x = x.reshape(batch_size, timesteps * self.net[-1].out_features)
205
- return x
206
-
207
- def train_diffusion():
208
- """Train diffusion model using HuggingFace diffusers"""
209
- # Configs
210
- data_config = DataConfig()
211
- model_config = ModelConfig()
212
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
213
- num_epochs = 100
214
- print(f"Using device: {device}")
215
-
216
- # Create dataset (from zarr file)
217
- dataset = PushTStateDataset(
218
- dataset_path=data_config.dataset_path,
219
- pred_horizon=data_config.pred_horizon,
220
- obs_horizon=data_config.obs_horizon,
221
- action_horizon=data_config.action_horizon
222
- )
223
-
224
- # Assign stats and define save directory
225
- stats = dataset.stats
226
- save_dir = "checkpoints"
227
- os.makedirs(save_dir, exist_ok=True)
228
-
229
- dataloader = torch.utils.data.DataLoader(
230
- dataset,
231
- batch_size=256,
232
- num_workers=4,
233
- shuffle=True,
234
- pin_memory=True,
235
- persistent_workers=True
236
- )
237
-
238
- # Create observation encoder
239
- obs_encoder = ObservationEncoder(
240
- obs_dim=data_config.state_dim,
241
- embed_dim=model_config.obs_embed_dim
242
- ).to(device)
243
-
244
- # Create UNet1D model from diffusers
245
- model = UNet1DModel(
246
- sample_size=model_config.sample_size,
247
- in_channels=model_config.total_in_channels, # actions + conditioning
248
- out_channels=model_config.out_channels,
249
- layers_per_block=model_config.layers_per_block,
250
- block_out_channels=model_config.block_out_channels,
251
- norm_num_groups=model_config.norm_num_groups,
252
- down_block_types=model_config.down_block_types,
253
- up_block_types=model_config.up_block_types,
254
- ).to(device)
255
-
256
- # Create noise scheduler
257
- noise_scheduler = DDPMScheduler(
258
- num_train_timesteps=100,
259
- beta_schedule="squaredcos_cap_v2",
260
- clip_sample=True,
261
- prediction_type="epsilon"
262
- )
263
-
264
- # Create projection layer OUTSIDE the training loop
265
- obs_projection = nn.Linear(model_config.obs_embed_dim * data_config.obs_horizon,
266
- model_config.obs_embed_dim // 8).to(device)
267
-
268
- # Update optimizer to include projection layer
269
- optimizer = torch.optim.AdamW([
270
- {'params': model.parameters()},
271
- {'params': obs_encoder.parameters()},
272
- {'params': obs_projection.parameters()}
273
- ], lr=1e-4)
274
-
275
- # Update EMA to include projection layer
276
- ema = EMAModel(
277
- parameters=list(model.parameters()) +
278
- list(obs_encoder.parameters()) +
279
- list(obs_projection.parameters()),
280
- power=0.75
281
- )
282
-
283
- for epoch in range(num_epochs):
284
- progress_bar = tqdm(total=len(dataloader), desc=f'Epoch {epoch}')
285
- epoch_loss = []
286
-
287
- for batch in dataloader:
288
- # Get batch data
289
- obs = batch['obs'].to(device) # [batch, obs_horizon, obs_dim]
290
- actions = batch['action'].to(device) # [batch, pred_horizon, action_dim]
291
- batch_size = obs.shape[0]
292
-
293
- # Encode observations for conditioning
294
- obs_embedding = obs_encoder(obs) # [batch, obs_embed_dim * obs_horizon]
295
-
296
- # Sample noise and timesteps
297
- noise = torch.randn_like(actions)
298
- timesteps = torch.randint(
299
- 0, noise_scheduler.config.num_train_timesteps,
300
- (batch_size,), device=device
301
- ).long()
302
-
303
- # Add noise to actions according to noise schedule
304
- noisy_actions = noise_scheduler.add_noise(actions, noise, timesteps)
305
-
306
- # Reshape to channels format for UNet
307
- # [batch, pred_horizon, channels] -> [batch, channels, pred_horizon]
308
- noisy_actions = noisy_actions.transpose(1, 2)
309
- noise = noise.transpose(1, 2)
310
-
311
- # Project the observation embedding
312
- obs_cond = obs_projection(obs_embedding) # [batch, obs_embed_dim//8]
313
-
314
- # Reshape to match sequence length
315
- obs_cond = obs_cond.unsqueeze(-1).expand(-1, -1, noisy_actions.shape[-1])
316
-
317
- # Concatenate along channel dimension
318
- model_input = torch.cat([noisy_actions, obs_cond], dim=1)
319
-
320
- noise_pred = model(
321
- model_input,
322
- timesteps,
323
- ).sample # Removed slicing [:, :data_config.action_dim]
324
-
325
- # Calculate loss
326
- loss = F.mse_loss(noise_pred, noise)
327
- epoch_loss.append(loss.item())
328
-
329
- # Optimize
330
- optimizer.zero_grad()
331
- loss.backward()
332
- optimizer.step()
333
-
334
- # Update EMA parameters
335
- ema.step(list(model.parameters()) +
336
- list(obs_encoder.parameters()) +
337
- list(obs_projection.parameters()))
338
-
339
- # Update progress
340
- progress_bar.update(1)
341
- progress_bar.set_postfix(loss=loss.item())
342
-
343
- progress_bar.close()
344
-
345
- # Print epoch stats
346
- avg_loss = sum(epoch_loss) / len(epoch_loss)
347
- print(f"\nEpoch {epoch} average loss: {avg_loss:.6f}")
348
-
349
- # Save checkpoint every 10 epochs
350
- if (epoch + 1) % 10 == 0:
351
- torch.save({
352
- 'epoch': epoch,
353
- 'model_state_dict': model.state_dict(),
354
- 'encoder_state_dict': obs_encoder.state_dict(),
355
- 'projection_state_dict': obs_projection.state_dict(),
356
- 'ema_state_dict': ema.state_dict(),
357
- 'optimizer_state_dict': optimizer.state_dict(),
358
- # 'noise_scheduler_state_dict': noise_scheduler.state_dict(), # Removed
359
- 'stats': stats,
360
- 'loss': avg_loss,
361
- }, os.path.join(save_dir, f'diffusion_checkpoint_{epoch}.pt'))
362
-
363
- return model, obs_encoder, obs_projection, ema, noise_scheduler, optimizer, stats
364
-
365
- def main():
366
- # Download dataset if needed
367
- config = DataConfig()
368
- if not os.path.isfile(config.dataset_path):
369
- print("Downloading dataset...")
370
- gdown.download(id=config.dataset_gdrive_id, output=config.dataset_path, quiet=False)
371
-
372
- # Create dataset
373
- dataset = PushTStateDataset(
374
- dataset_path=config.dataset_path,
375
- pred_horizon=config.pred_horizon,
376
- obs_horizon=config.obs_horizon,
377
- action_horizon=config.action_horizon
378
- )
379
-
380
- # Test batch
381
- dataloader = torch.utils.data.DataLoader(
382
- dataset, batch_size=256, num_workers=1,
383
- shuffle=True, pin_memory=True, persistent_workers=True
384
- )
385
- batch = next(iter(dataloader))
386
- print("batch['obs'].shape:", batch['obs'].shape)
387
- print("batch['action'].shape:", batch['action'].shape)
388
-
389
- if __name__ == "__main__":
390
- main()
391
-
392
- print("\nStarting diffusion model training...")
393
- model, obs_encoder, obs_projection, ema, noise_scheduler, optimizer, stats = train_diffusion()
394
-
395
- # Save final model
396
- save_dir = "checkpoints"
397
- os.makedirs(save_dir, exist_ok=True)
398
- torch.save({
399
- 'model_state_dict': model.state_dict(),
400
- 'encoder_state_dict': obs_encoder.state_dict(),
401
- 'projection_state_dict': obs_projection.state_dict(),
402
- 'ema_state_dict': ema.state_dict(),
403
- 'optimizer_state_dict': optimizer.state_dict(),
404
- # 'noise_scheduler_state_dict': noise_scheduler.state_dict(),
405
- 'stats': stats
406
- }, os.path.join(save_dir, 'diffusion_final.pt'))