Spaces:
Paused
Paused
Switched back to slerp for latents interpolation
Browse files
app.py
CHANGED
|
@@ -45,6 +45,17 @@ def InitializeOutpainting():
|
|
| 45 |
pipeline = StableDiffusionInpaintPipeline.from_pretrained(modelNames[modelIndex])
|
| 46 |
#safety_checker=lambda images, **kwargs: (images, False))
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps):
|
| 49 |
global lastImage, lastSeed, generator, oldLatentWalk, activeLatents
|
| 50 |
|
|
@@ -55,7 +66,7 @@ def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInf
|
|
| 55 |
GenerateNewLatentsForInference()
|
| 56 |
|
| 57 |
if oldLatentWalk != latentWalk:
|
| 58 |
-
activeLatents =
|
| 59 |
oldLatentWalk = latentWalk
|
| 60 |
|
| 61 |
if lastSeed != generatorSeed:
|
|
|
|
| 45 |
pipeline = StableDiffusionInpaintPipeline.from_pretrained(modelNames[modelIndex])
|
| 46 |
#safety_checker=lambda images, **kwargs: (images, False))
|
| 47 |
|
| 48 |
+
# Based on: https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
|
| 49 |
+
# Further optimized to trade a divide operation for a multiply
|
| 50 |
+
def slerp(start, end, alpha):
|
| 51 |
+
start_norm = torch.norm(start, dim=1, keepdim=True)
|
| 52 |
+
end_norm = torch.norm(end, dim=1, keepdim=True)
|
| 53 |
+
omega = torch.acos((start*end/(start_norm*end_norm)).sum(1))
|
| 54 |
+
sinOmega = torch.sin(omega)
|
| 55 |
+
first = torch.sin((1.0-alpha)*omega)/sinOmega
|
| 56 |
+
second = torch.sin(alpha*omega)/sinOmega
|
| 57 |
+
return first.unsqueeze(1)*start + second.unsqueeze(1)*end
|
| 58 |
+
|
| 59 |
def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps):
|
| 60 |
global lastImage, lastSeed, generator, oldLatentWalk, activeLatents
|
| 61 |
|
|
|
|
| 66 |
GenerateNewLatentsForInference()
|
| 67 |
|
| 68 |
if oldLatentWalk != latentWalk:
|
| 69 |
+
activeLatents = slerp(oldLatents, latents, latentWalk)
|
| 70 |
oldLatentWalk = latentWalk
|
| 71 |
|
| 72 |
if lastSeed != generatorSeed:
|