File size: 4,362 Bytes
4f2b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from lightning_modules import AnyOrderInsertionFlowModule, MaskedDiffusionModule
from sampling import (
    any_order_mask_insertion_euler_sampling,
    mdm_euler_sampling,
    any_order_mask_insertion_tau_leaping_sampling,
    mdm_tau_leaping_sampling,
)
from schedule import GeometricSchedule, LinearSchedule
from data.text import setup_tokeniser
import torch
import torch.nn as nn
import argparse
import json


torch.set_float32_matmul_precision("high")
torch.set_printoptions(threshold=10_000)

# Add argparse and remove hard-coded checkpoint_path
parser = argparse.ArgumentParser(description="Generate samples")
parser.add_argument(
    "--checkpoint_path",
    type=str,
    required=True,
    help="Path to the model checkpoint file",
)
# Add argparse arguments
parser.add_argument(
    "--total_samples",
    "-n",
    type=int,
    default=1024,
    help="Total number of samples to generate",
)
parser.add_argument(
    "--model_type",
    choices=["flow", "mdm"],
    default="flow",
    help="Model type to use: 'flow' or 'mdm'",
)
parser.add_argument(
    "--output_file",
    "-o",
    type=str,
    default="generated_samples.json",
    help="Path to save generated samples JSON",
)
parser.add_argument(
    "--batch_size", "-b", type=int, help="Batch size; defaults to #GPUs or 1"
)
parser.add_argument(
    "--step_size",
    type=int,
    default=2048,
    help="Number of sampling steps",
)
parser.add_argument(
    "--sampler_type",
    type=str,
    default="euler",
    choices=["euler", "tau-leaping"],
)
args = parser.parse_args()
checkpoint_path = args.checkpoint_path
output_file = args.output_file

# Load chosen model
if args.model_type == "flow":
    model = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path)
elif args.model_type == "mdm":
    model = MaskedDiffusionModule.load_from_checkpoint(checkpoint_path)
else:
    raise ValueError(f"Unknown model_type: {args.model_type}")
model.swap_to_ema()
model.eval()

# distribute model across GPUs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

tokeniser = setup_tokeniser("gpt2")

# determine processing batch size
per_gpu_batch_size = 32
if args.batch_size:
    processing_batch_size = args.batch_size
elif torch.cuda.is_available():
    processing_batch_size = per_gpu_batch_size * torch.cuda.device_count()
else:
    processing_batch_size = per_gpu_batch_size

step_size = args.step_size
total_samples = args.total_samples

all_samples = []
all_raw_samples = []

sample_fn = None

match (args.model_type, args.sampler_type):
    case ("flow", "euler"):
        sample_fn = any_order_mask_insertion_euler_sampling
    case ("mdm", "euler"):
        sample_fn = mdm_euler_sampling
    case ("flow", "tau-leaping"):
        sample_fn = any_order_mask_insertion_tau_leaping_sampling
    case ("mdm", "tau-leaping"):
        sample_fn = mdm_tau_leaping_sampling


for _ in range(0, total_samples, processing_batch_size):
    samples, _ = sample_fn(
        model,
        steps=step_size,
        mask=model.interpolant.mask_token,
        pad=model.interpolant.pad_token,
        batch_size=processing_batch_size,
        max_length=model.interpolant.max_length,
        return_trace=False,
    )
    if args.model_type == "flow":
        all_raw_samples.extend(samples.cpu().tolist())
    else:
        all_raw_samples.extend(samples.cpu().tolist())
        # post-process mdm samples: replace tokens after first pad with pad
        pad_id = model.interpolant.pad_token
        B, L = samples.shape
        pos = torch.arange(L, device=samples.device).unsqueeze(0).expand(B, L)
        pad_mask = samples == pad_id
        idxs = torch.where(pad_mask, pos, torch.full_like(pos, L))
        first_idx = idxs.min(dim=1).values
        mask_after = pos > first_idx.unsqueeze(1)
        samples[mask_after] = pad_id

    # store raw token IDs
    # Decode and strip samples
    decoded_samples = tokeniser.batch_decode(samples, skip_special_tokens=True)
    all_samples.extend(decoded_samples)

with open(output_file, "w") as f:
    json.dump(all_samples, f, indent=2)
# also write raw token outputs
with open(output_file.replace(".json", "_raw.json"), "w") as f:
    json.dump(all_raw_samples, f, indent=2)