| using System; | |
| using TorchSharp; | |
| public class DDIMSampler | |
| { | |
| private readonly DDPM _model; | |
| private const int TIME_STEPS = 1000; | |
| private readonly torch.Device _device; | |
| public DDIMSampler(DDPM model, float scale = 9.0f) | |
| { | |
| _model = model; | |
| _device = model.Device; | |
| } | |
| public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f) | |
| { | |
| var gap = DDIMSampler.TIME_STEPS / steps; | |
| var batch = img.shape[0]; | |
| using(var context = torch.enable_grad(false)) | |
| { | |
| for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap) | |
| { | |
| var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device); | |
| var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device); | |
| (var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur); | |
| var model_output = e_t_uncond + scale * (e_t - e_t_uncond); | |
| e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output); | |
| var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output); | |
| img = _model.QSample(pred_x0, t_prev, e_t); | |
| Console.WriteLine(img); | |
| } | |
| return img; | |
| } | |
| } | |
| } |