Spaces:
No application file
No application file
| from lightning_modules import AnyOrderInsertionFlowModule, MaskedDiffusionModule | |
| from sampling import ( | |
| any_order_mask_insertion_euler_sampling, | |
| mdm_euler_sampling, | |
| any_order_mask_insertion_tau_leaping_sampling, | |
| mdm_tau_leaping_sampling, | |
| ) | |
| from schedule import GeometricSchedule, LinearSchedule | |
| from data.text import setup_tokeniser | |
| import torch | |
| import torch.nn as nn | |
| import argparse | |
| import json | |
| torch.set_float32_matmul_precision("high") | |
| torch.set_printoptions(threshold=10_000) | |
| # Add argparse and remove hard-coded checkpoint_path | |
| parser = argparse.ArgumentParser(description="Generate samples") | |
| parser.add_argument( | |
| "--checkpoint_path", | |
| type=str, | |
| required=True, | |
| help="Path to the model checkpoint file", | |
| ) | |
| # Add argparse arguments | |
| parser.add_argument( | |
| "--total_samples", | |
| "-n", | |
| type=int, | |
| default=1024, | |
| help="Total number of samples to generate", | |
| ) | |
| parser.add_argument( | |
| "--model_type", | |
| choices=["flow", "mdm"], | |
| default="flow", | |
| help="Model type to use: 'flow' or 'mdm'", | |
| ) | |
| parser.add_argument( | |
| "--output_file", | |
| "-o", | |
| type=str, | |
| default="generated_samples.json", | |
| help="Path to save generated samples JSON", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", "-b", type=int, help="Batch size; defaults to #GPUs or 1" | |
| ) | |
| parser.add_argument( | |
| "--step_size", | |
| type=int, | |
| default=2048, | |
| help="Number of sampling steps", | |
| ) | |
| parser.add_argument( | |
| "--sampler_type", | |
| type=str, | |
| default="euler", | |
| choices=["euler", "tau-leaping"], | |
| ) | |
| args = parser.parse_args() | |
| checkpoint_path = args.checkpoint_path | |
| output_file = args.output_file | |
| # Load chosen model | |
| if args.model_type == "flow": | |
| model = AnyOrderInsertionFlowModule.load_from_checkpoint(checkpoint_path) | |
| elif args.model_type == "mdm": | |
| model = MaskedDiffusionModule.load_from_checkpoint(checkpoint_path) | |
| else: | |
| raise ValueError(f"Unknown model_type: {args.model_type}") | |
| model.swap_to_ema() | |
| model.eval() | |
| # distribute model across GPUs | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| if torch.cuda.device_count() > 1: | |
| print(f"Using {torch.cuda.device_count()} GPUs") | |
| model = nn.DataParallel(model) | |
| tokeniser = setup_tokeniser("gpt2") | |
| # determine processing batch size | |
| per_gpu_batch_size = 32 | |
| if args.batch_size: | |
| processing_batch_size = args.batch_size | |
| elif torch.cuda.is_available(): | |
| processing_batch_size = per_gpu_batch_size * torch.cuda.device_count() | |
| else: | |
| processing_batch_size = per_gpu_batch_size | |
| step_size = args.step_size | |
| total_samples = args.total_samples | |
| all_samples = [] | |
| all_raw_samples = [] | |
| sample_fn = None | |
| match (args.model_type, args.sampler_type): | |
| case ("flow", "euler"): | |
| sample_fn = any_order_mask_insertion_euler_sampling | |
| case ("mdm", "euler"): | |
| sample_fn = mdm_euler_sampling | |
| case ("flow", "tau-leaping"): | |
| sample_fn = any_order_mask_insertion_tau_leaping_sampling | |
| case ("mdm", "tau-leaping"): | |
| sample_fn = mdm_tau_leaping_sampling | |
| for _ in range(0, total_samples, processing_batch_size): | |
| samples, _ = sample_fn( | |
| model, | |
| steps=step_size, | |
| mask=model.interpolant.mask_token, | |
| pad=model.interpolant.pad_token, | |
| batch_size=processing_batch_size, | |
| max_length=model.interpolant.max_length, | |
| return_trace=False, | |
| ) | |
| if args.model_type == "flow": | |
| all_raw_samples.extend(samples.cpu().tolist()) | |
| else: | |
| all_raw_samples.extend(samples.cpu().tolist()) | |
| # post-process mdm samples: replace tokens after first pad with pad | |
| pad_id = model.interpolant.pad_token | |
| B, L = samples.shape | |
| pos = torch.arange(L, device=samples.device).unsqueeze(0).expand(B, L) | |
| pad_mask = samples == pad_id | |
| idxs = torch.where(pad_mask, pos, torch.full_like(pos, L)) | |
| first_idx = idxs.min(dim=1).values | |
| mask_after = pos > first_idx.unsqueeze(1) | |
| samples[mask_after] = pad_id | |
| # store raw token IDs | |
| # Decode and strip samples | |
| decoded_samples = tokeniser.batch_decode(samples, skip_special_tokens=True) | |
| all_samples.extend(decoded_samples) | |
| with open(output_file, "w") as f: | |
| json.dump(all_samples, f, indent=2) | |
| # also write raw token outputs | |
| with open(output_file.replace(".json", "_raw.json"), "w") as f: | |
| json.dump(all_raw_samples, f, indent=2) | |