Yuanhan Mo
Add dummy datasets for XPU testing, XPU contrastive training script, and CLAUDE.md
be5d479 | import torch | |
| import torch.nn.functional as F | |
| from torch.optim import Adam | |
| from torch.utils.data import DataLoader | |
| from Diffusion.networks import get_net | |
| from Dataloader.dataLoader import DummyOMDataset_indiv | |
| import argparse | |
| import yaml | |
| import os | |
| import time | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", "-C", type=str, default="Config/config_om_contrastive.yaml") | |
| parser.add_argument("--dummy-samples", type=int, default=100, help="Number of dummy samples") | |
| args = parser.parse_args() | |
| with open(args.config, 'r') as file: | |
| hyp = yaml.safe_load(file) | |
| # Setup device: prefer XPU, fallback to CUDA, then CPU | |
| if hasattr(torch, 'xpu') and torch.xpu.is_available(): | |
| device = torch.device('xpu') | |
| print(f"Using XPU device: {torch.xpu.get_device_name(0)}") | |
| elif torch.cuda.is_available(): | |
| device = torch.device(hyp['device']) | |
| print(f"Using CUDA device") | |
| else: | |
| device = torch.device('cpu') | |
| print(f"Using CPU device") | |
| data_name = hyp['data_name'] | |
| net_name = hyp['net_name'] | |
| ndims = hyp['ndims'] | |
| img_size = hyp['img_size'] | |
| model_save_path = os.path.join('Models', f'{data_name}_{net_name}/') | |
| os.makedirs(model_save_path, exist_ok=True) | |
| # Model | |
| Net = get_net(net_name) | |
| model = Net(n_steps=hyp['timesteps'], ndims=ndims, num_input_chn=hyp['num_input_chn'], res=img_size).to(device) | |
| optimizer = Adam(model.parameters(), lr=hyp['lr']) | |
| # Data - dummy dataset for XPU testing | |
| dataset = DummyOMDataset_indiv(out_sz=img_size, num_samples=args.dummy_samples) | |
| train_loader = DataLoader(dataset, batch_size=hyp['batchsize'], shuffle=True, drop_last=True) | |
| # Training | |
| print(f'Start training on {device} with {len(dataset)} dummy samples...') | |
| for epoch in range(hyp['epoch']): | |
| epoch_loss = 0.0 | |
| for i, (volume, embd) in enumerate(train_loader): | |
| t0 = time.time() | |
| volume = volume.float().to(device) | |
| embd = embd.to(device) # [B, 1024] GT text embedding | |
| t = torch.randint(0, hyp['timesteps'], (volume.shape[0],)).to(device) | |
| _, img_embd = model(x=volume, y=volume, t=t) # img_embd: [B, 1024] | |
| # Cosine similarity loss: align img_embd with GT text embedding | |
| loss = 1 - F.cosine_similarity(img_embd, embd, dim=-1).mean() | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| t1 = time.time() | |
| dt = t1 - t0 | |
| print(f" Batch {i:04d} | Loss: {loss.item():.6f} | Time: {dt:.2f}s") | |
| avg_loss = epoch_loss / max(len(train_loader), 1) | |
| print(f"Epoch {epoch:04d} | Avg Loss: {avg_loss:.6f}") | |