| using System; |
| using TorchSharp; |
|
|
| public class DDPM : IDisposable |
| { |
| private torch.jit.ScriptModule _model; |
| public torch.Device Device {get;} |
| public DDPM(string modelPath, torch.Device device) |
| { |
| _model = TorchSharp.torch.jit.load(modelPath); |
| Device = device; |
| _model.to(Device); |
| _model.eval(); |
| } |
|
|
| public (torch.Tensor e_T_Uncondition, torch.Tensor e_T) DiffusionModel(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, torch.Tensor t) |
| { |
| var x_in = torch.cat(new[] { img, img }); |
| var condition_in = torch.cat(new[] { unconditional_condition, condition }); |
| var t_in = torch.cat(new[] { t, t }); |
| var res = _model.invoke<torch.Tensor>("diffusion_model", x_in, t_in, condition_in).chunk(2); |
| return (res[0], res[1]); |
| } |
|
|
| public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v) |
| { |
| return _model.invoke<torch.Tensor>("q_sample",z, t, v); |
| } |
|
|
| public torch.Tensor PredictEPSFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v) |
| { |
| return _model.invoke<torch.Tensor>("predict_eps_from_z_and_v", z, t, v); |
| } |
|
|
| public torch.Tensor PredictStartFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v) |
| { |
| return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v); |
| } |
|
|
| public void Dispose() |
| { |
| _model.Dispose(); |
| _model = null; |
| } |
| } |
|
|