AlekseyCalvin commited on
Commit
1c2d316
·
verified ·
1 Parent(s): a25e242

Update deforum_engine.py

Browse files
Files changed (1) hide show
  1. deforum_engine.py +44 -36
deforum_engine.py CHANGED
@@ -12,40 +12,44 @@ class DeforumRunner:
12
  self.current_model_config = (None, None, None) # Model, LoRA, Scheduler
13
 
14
  def load_model(self, model_id, lora_id, scheduler_name):
15
- # Prevent reloading if config matches
16
  if (model_id, lora_id, scheduler_name) == self.current_model_config and self.pipe is not None:
17
  return
18
 
19
  print(f"Loading Model: {model_id}")
20
 
21
- # Safe Cleanup
22
- if hasattr(self, 'pipe') and self.pipe is not None:
23
  del self.pipe
24
  self.pipe = None
25
  gc.collect()
26
 
27
- # --- STRATEGY 1: Standard Load ---
28
- try:
29
- self.pipe = AutoPipelineForImage2Image.from_pretrained(
30
- model_id, safety_checker=None, torch_dtype=torch.float32
31
- )
32
- except Exception as e1:
33
- print(f"Strategy 1 failed ({e1}). Trying Legacy (BIN)...")
34
- # --- STRATEGY 2: Legacy .bin files (use_safetensors=False) ---
 
 
 
 
 
 
35
  try:
36
- self.pipe = AutoPipelineForImage2Image.from_pretrained(
37
- model_id, safety_checker=None, torch_dtype=torch.float32, use_safetensors=False
38
- )
39
- except Exception as e2:
40
- print(f"Strategy 2 failed ({e2}). Trying FP16 Variant...")
41
- # --- STRATEGY 3: FP16 Variant ---
42
- try:
43
- self.pipe = AutoPipelineForImage2Image.from_pretrained(
44
- model_id, safety_checker=None, torch_dtype=torch.float32, variant="fp16"
45
- )
46
- except Exception as e3:
47
- print(f"All loading strategies failed for {model_id}.")
48
- raise RuntimeError(f"Could not load model {model_id}. Last error: {e3}")
49
 
50
  # Load LoRA
51
  if lora_id and lora_id.strip().lower() != "none":
@@ -53,7 +57,8 @@ class DeforumRunner:
53
  self.pipe.load_lora_weights(lora_id)
54
  self.pipe.fuse_lora()
55
  print(f"LoRA {lora_id} loaded.")
56
- except Exception as e: print(f"LoRA Load Error: {e}")
 
57
 
58
  # Configure Scheduler
59
  s_config = self.pipe.scheduler.config
@@ -67,7 +72,8 @@ class DeforumRunner:
67
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(s_config)
68
 
69
  self.pipe.to(self.device)
70
- # Check if attention slicing is supported (TinyVAE sometimes hates it)
 
71
  try: self.pipe.enable_attention_slicing()
72
  except: pass
73
 
@@ -85,13 +91,15 @@ class DeforumRunner:
85
  model_path, lora_path, scheduler_type):
86
 
87
  self.stop_requested = False
 
 
88
  try:
89
  self.load_model(model_path, lora_path, scheduler_type)
90
  except Exception as e:
91
- yield None, None, None, f"Model Load Error: {str(e)}"
92
  return
93
 
94
- # 1. Parse Schedules
95
  try:
96
  schedules = {
97
  'zoom': utils.parse_weight_schedule(zoom_schedule, max_frames),
@@ -108,7 +116,7 @@ class DeforumRunner:
108
  run_id = uuid.uuid4().hex[:6]
109
  os.makedirs(f"output_{run_id}", exist_ok=True)
110
 
111
- # 2. Init Canvas
112
  if init_image:
113
  prev_img = init_image.resize((width, height), Image.LANCZOS)
114
  else:
@@ -120,7 +128,7 @@ class DeforumRunner:
120
  base_seed = random.randint(0, 2**32 - 1)
121
  print(f"Run {run_id} Started. Base Seed: {base_seed}")
122
 
123
- # 3. Main Loop
124
  for frame_idx in range(max_frames):
125
  if self.stop_requested: break
126
 
@@ -133,7 +141,7 @@ class DeforumRunner:
133
  np.random.seed(current_seed)
134
  torch.manual_seed(current_seed)
135
 
136
- # Warp
137
  warp_args = {
138
  'angle': schedules['angle'][frame_idx],
139
  'zoom': schedules['zoom'][frame_idx],
@@ -142,7 +150,7 @@ class DeforumRunner:
142
  }
143
  warped_img = utils.anim_frame_warp_2d(prev_img, warp_args, border_mode)
144
 
145
- # Cadence Logic
146
  if frame_idx % cadence == 0:
147
  # Prepare
148
  init_for_diffusion = utils.maintain_colors(warped_img, color_anchor, color_coherence)
@@ -171,16 +179,16 @@ class DeforumRunner:
171
 
172
  prev_img = gen_image
173
  else:
174
- # Turbo
175
  gen_image = warped_img
176
  prev_img = warped_img
177
 
178
  generated_frames.append(gen_image)
179
 
180
- # Yield Frame
181
  yield gen_image, None, None, f"Frame {frame_idx+1}/{max_frames}"
182
 
183
- # 4. Save Outputs
184
  video_path = f"output_{run_id}/video.mp4"
185
  self.save_video(generated_frames, video_path, fps)
186
 
@@ -193,7 +201,7 @@ class DeforumRunner:
193
  if not frames: return
194
  try:
195
  w, h = frames[0].size
196
- # Try 'mp4v' first
197
  out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
198
  for f in frames:
199
  out.write(cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR))
 
12
  self.current_model_config = (None, None, None) # Model, LoRA, Scheduler
13
 
14
  def load_model(self, model_id, lora_id, scheduler_name):
15
+ # Prevent reloading if config matches exactly
16
  if (model_id, lora_id, scheduler_name) == self.current_model_config and self.pipe is not None:
17
  return
18
 
19
  print(f"Loading Model: {model_id}")
20
 
21
+ # Safe Cleanup: Check existence before deletion to prevent AttributeError
22
+ if getattr(self, 'pipe', None) is not None:
23
  del self.pipe
24
  self.pipe = None
25
  gc.collect()
26
 
27
+ # Loading Strategies: Try them in order until one works
28
+ strategies = [
29
+ # 1. Standard Load (Safetensors default)
30
+ {"safety_checker": None, "torch_dtype": torch.float32},
31
+ # 2. Legacy Load (PyTorch .bin files)
32
+ {"safety_checker": None, "torch_dtype": torch.float32, "use_safetensors": False},
33
+ # 3. FP16 Variant (Common in optimized repos, but we cast to float32 for CPU safety)
34
+ {"safety_checker": None, "torch_dtype": torch.float32, "variant": "fp16"},
35
+ ]
36
+
37
+ success = False
38
+ last_error = None
39
+
40
+ for kwargs in strategies:
41
  try:
42
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(model_id, **kwargs)
43
+ success = True
44
+ print(f"Loaded successfully with config: {kwargs}")
45
+ break
46
+ except Exception as e:
47
+ last_error = e
48
+ continue
49
+
50
+ if not success:
51
+ print(f"All loading strategies failed for {model_id}.")
52
+ raise RuntimeError(f"Could not load model {model_id}. Error: {last_error}")
 
 
53
 
54
  # Load LoRA
55
  if lora_id and lora_id.strip().lower() != "none":
 
57
  self.pipe.load_lora_weights(lora_id)
58
  self.pipe.fuse_lora()
59
  print(f"LoRA {lora_id} loaded.")
60
+ except Exception as e:
61
+ print(f"LoRA Load Error (continuing without LoRA): {e}")
62
 
63
  # Configure Scheduler
64
  s_config = self.pipe.scheduler.config
 
72
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(s_config)
73
 
74
  self.pipe.to(self.device)
75
+
76
+ # TinyVAE / specific models crash with attention slicing, so we wrap it
77
  try: self.pipe.enable_attention_slicing()
78
  except: pass
79
 
 
91
  model_path, lora_path, scheduler_type):
92
 
93
  self.stop_requested = False
94
+
95
+ # Load Model (Propagate errors to UI)
96
  try:
97
  self.load_model(model_path, lora_path, scheduler_type)
98
  except Exception as e:
99
+ yield None, None, None, f"Model Error: {str(e)}"
100
  return
101
 
102
+ # Parse Schedules
103
  try:
104
  schedules = {
105
  'zoom': utils.parse_weight_schedule(zoom_schedule, max_frames),
 
116
  run_id = uuid.uuid4().hex[:6]
117
  os.makedirs(f"output_{run_id}", exist_ok=True)
118
 
119
+ # Init Canvas
120
  if init_image:
121
  prev_img = init_image.resize((width, height), Image.LANCZOS)
122
  else:
 
128
  base_seed = random.randint(0, 2**32 - 1)
129
  print(f"Run {run_id} Started. Base Seed: {base_seed}")
130
 
131
+ # Main Loop
132
  for frame_idx in range(max_frames):
133
  if self.stop_requested: break
134
 
 
141
  np.random.seed(current_seed)
142
  torch.manual_seed(current_seed)
143
 
144
+ # 1. Warp
145
  warp_args = {
146
  'angle': schedules['angle'][frame_idx],
147
  'zoom': schedules['zoom'][frame_idx],
 
150
  }
151
  warped_img = utils.anim_frame_warp_2d(prev_img, warp_args, border_mode)
152
 
153
+ # 2. Cadence Logic
154
  if frame_idx % cadence == 0:
155
  # Prepare
156
  init_for_diffusion = utils.maintain_colors(warped_img, color_anchor, color_coherence)
 
179
 
180
  prev_img = gen_image
181
  else:
182
+ # Turbo (Warp Only)
183
  gen_image = warped_img
184
  prev_img = warped_img
185
 
186
  generated_frames.append(gen_image)
187
 
188
+ # Yield Status
189
  yield gen_image, None, None, f"Frame {frame_idx+1}/{max_frames}"
190
 
191
+ # Finalize
192
  video_path = f"output_{run_id}/video.mp4"
193
  self.save_video(generated_frames, video_path, fps)
194
 
 
201
  if not frames: return
202
  try:
203
  w, h = frames[0].size
204
+ # 'mp4v' is widely supported for CPU/OpenCV
205
  out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
206
  for f in frames:
207
  out.write(cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR))