| using System; | |
| using TorchSharp; | |
| public class DDPM : IDisposable | |
| { | |
| private torch.jit.ScriptModule _model; | |
| public torch.Device Device {get;} | |
| public DDPM(string modelPath, torch.Device device) | |
| { | |
| _model = TorchSharp.torch.jit.load(modelPath); | |
| Device = device; | |
| _model.to(Device); | |
| _model.eval(); | |
| } | |
| public (torch.Tensor e_T_Uncondition, torch.Tensor e_T) DiffusionModel(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, torch.Tensor t) | |
| { | |
| var x_in = torch.cat(new[] { img, img }); | |
| var condition_in = torch.cat(new[] { unconditional_condition, condition }); | |
| var t_in = torch.cat(new[] { t, t }); | |
| var res = _model.invoke<torch.Tensor>("diffusion_model", x_in, t_in, condition_in).chunk(2); | |
| return (res[0], res[1]); | |
| } | |
| public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v) | |
| { | |
| return _model.invoke<torch.Tensor>("q_sample",z, t, v); | |
| } | |
| public torch.Tensor PredictEPSFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v) | |
| { | |
| return _model.invoke<torch.Tensor>("predict_eps_from_z_and_v", z, t, v); | |
| } | |
| public torch.Tensor PredictStartFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v) | |
| { | |
| return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v); | |
| } | |
| public void Dispose() | |
| { | |
| _model.Dispose(); | |
| _model = null; | |
| } | |
| } | |