File size: 2,562 Bytes
be5d479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from Diffusion.networks import get_net
from Dataloader.dataLoader import DummyOMDataset_indiv
import argparse
import yaml
import os
import time

parser = argparse.ArgumentParser()
parser.add_argument("--config", "-C", type=str, default="Config/config_om_contrastive.yaml")
parser.add_argument("--dummy-samples", type=int, default=100, help="Number of dummy samples")
args = parser.parse_args()

with open(args.config, 'r') as file:
    hyp = yaml.safe_load(file)

# Setup device: prefer XPU, fallback to CUDA, then CPU
if hasattr(torch, 'xpu') and torch.xpu.is_available():
    device = torch.device('xpu')
    print(f"Using XPU device: {torch.xpu.get_device_name(0)}")
elif torch.cuda.is_available():
    device = torch.device(hyp['device'])
    print(f"Using CUDA device")
else:
    device = torch.device('cpu')
    print(f"Using CPU device")

data_name = hyp['data_name']
net_name = hyp['net_name']
ndims = hyp['ndims']
img_size = hyp['img_size']
model_save_path = os.path.join('Models', f'{data_name}_{net_name}/')
os.makedirs(model_save_path, exist_ok=True)

# Model
Net = get_net(net_name)
model = Net(n_steps=hyp['timesteps'], ndims=ndims, num_input_chn=hyp['num_input_chn'], res=img_size).to(device)
optimizer = Adam(model.parameters(), lr=hyp['lr'])

# Data - dummy dataset for XPU testing
dataset = DummyOMDataset_indiv(out_sz=img_size, num_samples=args.dummy_samples)
train_loader = DataLoader(dataset, batch_size=hyp['batchsize'], shuffle=True, drop_last=True)

# Training
print(f'Start training on {device} with {len(dataset)} dummy samples...')
for epoch in range(hyp['epoch']):
    epoch_loss = 0.0

    for i, (volume, embd) in enumerate(train_loader):
        t0 = time.time()
        volume = volume.float().to(device)
        embd = embd.to(device)  # [B, 1024] GT text embedding
        t = torch.randint(0, hyp['timesteps'], (volume.shape[0],)).to(device)

        _, img_embd = model(x=volume, y=volume, t=t)  # img_embd: [B, 1024]

        # Cosine similarity loss: align img_embd with GT text embedding
        loss = 1 - F.cosine_similarity(img_embd, embd, dim=-1).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        t1 = time.time()
        dt = t1 - t0
        print(f"  Batch {i:04d} | Loss: {loss.item():.6f} | Time: {dt:.2f}s")
    avg_loss = epoch_loss / max(len(train_loader), 1)
    print(f"Epoch {epoch:04d} | Avg Loss: {avg_loss:.6f}")