Spaces:
Paused
Paused
Commit ·
bf89172
1
Parent(s): 9cd412c
Add more logic to clip embeds
Browse files
app.py
CHANGED
|
@@ -59,8 +59,32 @@ make_cutouts = MakeCutouts(clip_model.visual.input_resolution, 16, 1.)
|
|
| 59 |
def run_all(prompt, steps, n_images, weight, clip_guided):
|
| 60 |
import random
|
| 61 |
seed = int(random.randint(0, 2147483647))
|
| 62 |
-
target_embed = clip_model.encode_text(clip.tokenize(prompt)).float().cuda()
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def cfg_model_fn(x, t):
|
| 65 |
"""The CFG wrapper function."""
|
| 66 |
n = x.shape[0]
|
|
|
|
| 59 |
def run_all(prompt, steps, n_images, weight, clip_guided):
|
| 60 |
import random
|
| 61 |
seed = int(random.randint(0, 2147483647))
|
| 62 |
+
target_embed = clip_model.encode_text(clip.tokenize(prompt)).float()#.cuda()
|
| 63 |
+
|
| 64 |
+
if(clip_guided):
|
| 65 |
+
prompts = [prompt]
|
| 66 |
+
def parse_prompt(prompt):
|
| 67 |
+
if prompt.startswith('http://') or prompt.startswith('https://'):
|
| 68 |
+
vals = prompt.rsplit(':', 2)
|
| 69 |
+
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
| 70 |
+
else:
|
| 71 |
+
vals = prompt.rsplit(':', 1)
|
| 72 |
+
vals = vals + ['', '1'][len(vals):]
|
| 73 |
+
return vals[0], float(vals[1])
|
| 74 |
+
|
| 75 |
+
for prompt in prompts:
|
| 76 |
+
txt, weight = parse_prompt(prompt)
|
| 77 |
+
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
|
| 78 |
+
weights.append(weight)
|
| 79 |
+
|
| 80 |
+
target_embeds = torch.cat(target_embeds)
|
| 81 |
+
weights = torch.tensor(weights, device=device)
|
| 82 |
+
if weights.sum().abs() < 1e-3:
|
| 83 |
+
raise RuntimeError('The weights must not sum to 0.')
|
| 84 |
+
weights /= weights.sum().abs()
|
| 85 |
+
clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
|
| 86 |
+
clip_embed = target_embed.repeat([n_images, 1])
|
| 87 |
+
|
| 88 |
def cfg_model_fn(x, t):
|
| 89 |
"""The CFG wrapper function."""
|
| 90 |
n = x.shape[0]
|