AlekseyCalvin commited on
Commit
08b7792
·
verified ·
1 Parent(s): aaa45ef

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +77 -0
utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import numexpr
4
+ import re
5
+ import torch
6
+ from PIL import Image
7
+
8
+ def parse_weight_string(string, max_frames):
9
+ string = re.sub(r'\s+', '', str(string))
10
+ keyframes = {}
11
+ parts = string.split(',')
12
+ for part in parts:
13
+ try:
14
+ if ':' not in part: continue
15
+ f_str, v_str = part.split(':', 1)
16
+ keyframes[int(f_str)] = v_str.strip('()')
17
+ except: continue
18
+ if 0 not in keyframes: keyframes[0] = "0"
19
+
20
+ series = np.zeros(int(max_frames))
21
+ sorted_k = sorted(keyframes.keys())
22
+ for i in range(len(sorted_k)):
23
+ f_start = sorted_k[i]
24
+ f_end = sorted_k[i+1] if i < len(sorted_k)-1 else int(max_frames)
25
+ formula = keyframes[f_start]
26
+ for f in range(f_start, f_end):
27
+ try:
28
+ series[f] = float(numexpr.evaluate(formula, local_dict={'t':f,'pi':np.pi,'sin':np.sin,'cos':np.cos}))
29
+ except:
30
+ series[f] = float(formula) if formula.replace('.','',1).isdigit() else (series[f-1] if f>0 else 0.0)
31
+ return series
32
+
33
+ def interpolate_prompts(pipe, prompt_dict, max_frames):
34
+ """Blends CLIP embeddings between keyframes for smooth conceptual transitions."""
35
+ sorted_keys = sorted(prompt_dict.keys())
36
+ # Pre-calculate embeddings for all keyframe prompts
37
+ key_embs = {}
38
+ for k in sorted_keys:
39
+ tokens = pipe.tokenizer(prompt_dict[k], padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt").input_ids.to(pipe.device)
40
+ with torch.no_grad():
41
+ key_embs[k] = pipe.text_encoder(tokens)[0]
42
+
43
+ full_embs = []
44
+ for f in range(max_frames):
45
+ # Find surrounding keyframes
46
+ before = [k for k in sorted_keys if k <= f]
47
+ after = [k for k in sorted_keys if k > f]
48
+
49
+ if not after:
50
+ full_embs.append(key_embs[before[-1]])
51
+ elif not before:
52
+ full_embs.append(key_embs[after[0]])
53
+ else:
54
+ k1, k2 = before[-1], after[0]
55
+ alpha = (f - k1) / (k2 - k1)
56
+ # Spherical Linear Interpolation (Slerp) or simple Lerp
57
+ blended = torch.lerp(key_embs[k1], key_embs[k2], alpha)
58
+ full_embs.append(blended)
59
+ return full_embs
60
+
61
+ def maintain_colors(img, anchor, mode='LAB'):
62
+ if mode == 'None' or anchor is None: return img
63
+ img_np, anc_np = np.array(img), np.array(anchor)
64
+ if mode == 'LAB':
65
+ img_lab, anc_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB), cv2.cvtColor(anc_np, cv2.COLOR_RGB2LAB)
66
+ for i in range(3):
67
+ img_lab[:,:,i] = np.clip(img_lab[:,:,i] - np.mean(img_lab[:,:,i]) + np.mean(anc_lab[:,:,i]), 0, 255)
68
+ return Image.fromarray(cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB))
69
+ return img
70
+
71
+ def anim_frame_warp_2d(img, args, mode='Reflect'):
72
+ cv_img = np.array(img)
73
+ h, w = cv_img.shape[:2]
74
+ mat = cv2.getRotationMatrix2D((w//2, h//2), args.get('angle',0), args.get('zoom',1))
75
+ mat[0, 2] += args.get('tx',0); mat[1, 2] += args.get('ty',0)
76
+ b = {'Reflect':cv2.BORDER_REFLECT_101, 'Replicate':cv2.BORDER_REPLICATE, 'Wrap':cv2.BORDER_WRAP}[mode]
77
+ return Image.fromarray(cv2.warpAffine(cv_img, mat, (w, h), borderMode=b))