littlelittlecloud
commited on
Commit
·
e3192e0
1
Parent(s):
db2ecf2
refactor and add DDPM and DDIMSampler
Browse files- .gitattributes +2 -0
- DDIMSampler.cs +34 -0
- DDPM.cs +48 -0
- Program.cs +11 -29
- cat.png +2 -2
- ddim_v_sampler.ckpt +2 -2
.gitattributes
CHANGED
|
@@ -22,6 +22,7 @@
|
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
@@ -32,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 26 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 27 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
cat.png filter=lfs diff=lfs merge=lfs -text
|
DDIMSampler.cs
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
using TorchSharp;
|
| 2 |
+
|
| 3 |
+
public class DDIMSampler
|
| 4 |
+
{
|
| 5 |
+
private readonly DDPM _model;
|
| 6 |
+
private const int TIME_STEPS = 1000;
|
| 7 |
+
private readonly torch.Device _device;
|
| 8 |
+
|
| 9 |
+
public DDIMSampler(DDPM model, float scale = 9.0f)
|
| 10 |
+
{
|
| 11 |
+
_model = model;
|
| 12 |
+
_device = model.Device;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
|
| 16 |
+
{
|
| 17 |
+
var gap = DDIMSampler.TIME_STEPS / steps;
|
| 18 |
+
using(var context = torch.enable_grad(false))
|
| 19 |
+
{
|
| 20 |
+
for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
|
| 21 |
+
{
|
| 22 |
+
var t_cur = torch.full(1, i, dtype: torch.ScalarType.Int64, device: _device);
|
| 23 |
+
var t_prev = torch.full(1, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
|
| 24 |
+
(var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
|
| 25 |
+
var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
|
| 26 |
+
e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
|
| 27 |
+
var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
|
| 28 |
+
img = _model.QSample(pred_x0, t_prev, e_t);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
return img;
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
}
|
DDPM.cs
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
using TorchSharp;
|
| 2 |
+
|
| 3 |
+
public class DDPM
|
| 4 |
+
{
|
| 5 |
+
private readonly torch.jit.ScriptModule _model;
|
| 6 |
+
public torch.Device Device {get;}
|
| 7 |
+
public DDPM(string modelPath, torch.Device device)
|
| 8 |
+
{
|
| 9 |
+
_model = TorchSharp.torch.jit.load(modelPath);
|
| 10 |
+
Device = device;
|
| 11 |
+
_model.to(Device);
|
| 12 |
+
_model.eval();
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
public (torch.Tensor e_T_Uncondition, torch.Tensor e_T) DiffusionModel(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, torch.Tensor t)
|
| 16 |
+
{
|
| 17 |
+
var x_in = torch.cat(new[] { img, img });
|
| 18 |
+
var condition_in = torch.cat(new[] { unconditional_condition, condition });
|
| 19 |
+
var t_in = torch.cat(new[] { t, t });
|
| 20 |
+
var res = _model.invoke<torch.Tensor>("diffusion_model", x_in, t_in, condition_in).chunk(2);
|
| 21 |
+
return (res[0], res[1]);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
public torch.Tensor DecodeImage(torch.Tensor img)
|
| 25 |
+
{
|
| 26 |
+
return _model.invoke<torch.Tensor>("decode_image", img);
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
public torch.Tensor ClipEncoder(torch.Tensor tokenTensor)
|
| 30 |
+
{
|
| 31 |
+
return _model.invoke<torch.Tensor>("clip_encoder", tokenTensor);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v)
|
| 35 |
+
{
|
| 36 |
+
return _model.invoke<torch.Tensor>("q_sample",z, t, v);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
public torch.Tensor PredictEPSFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v)
|
| 40 |
+
{
|
| 41 |
+
return _model.invoke<torch.Tensor>("predict_eps_from_z_and_v", z, t, v);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
public torch.Tensor PredictStartFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v)
|
| 45 |
+
{
|
| 46 |
+
return _model.invoke<torch.Tensor>("predict_start_from_z_and_v", z, t, v);
|
| 47 |
+
}
|
| 48 |
+
}
|
Program.cs
CHANGED
|
@@ -3,11 +3,11 @@ using System.Collections.Generic;
|
|
| 3 |
using System.IO;
|
| 4 |
using System.Linq;
|
| 5 |
using TorchSharp;
|
|
|
|
| 6 |
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
| 7 |
var device = TorchSharp.torch.device("cuda:0");
|
| 8 |
-
var
|
| 9 |
-
|
| 10 |
-
ddpm_v_sampler.eval();
|
| 11 |
|
| 12 |
var start_token = 49406;
|
| 13 |
var end_token = 49407;
|
|
@@ -34,30 +34,12 @@ var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype
|
|
| 34 |
unconditional_tokenTensor = unconditional_tokenTensor.reshape((long)batch, -1);
|
| 35 |
var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
|
| 36 |
var t = torch.ones(batch, dtype: torch.ScalarType.Int32, device: device);
|
| 37 |
-
var condition =
|
| 38 |
-
var unconditional_condition =
|
| 39 |
-
Console.WriteLine(condition);
|
| 40 |
-
var timesteps = 1000;
|
| 41 |
var ddim_steps = 50;
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: device);
|
| 49 |
-
img = (torch.Tensor)ddpm_v_sampler.invoke("ddim_sampler", img, condition, unconditional_condition, t_cur, t_prev);
|
| 50 |
-
Console.WriteLine($"step {i}");
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
var decoded_images = (torch.Tensor)ddpm_v_sampler.invoke("decode_image", img);
|
| 54 |
-
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
| 55 |
-
|
| 56 |
-
for(int i = 0; i!= batch; ++i)
|
| 57 |
-
{
|
| 58 |
-
// c * h * w
|
| 59 |
-
var image = decoded_images[i];
|
| 60 |
-
image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
|
| 61 |
-
torchvision.io.write_image(image, $"{i}.png", torchvision.ImageFormat.Png);
|
| 62 |
-
}
|
| 63 |
-
}
|
|
|
|
| 3 |
using System.IO;
|
| 4 |
using System.Linq;
|
| 5 |
using TorchSharp;
|
| 6 |
+
|
| 7 |
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
|
| 8 |
var device = TorchSharp.torch.device("cuda:0");
|
| 9 |
+
var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
|
| 10 |
+
var ddimSampler = new DDIMSampler(ddpm);
|
|
|
|
| 11 |
|
| 12 |
var start_token = 49406;
|
| 13 |
var end_token = 49407;
|
|
|
|
| 34 |
unconditional_tokenTensor = unconditional_tokenTensor.reshape((long)batch, -1);
|
| 35 |
var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
|
| 36 |
var t = torch.ones(batch, dtype: torch.ScalarType.Int32, device: device);
|
| 37 |
+
var condition = ddpm.ClipEncoder(tokenTensor);
|
| 38 |
+
var unconditional_condition = ddpm.ClipEncoder(unconditional_tokenTensor);
|
|
|
|
|
|
|
| 39 |
var ddim_steps = 50;
|
| 40 |
+
img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
|
| 41 |
+
var decoded_images = (torch.Tensor)ddpm.DecodeImage(img);
|
| 42 |
+
decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
|
| 43 |
+
var image = decoded_images[0];
|
| 44 |
+
image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
|
| 45 |
+
torchvision.io.write_image(image, $"0.png", torchvision.ImageFormat.Png);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cat.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
ddim_v_sampler.ckpt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22b16b2fc18c3b20c0eb74ed49a8f1834388fbfd84a49110340943f22fd30fa1
|
| 3 |
+
size 5216915007
|