Spaces:
Runtime error
Runtime error
Commit ·
3a72088
1
Parent(s): 1473645
Update app.py
Browse files
app.py
CHANGED
|
@@ -29,11 +29,11 @@ model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
|
|
| 29 |
model = model.half().cuda().eval().requires_grad_(False)
|
| 30 |
clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0]
|
| 31 |
|
| 32 |
-
def run_all(prompt, steps, n_images, weight):
|
| 33 |
import random
|
| 34 |
seed = int(random.randint(0, 2147483647))
|
| 35 |
target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda()
|
| 36 |
-
|
| 37 |
def cfg_model_fn(x, t):
|
| 38 |
"""The CFG wrapper function."""
|
| 39 |
n = x.shape[0]
|
|
@@ -44,14 +44,41 @@ def run_all(prompt, steps, n_images, weight):
|
|
| 44 |
v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
|
| 45 |
v = v_uncond + (v_cond - v_uncond) * weight
|
| 46 |
return v
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
gc.collect()
|
| 49 |
torch.cuda.empty_cache()
|
| 50 |
torch.manual_seed(seed)
|
| 51 |
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
|
| 52 |
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
|
| 53 |
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
images_out = []
|
| 56 |
for i, out in enumerate(outs):
|
| 57 |
images_out.append(utils.to_pil_image(out))
|
|
@@ -65,15 +92,10 @@ iface = gr.Interface(
|
|
| 65 |
fn=run_all,
|
| 66 |
inputs=[
|
| 67 |
gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
|
| 68 |
-
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=
|
| 69 |
-
gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1,step=1),
|
| 70 |
-
gr.inputs.Slider(label="Weight", default=5, maximum=15, minimum=0, step=1),
|
| 71 |
-
|
| 72 |
-
#gr.inputs.Dropdown(label="Flavor",choices=["ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu", "custom"]),
|
| 73 |
-
#markdown,
|
| 74 |
-
#gr.inputs.Dropdown(label="Style",choices=["Default","Balanced","Detailed","Consistent Creativity","Realistic","Smooth","Subtle MSE","Hyper Fast Results"],default="Hyper Fast Results"),
|
| 75 |
-
#gr.inputs.Radio(label="Width", choices=[32,64,128,256,512],default=512),
|
| 76 |
-
#gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=512),
|
| 77 |
],
|
| 78 |
outputs=gallery,
|
| 79 |
title="Generate images from text with V-Diffusion CC12M CFG",
|
|
|
|
| 29 |
model = model.half().cuda().eval().requires_grad_(False)
|
| 30 |
clip_model = clip.load(model.clip_model, jit=False, device='cpu')[0]
|
| 31 |
|
| 32 |
+
def run_all(prompt, steps, n_images, weight, clip_guided):
|
| 33 |
import random
|
| 34 |
seed = int(random.randint(0, 2147483647))
|
| 35 |
target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda()
|
| 36 |
+
clip_embed = target_embed.repeat([n, 1])
|
| 37 |
def cfg_model_fn(x, t):
|
| 38 |
"""The CFG wrapper function."""
|
| 39 |
n = x.shape[0]
|
|
|
|
| 44 |
v_uncond, v_cond = model(x_in, t_in, clip_embed_in).chunk(2, dim=0)
|
| 45 |
v = v_uncond + (v_cond - v_uncond) * weight
|
| 46 |
return v
|
| 47 |
+
|
| 48 |
+
def make_cond_model_fn(model, cond_fn):
|
| 49 |
+
def cond_model_fn(x, t, **extra_args):
|
| 50 |
+
with torch.enable_grad():
|
| 51 |
+
x = x.detach().requires_grad_()
|
| 52 |
+
v = model(x, t, **extra_args)
|
| 53 |
+
alphas, sigmas = utils.t_to_alpha_sigma(t)
|
| 54 |
+
pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
|
| 55 |
+
cond_grad = cond_fn(x, t, pred, **extra_args).detach()
|
| 56 |
+
v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
|
| 57 |
+
return v
|
| 58 |
+
return cond_model_fn
|
| 59 |
+
def cond_fn(x, t, pred, clip_embed):
|
| 60 |
+
if min(pred.shape[2:4]) < 256:
|
| 61 |
+
pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
|
| 62 |
+
clip_in = normalize(make_cutouts((pred + 1) / 2))
|
| 63 |
+
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
|
| 64 |
+
losses = spherical_dist_loss(image_embeds, clip_embed[None])
|
| 65 |
+
loss = losses.mean(0).sum() * args.clip_guidance_scale
|
| 66 |
+
grad = -torch.autograd.grad(loss, x)[0]
|
| 67 |
+
return grad
|
| 68 |
+
|
| 69 |
gc.collect()
|
| 70 |
torch.cuda.empty_cache()
|
| 71 |
torch.manual_seed(seed)
|
| 72 |
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
|
| 73 |
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
|
| 74 |
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
|
| 75 |
+
if(not clip_guided):
|
| 76 |
+
outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
|
| 77 |
+
else:
|
| 78 |
+
extra_args = {'clip_embed': clip_embed}
|
| 79 |
+
cond_fn_ = cond_fn
|
| 80 |
+
model_fn = make_cond_model_fn(model, cond_fn_)
|
| 81 |
+
outs = sampling.plms_sample(model_fn, x, steps, extra_args)
|
| 82 |
images_out = []
|
| 83 |
for i, out in enumerate(outs):
|
| 84 |
images_out.append(utils.to_pil_image(out))
|
|
|
|
| 92 |
fn=run_all,
|
| 93 |
inputs=[
|
| 94 |
gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
|
| 95 |
+
gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=40,maximum=80,minimum=1,step=1),
|
| 96 |
+
gr.inputs.Slider(label="Number of images in parallel", default=2, maximum=4, minimum=1, step=1),
|
| 97 |
+
gr.inputs.Slider(label="Weight - how closely the image should resemble the prompt", default=5, maximum=15, minimum=0, step=1),
|
| 98 |
+
gr.inputs.Checkbox(label="CLIP Guided - improves coherence with prompt, makes it slower"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
],
|
| 100 |
outputs=gallery,
|
| 101 |
title="Generate images from text with V-Diffusion CC12M CFG",
|