AlekseyCalvin commited on
Commit
d940e49
·
verified ·
1 Parent(s): 93ffa36

Update deforum_engine.py

Browse files
Files changed (1) hide show
  1. deforum_engine.py +162 -63
deforum_engine.py CHANGED
@@ -1,90 +1,189 @@
1
- import torch, os, uuid, zipfile, cv2, gc
2
  import numpy as np
3
- from diffusers import AutoPipelineForImage2Image, LCMScheduler, EulerAncestralDiscreteScheduler
4
  from PIL import Image
5
  import utils
6
 
7
  class DeforumRunner:
8
  def __init__(self, device="cpu"):
9
- self.device, self.pipe, self.stop_requested = device, None, False
10
- self.current_cfg = (None, None, None)
 
 
 
 
 
 
 
11
 
12
- def load_model(self, m_id, l_id, s_name):
13
- if (m_id, l_id, s_name) == self.current_cfg and self.pipe: return
14
- print(f"Loading {m_id}...")
15
  if self.pipe: del self.pipe; gc.collect()
16
-
17
- # Resilient loading for non-standard model repos
18
  try:
19
- self.pipe = AutoPipelineForImage2Image.from_pretrained(m_id, safety_checker=None, torch_dtype=torch.float32)
 
 
20
  except:
21
- self.pipe = AutoPipelineForImage2Image.from_pretrained(m_id, safety_checker=None, torch_dtype=torch.float32, use_safetensors=False)
 
 
 
22
 
23
- if l_id and l_id != "None":
24
- self.pipe.load_lora_weights(l_id); self.pipe.fuse_lora()
25
-
26
- self.pipe.scheduler = (LCMScheduler if s_name=="LCM" else EulerAncestralDiscreteScheduler).from_config(self.pipe.scheduler.config)
27
- self.pipe.to(self.device); self.pipe.enable_attention_slicing()
28
- self.current_cfg = (m_id, l_id, s_name)
29
 
30
- def stop(self): self.stop_requested = True
 
 
 
 
 
 
 
 
 
31
 
32
- def render(self, prompt_dict, neg, max_f, w, h, z_s, a_s, tx_s, ty_s, str_s, noi_s, steps, color_m, border_m, blend_f, init_upload, m_id, l_id, s_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  self.stop_requested = False
34
- self.load_model(m_id, l_id, s_name)
35
 
36
- # 1. Schedules & Prompt Interpolation
37
- s = {k: utils.parse_weight_string(v, max_f) for k, v in zip(['z','a','tx','ty','st','no'], [z_s, a_s, tx_s, ty_s, str_s, noi_s])}
38
- prompt_embeddings = utils.interpolate_prompts(self.pipe, prompt_dict, max_f)
39
- neg_tokens = self.pipe.tokenizer(neg, padding="max_length", max_length=self.pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt").input_ids.to(self.device)
40
- neg_emb = self.pipe.text_encoder(neg_tokens)[0]
41
-
42
  run_id = uuid.uuid4().hex[:6]
43
  os.makedirs(f"out_{run_id}", exist_ok=True)
44
- prev_img = init_upload.resize((w, h), Image.LANCZOS) if init_upload else None
45
- color_anchor, frames = None, []
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- for i in range(max_f):
 
48
  if self.stop_requested: break
49
 
50
- # --- Deforum Recursive Step ---
51
- # 1. Warp
52
- warped = utils.anim_frame_warp_2d(prev_img, {'angle':s['a'][i],'zoom':s['z'][i],'tx':s['tx'][i],'ty':s['ty'][i]}, border_m) if prev_img else Image.new("RGB",(w,h),(128,128,128))
 
 
 
 
53
 
54
- # 2. Preparation (Color + Noise)
55
- init_diff = utils.maintain_colors(warped, color_anchor, color_m) if color_anchor else warped
56
 
57
- # 3. Prevent 0-step crash: Strength must be high enough relative to steps
58
- strength = s['st'][i] if prev_img else 0.99
59
- if (steps * strength) < 1.0: strength = 1.1 / steps # Force at least 1 step
60
 
61
- # 4. Generate with pre-blended prompt embeddings
62
- gen_img = self.pipe(
63
- prompt_embeds=prompt_embeddings[i],
64
- negative_prompt_embeds=neg_emb,
65
- image=init_diff,
66
- num_inference_steps=steps,
67
- strength=strength,
68
- guidance_scale=1.5, # Keep low for LCM to avoid "Colorful Noise"
69
- width=w, height=h
70
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # 5. Stability: Temporal Blend
73
- if blend_f > 0 and prev_img:
74
- gen_img = Image.blend(gen_img, prev_img, blend_f)
75
 
76
- prev_img = gen_img
77
- if not color_anchor: color_anchor = gen_img
78
- frames.append(gen_img)
79
- yield gen_img, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # (Finalization code same as before...)
82
- vid_p = f"out_{run_id}/video.mp4"
83
- v_out = cv2.VideoWriter(vid_p, cv2.VideoWriter_fourcc(*'mp4v'), 12, (w, h))
84
- for f in frames: v_out.write(cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR))
85
- v_out.release()
86
- z_p = f"out_{run_id}/frames.zip"
87
- with zipfile.ZipFile(z_p, 'w') as zf:
88
  for j, f in enumerate(frames):
89
- name = f"f_{j:05d}.png"; f.save(name); zf.write(name); os.remove(name)
90
- yield frames[-1], vid_p, z_p
 
 
 
1
+ import torch, os, uuid, zipfile, cv2, gc, random
2
  import numpy as np
3
+ from diffusers import AutoPipelineForImage2Image, LCMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler
4
  from PIL import Image
5
  import utils
6
 
7
  class DeforumRunner:
8
  def __init__(self, device="cpu"):
9
+ self.device = device
10
+ self.pipe = None
11
+ self.stop_requested = False
12
+ self.current_config = (None, None, None) # Model, LoRA, Scheduler
13
+
14
+ def load_model(self, model_id, lora_id, scheduler_name):
15
+ # Avoid reloading if not changed
16
+ if (model_id, lora_id, scheduler_name) == self.current_config and self.pipe is not None:
17
+ return
18
 
19
+ print(f"Loading Model: {model_id} with {scheduler_name}")
 
 
20
  if self.pipe: del self.pipe; gc.collect()
21
+
 
22
  try:
23
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(
24
+ model_id, safety_checker=None, torch_dtype=torch.float32
25
+ )
26
  except:
27
+ # Fallback for non-safetensor repos
28
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(
29
+ model_id, safety_checker=None, torch_dtype=torch.float32, use_safetensors=False
30
+ )
31
 
32
+ # Load LoRA
33
+ if lora_id and lora_id != "None":
34
+ try:
35
+ self.pipe.load_lora_weights(lora_id)
36
+ self.pipe.fuse_lora()
37
+ except Exception as e: print(f"LoRA Load Fail: {e}")
38
 
39
+ # Set Scheduler
40
+ s_config = self.pipe.scheduler.config
41
+ if scheduler_name == "LCM":
42
+ self.pipe.scheduler = LCMScheduler.from_config(s_config)
43
+ elif scheduler_name == "Euler A":
44
+ self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(s_config)
45
+ elif scheduler_name == "DDIM":
46
+ self.pipe.scheduler = DDIMScheduler.from_config(s_config)
47
+ elif scheduler_name == "DPM++ 2M":
48
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(s_config)
49
 
50
+ self.pipe.to(self.device)
51
+ self.pipe.enable_attention_slicing()
52
+ self.current_config = (model_id, lora_id, scheduler_name)
53
+
54
+ def stop(self):
55
+ self.stop_requested = True
56
+
57
+ def render(self,
58
+ prompts, neg_prompt, max_frames, width, height,
59
+ zoom_s, angle_s, tx_s, ty_s, strength_s, noise_s,
60
+ fps, steps, cadence,
61
+ color_mode, border_mode, seed_behavior, init_image,
62
+ model_id, lora_id, scheduler_name):
63
+
64
  self.stop_requested = False
65
+ self.load_model(model_id, lora_id, scheduler_name)
66
 
67
+ # 1. Parse Schedules
68
+ keys = ['z', 'a', 'tx', 'ty', 'str', 'noi']
69
+ inputs = [zoom_s, angle_s, tx_s, ty_s, strength_s, noise_s]
70
+ sched = {k: utils.parse_weight_string(v, max_frames) for k, v in zip(keys, inputs)}
71
+
72
+ # 2. Setup Run
73
  run_id = uuid.uuid4().hex[:6]
74
  os.makedirs(f"out_{run_id}", exist_ok=True)
75
+
76
+ # Init Image & Canvas
77
+ if init_image:
78
+ prev_img = init_image.resize((width, height), Image.LANCZOS)
79
+ else:
80
+ # Start with neutral grey noise if no init
81
+ prev_img = Image.fromarray(np.random.randint(100, 150, (height, width, 3), dtype=np.uint8))
82
+
83
+ color_anchor = prev_img.copy()
84
+ frames = []
85
+
86
+ # Seed Setup
87
+ current_seed = random.randint(0, 2**32 - 1)
88
+
89
+ print(f"Starting Run {run_id}. Cadence: {cadence}")
90
 
91
+ # 3. Main Loop
92
+ for i in range(max_frames):
93
  if self.stop_requested: break
94
 
95
+ # Update Seed
96
+ if seed_behavior == "iter": current_seed += 1
97
+ elif seed_behavior == "random": current_seed = random.randint(0, 2**32 - 1)
98
+ # else fixed
99
+
100
+ # Get Current Params
101
+ args = {'angle': sched['a'][i], 'zoom': sched['z'][i], 'tx': sched['tx'][i], 'ty': sched['ty'][i]}
102
 
103
+ # --- Deforum Logic ---
 
104
 
105
+ # 1. WARP (Happens every frame)
106
+ # Warp the *previous* result
107
+ warped_img = utils.anim_frame_warp_2d(prev_img, args, border_mode)
108
 
109
+ # 2. DECIDE: Generate or Skip (Cadence)
110
+ # If Cadence=1, we generate every frame.
111
+ # If Cadence=2, we generate on 0, 2, 4... and just warp on 1, 3, 5
112
+
113
+ if i % cadence == 0:
114
+ # --- GENERATION STEP ---
115
+
116
+ # A. Color Match (Pre-Diffusion)
117
+ init_for_diff = utils.maintain_colors(warped_img, color_anchor, color_mode)
118
+
119
+ # B. Add Noise
120
+ init_for_diff = utils.add_noise(init_for_diff, sched['noi'][i])
121
+
122
+ # C. Prompt
123
+ # Find latest prompt key <= current frame
124
+ p_keys = sorted([k for k in prompts.keys() if k <= i])
125
+ curr_prompt = prompts[p_keys[-1]]
126
+
127
+ # D. Strength Logic
128
+ # Prevent 0-step crash
129
+ curr_strength = sched['str'][i]
130
+ if (steps * curr_strength) < 1.0: curr_strength = 1.1 / steps
131
+
132
+ # E. Diffuse
133
+ generator = torch.Generator(device=self.device).manual_seed(current_seed)
134
+
135
+ # Using 1.5 - 2.0 guidance for LCM/SDXS to prevent frying
136
+ cfg = 1.5 if "LCM" in scheduler_name else 7.5
137
+
138
+ gen_image = self.pipe(
139
+ prompt=curr_prompt,
140
+ negative_prompt=neg_prompt,
141
+ image=init_for_diff,
142
+ num_inference_steps=steps,
143
+ strength=curr_strength,
144
+ guidance_scale=cfg,
145
+ width=width, height=height,
146
+ generator=generator
147
+ ).images[0]
148
+
149
+ # F. Color Match (Post-Diffusion stability)
150
+ if color_mode != 'None':
151
+ gen_image = utils.maintain_colors(gen_image, color_anchor, color_mode)
152
+
153
+ else:
154
+ # --- CADENCE STEP (Turbo) ---
155
+ # Just use the warped image. This is the "In-between" frame.
156
+ # In true Deforum, we might blend this with the *next* generation,
157
+ # but for real-time/CPU, returning the warped frame is the standard "Turbo" behavior.
158
+ gen_image = warped_img
159
 
160
+ # Update State
161
+ prev_img = gen_image
162
+ frames.append(gen_image)
163
 
164
+ yield gen_image, None, None
165
+
166
+ # 4. Finalize
167
+ vid_path = f"out_{run_id}/video.mp4"
168
+ self.save_video(frames, vid_path, fps)
169
+ zip_path = f"out_{run_id}/frames.zip"
170
+ self.save_zip(frames, zip_path)
171
+
172
+ yield frames[-1], vid_path, zip_path
173
+
174
+ def save_video(self, frames, path, fps):
175
+ if not frames: return
176
+ w, h = frames[0].size
177
+ out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
178
+ for f in frames:
179
+ out.write(cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR))
180
+ out.release()
181
 
182
+ def save_zip(self, frames, path):
183
+ import io
184
+ with zipfile.ZipFile(path, 'w') as zf:
 
 
 
 
185
  for j, f in enumerate(frames):
186
+ name = f"f_{j:05d}.png"
187
+ buf = io.BytesIO()
188
+ f.save(buf, format="PNG")
189
+ zf.writestr(name, buf.getvalue())