File size: 5,413 Bytes
402d0ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import numpy as np
import cv2
import numexpr
import re
from PIL import Image, ImageOps

# --- Math & Schedule Parsing ---

def get_inbetweens(key_frames, max_frames, integer=False):
    """Interpolates values between keyframes (simple linear for now, but robust)."""
    key_frames = dict(sorted(key_frames.items()))
    keys = list(key_frames.keys())
    vals = list(key_frames.values())
    
    # Fill array
    series = np.linspace(vals[0], vals[0], max_frames)
    for i in range(len(keys)-1):
        idx_start, idx_end = keys[i], keys[i+1]
        val_start, val_end = vals[i], vals[i+1]
        if idx_end > max_frames: idx_end = max_frames
        
        # Linear interpolation
        range_len = idx_end - idx_start
        if range_len > 0:
            segment = np.linspace(val_start, val_end, range_len, endpoint=False)
            series[idx_start:idx_end] = segment
            
    # Fill tail
    if keys[-1] < max_frames:
        series[keys[-1]:] = vals[-1]
        
    return series.astype(int) if integer else series

def parse_weight_string(string, max_frames):
    """
    Parses '0:(0.5), 10:(1.0)' including Math like '0:(sin(t/10))'.
    Returns a numpy array of float values for every frame.
    """
    # Clean string
    string = re.sub(r'\s+', '', string)
    keyframes = {}
    
    # Split by comma, respecting parentheses might be needed in complex regex, 
    # but simple split usually works for Deforum format
    parts = string.split(',')
    
    for part in parts:
        try:
            frame_str, val_str = part.split(':')
            frame = int(frame_str)
            # Remove parentheses
            val_str = val_str.strip('()')
            keyframes[frame] = val_str
        except:
            continue
            
    if 0 not in keyframes:
        keyframes[0] = "0"
        
    sorted_frames = sorted(keyframes.keys())
    series = np.zeros(max_frames)
    
    # Evaluate math for every frame
    for i in range(len(sorted_frames)):
        f_start = sorted_frames[i]
        f_end = sorted_frames[i+1] if i < len(sorted_frames)-1 else max_frames
        
        formula = keyframes[f_start]
        
        for f in range(f_start, f_end):
            t = f # Deforum standard variable
            try:
                # Safe evaluation environment
                val = numexpr.evaluate(formula, local_dict={'t': t, 's': f_start, 'pi': np.pi, 'sin': np.sin, 'cos': np.cos, 'tan': np.tan})
                series[f] = float(val)
            except Exception as e:
                # If static value or fail
                try:
                    series[f] = float(formula)
                except:
                    series[f] = series[f-1] if f > 0 else 0.0
                    
    return series

# --- Image Processing ---

def maintain_colors(prev_img, color_match_sample):
    """
    Matches the coloring of the previous frame (or frame 0) to prevent color drift.
    Uses LAB color space transfer.
    """
    prev_img_np = np.array(prev_img).astype(np.uint8)
    sample_np = np.array(color_match_sample).astype(np.uint8)
    
    prev_lab = cv2.cvtColor(prev_img_np, cv2.COLOR_RGB2LAB)
    sample_lab = cv2.cvtColor(sample_np, cv2.COLOR_RGB2LAB)
    
    l_avg_p, a_avg_p, b_avg_p = np.mean(prev_lab[:,:,0]), np.mean(prev_lab[:,:,1]), np.mean(prev_lab[:,:,2])
    l_avg_s, a_avg_s, b_avg_s = np.mean(sample_lab[:,:,0]), np.mean(sample_lab[:,:,1]), np.mean(sample_lab[:,:,2])
    
    l, a, b = cv2.split(prev_lab)
    
    # Shift current image logic towards sample mean
    # Note: Deforum usually matches the NEW generation to the OLD image.
    # Here we are adjusting the image we just warped (prev) to match the original anchor? 
    # Actually, standard Deforum 'Match Frame 0' means we force the init image to look like Frame 0 colors.
    
    l = l - l_avg_p + l_avg_s
    a = a - a_avg_p + a_avg_s
    b = b - b_avg_p + b_avg_s
    
    l = np.clip(l, 0, 255)
    a = np.clip(a, 0, 255)
    b = np.clip(b, 0, 255)
    
    matched_lab = cv2.merge([l.astype(np.uint8), a.astype(np.uint8), b.astype(np.uint8)])
    return Image.fromarray(cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB))

def add_noise(img, noise_amt):
    """Adds uniform noise to the image to give the diffusion model texture to latch onto."""
    if noise_amt <= 0: return img
    
    img_np = np.array(img).astype(np.float32)
    noise = np.random.normal(0, noise_amt * 255, img_np.shape).astype(np.float32)
    noisy_img = np.clip(img_np + noise, 0, 255).astype(np.uint8)
    return Image.fromarray(noisy_img)

def anim_frame_warp_2d(prev_img_pil, args_dict):
    """
    Standard Deforum 2D Warping.
    args_dict must contain: angle, zoom, translation_x, translation_y
    """
    cv2_img = np.array(prev_img_pil)
    height, width = cv2_img.shape[:2]
    
    center = (width // 2, height // 2)
    
    # Rotation & Zoom
    angle = args_dict.get('angle', 0)
    zoom = args_dict.get('zoom', 1.0)
    
    trans_mat = cv2.getRotationMatrix2D(center, angle, zoom)
    
    # Translation
    tx = args_dict.get('translation_x', 0)
    ty = args_dict.get('translation_y', 0)
    trans_mat[0, 2] += tx
    trans_mat[1, 2] += ty
    
    # Warp with Reflection to handle edges naturally
    warped = cv2.warpAffine(
        cv2_img, 
        trans_mat, 
        (width, height), 
        borderMode=cv2.BORDER_REFLECT_101
    )
    
    return Image.fromarray(warped)