Amshaker commited on
Commit
cc7566c
·
verified ·
1 Parent(s): fd1d226

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. save_merged_model.py +691 -0
save_merged_model.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoTokenizer
5
+ from blip3o.model import *
6
+ from peft import PeftModel
7
+ import os
8
+ from safetensors.torch import load_file
9
+ import argparse
10
+ from pathlib import Path
11
+ import re
12
+
13
+ @dataclass
14
+ class T2IConfig:
15
+ # Base model path (original model before LoRA training)
16
+ #base_model_path: str = "/proj/cvl/users/x_fahkh2/BLIP3o_SANA/fastvlm-o/blip3o_fast_vlm_unified_v6_60k_blip3o_45k_sharegpt_quad_all_learnable_dynamic_lr_v5_27e_lora16_pretrain_without_sft_ve_learnable_v7_abl4"
17
+ #base_model_path: str = "/proj/cvl/users/x_fahkh2/BLIP3o_SANA/fastvlm-o/blip3o_fast_vlm_unified_v6_60k_blip3o_45k_sharegpt_quad_all_learnable_dynamic_lr_v5_7e_lora16_after_sft_pretrain_ve_learnable_v7_image_edit_512_v3_LLM_Lora"
18
+
19
+ base_model_path: str = "/proj/cvl/users/x_fahkh2/BLIP3o_SANA/fastvlm-o/blip3o_fast_vlm_unified_v6_60k_blip3o_45k_sharegpt_quad_all_learnable_dynamic_lr_v5_20e_lora16_after_sft_pretrain_v6_ve_learnable_v7_with_edit_1_5B"
20
+ dtype: torch.dtype = torch.bfloat16
21
+
22
+ # generation config
23
+ scale: int = 0
24
+ seq_len: int = 729
25
+ top_p: float = 0.95
26
+ top_k: int = 1200
27
+
28
+ # Set to True to use LoRA checkpoint, False to use base model only
29
+ use_lora_checkpoint: bool = True
30
+
31
+
32
+ def find_latest_checkpoint(checkpoint_dir):
33
+ """
34
+ Find the latest checkpoint in the given directory.
35
+
36
+ Args:
37
+ checkpoint_dir: Path to the directory containing checkpoints
38
+
39
+ Returns:
40
+ Path to the latest checkpoint's global_step directory, or None if not found
41
+ """
42
+ checkpoint_path = Path(checkpoint_dir)
43
+
44
+ if not checkpoint_path.exists():
45
+ print(f"⚠️ Warning: Checkpoint directory does not exist: {checkpoint_dir}")
46
+ return None
47
+
48
+ # Find all checkpoint directories (format: checkpoint-XXXXX)
49
+ checkpoint_dirs = []
50
+ for item in checkpoint_path.iterdir():
51
+ if item.is_dir() and item.name.startswith("checkpoint-"):
52
+ # Extract the step number
53
+ match = re.match(r"checkpoint-(\d+)", item.name)
54
+ if match:
55
+ step_num = int(match.group(1))
56
+ checkpoint_dirs.append((step_num, item))
57
+
58
+ if not checkpoint_dirs:
59
+ print(f"⚠️ Warning: No checkpoint directories found in {checkpoint_dir}")
60
+ return None
61
+
62
+ # Sort by step number and get the latest
63
+ checkpoint_dirs.sort(key=lambda x: x[0], reverse=True)
64
+ latest_step, latest_dir = checkpoint_dirs[0]
65
+ latest_step=23620
66
+ # Look for global_step directory inside
67
+ global_step_dir = latest_dir / f"global_step{latest_step}"
68
+
69
+ if not global_step_dir.exists():
70
+ print(f"⚠️ Warning: global_step directory not found at {global_step_dir}")
71
+ return None
72
+
73
+ print(f"✓ Found latest checkpoint: {latest_dir.name} (step {latest_step})")
74
+ return str(global_step_dir)
75
+
76
+
77
+ class TextToImageInference:
78
+ def __init__(self, config: T2IConfig):
79
+ self.config = config
80
+ self.device = 'cuda:0'#torch.device(config.device)
81
+ self._load_models()
82
+
83
+ def save_merged_model(self, output_path: str, deepspeed_checkpoint_path: str = None):
84
+ """
85
+ Merge LoRA weights with base model and save as a standalone model.
86
+ Handles DeepSpeed ZeRO checkpoints if provided.
87
+
88
+ Args:
89
+ output_path: Directory where the merged model will be saved
90
+ deepspeed_checkpoint_path: Path to DeepSpeed checkpoint directory (e.g., checkpoint-5719/global_step5719)
91
+ """
92
+ import shutil
93
+ from pathlib import Path
94
+ from transformers import AutoTokenizer
95
+ import torch
96
+
97
+ print(f"\n{'='*80}")
98
+ print("SAVING MERGED MODEL")
99
+ print(f"{'='*80}\n")
100
+
101
+ output_path = Path(output_path)
102
+ output_path.mkdir(parents=True, exist_ok=True)
103
+
104
+ # Step 0: Load DeepSpeed checkpoint if provided
105
+ if deepspeed_checkpoint_path is not None:
106
+ print("[0/5] Loading DeepSpeed checkpoint...")
107
+ deepspeed_checkpoint_path = Path(deepspeed_checkpoint_path)
108
+
109
+ # Check if zero_to_fp32.py exists
110
+ zero_script = deepspeed_checkpoint_path.parent / "zero_to_fp32.py"
111
+ if not zero_script.exists():
112
+ print(f" ⚠️ zero_to_fp32.py not found at {zero_script}")
113
+ print(" Looking for consolidated checkpoint...")
114
+
115
+ # Try to load consolidated checkpoint
116
+ consolidated_path = deepspeed_checkpoint_path / "pytorch_model.bin"
117
+ if False:
118
+ print(f" Loading consolidated checkpoint from {consolidated_path}")
119
+ deepspeed_state_dict = torch.load(consolidated_path, map_location='cpu')
120
+ print(f" ✓ Loaded {len(deepspeed_state_dict)} parameters from DeepSpeed checkpoint")
121
+ else:
122
+ # Try to load from mp_rank_00_model_states.pt
123
+ model_states_path = deepspeed_checkpoint_path / "mp_rank_00_model_states.pt"
124
+ if model_states_path.exists():
125
+ print(f" Loading model states from {model_states_path}")
126
+ checkpoint = torch.load(model_states_path, map_location='cpu')
127
+
128
+ # Extract the actual model state dict (DeepSpeed wraps it)
129
+ if 'module' in checkpoint:
130
+ deepspeed_state_dict = checkpoint['module']
131
+ elif 'model_state_dict' in checkpoint:
132
+ deepspeed_state_dict = checkpoint['model_state_dict']
133
+ else:
134
+ deepspeed_state_dict = checkpoint
135
+
136
+ print(f" ✓ Loaded {len(deepspeed_state_dict)} parameters from DeepSpeed checkpoint")
137
+ else:
138
+ print(f" ⚠️ No consolidated checkpoint found. Please run:")
139
+ print(f" cd {deepspeed_checkpoint_path.parent}")
140
+ print(f" python zero_to_fp32.py {deepspeed_checkpoint_path.name} pytorch_model.bin")
141
+ deepspeed_state_dict = None
142
+ else:
143
+ deepspeed_state_dict = None
144
+
145
+ # Check if model is a PEFT model (has LoRA)
146
+ from peft import PeftModel
147
+
148
+ if isinstance(self.model, PeftModel):
149
+ print("[1/5] Merging LoRA weights into base model...")
150
+ merged_model = self.model.merge_and_unload()
151
+ print(" ✓ LoRA weights merged")
152
+
153
+ # Move to CPU for saving to avoid CUDA memory issues
154
+ print(" Moving model to CPU for saving...")
155
+ merged_model = merged_model.cpu()
156
+ else:
157
+ print("[1/5] Model has no LoRA adapters, saving as-is...")
158
+ merged_model = self.model.cpu()
159
+
160
+ # Save the merged model - use state_dict method to avoid PEFT issues
161
+ print(f"\n[2/5] Preparing model state dict...")
162
+
163
+ # Get the base model config
164
+ if hasattr(merged_model, 'config'):
165
+ config = merged_model.config
166
+ else:
167
+ from transformers import AutoConfig
168
+ config = AutoConfig.from_pretrained(self.config.base_model_path, trust_remote_code=True)
169
+
170
+ # Get model state dict
171
+ state_dict = merged_model.state_dict()
172
+
173
+ # Merge with DeepSpeed checkpoint if available
174
+ if deepspeed_state_dict is not None:
175
+ print(" Merging with DeepSpeed checkpoint...")
176
+
177
+ # Remove 'module.' prefix if present (from DDP/DeepSpeed)
178
+ cleaned_deepspeed_dict = {}
179
+ for key, value in deepspeed_state_dict.items():
180
+ clean_key = key.replace('module.', '')
181
+ cleaned_deepspeed_dict[clean_key] = value
182
+
183
+ # Update state dict with DeepSpeed weights
184
+ # This will overwrite LoRA-merged weights with fully trained weights
185
+ for key, value in cleaned_deepspeed_dict.items():
186
+ if key in state_dict:
187
+ state_dict[key] = value
188
+ else:
189
+ # Add new parameters that might have been trained
190
+ state_dict[key] = value
191
+
192
+ print(f" ✓ Merged {len(cleaned_deepspeed_dict)} parameters from DeepSpeed")
193
+
194
+ # Remove any PEFT-related keys that might remain
195
+ keys_to_remove = []
196
+ for key in state_dict.keys():
197
+ if any(x in key for x in ['lora_', 'adapter_', 'peft_']):
198
+ keys_to_remove.append(key)
199
+
200
+ if keys_to_remove:
201
+ print(f" Removing {len(keys_to_remove)} PEFT-related keys...")
202
+ for key in keys_to_remove:
203
+ del state_dict[key]
204
+
205
+ print(f" ✓ Final state dict has {len(state_dict)} parameters")
206
+
207
+ # Save config
208
+ print(f"\n[3/5] Saving config to: {output_path}")
209
+ config.save_pretrained(str(output_path))
210
+ print(" ✓ Config saved")
211
+
212
+ # Save model weights using safetensors
213
+ print(f"\n[4/5] Saving model weights...")
214
+ from safetensors.torch import save_file
215
+ import math
216
+
217
+ # Split into shards if needed (5GB per shard)
218
+ max_shard_size = 5 * 1024 * 1024 * 1024 # 5GB in bytes
219
+
220
+ # Calculate approximate size
221
+ total_size = sum(v.numel() * v.element_size() for v in state_dict.values())
222
+
223
+ if total_size > max_shard_size:
224
+ print(f" Model size: {total_size / 1024**3:.2f}GB, splitting into shards...")
225
+ num_shards = math.ceil(total_size / max_shard_size)
226
+
227
+ # Split state dict into shards
228
+ keys = list(state_dict.keys())
229
+ shard_size = len(keys) // num_shards + 1
230
+
231
+ # Create index file for sharded model
232
+ weight_map = {}
233
+ for i in range(num_shards):
234
+ shard_keys = keys[i * shard_size:(i + 1) * shard_size]
235
+ shard_dict = {k: state_dict[k] for k in shard_keys}
236
+
237
+ shard_filename = f"model-{i+1:05d}-of-{num_shards:05d}.safetensors"
238
+ save_file(shard_dict, str(output_path / shard_filename))
239
+
240
+ # Update weight map
241
+ for k in shard_keys:
242
+ weight_map[k] = shard_filename
243
+
244
+ print(f" ✓ Saved shard {i+1}/{num_shards}: {shard_filename}")
245
+
246
+ # Save index file
247
+ import json
248
+ index = {
249
+ "metadata": {"total_size": total_size},
250
+ "weight_map": weight_map
251
+ }
252
+ with open(output_path / "model.safetensors.index.json", "w") as f:
253
+ json.dump(index, f, indent=2)
254
+ print(" ✓ Saved model index")
255
+ else:
256
+ print(f" Model size: {total_size / 1024**3:.2f}GB, saving in single file...")
257
+ save_file(state_dict, str(output_path / "model.safetensors"))
258
+ print(" ✓ Model weights saved")
259
+
260
+ # Save tokenizer
261
+ print("\n[5/5] Saving tokenizer and additional files...")
262
+ tokenizer = AutoTokenizer.from_pretrained(self.config.base_model_path)
263
+ tokenizer.save_pretrained(str(output_path))
264
+ print(" ✓ Tokenizer saved")
265
+
266
+ # Copy additional files
267
+ base_path = Path(self.config.base_model_path)
268
+
269
+ # Copy Python files (modeling, configuration, processing)
270
+ print(" Copying Python files...")
271
+ for py_file in base_path.glob("*.py"):
272
+ if any(x in py_file.name.lower() for x in ["modeling", "configuration", "processing", "image"]):
273
+ try:
274
+ shutil.copy2(py_file, output_path / py_file.name)
275
+ print(f" - {py_file.name}")
276
+ except Exception as e:
277
+ print(f" ⚠️ Failed to copy {py_file.name}: {e}")
278
+
279
+ # Copy projector files if they exist
280
+ print(" Checking for projector files...")
281
+ search_paths = [
282
+ base_path,
283
+ base_path / "merged_model",
284
+ ]
285
+
286
+ # Add checkpoint path if it exists
287
+ if hasattr(self.config, 'lora_checkpoint_path') and self.config.lora_checkpoint_path:
288
+ search_paths.append(Path(self.config.lora_checkpoint_path))
289
+
290
+ projector_files = ["mm_projector.bin", "gen_projector.bin"]
291
+ for bin_file in projector_files:
292
+ found = False
293
+ for search_path in search_paths:
294
+ if search_path is None or not search_path.exists():
295
+ continue
296
+
297
+ src = search_path / bin_file
298
+ if src.exists():
299
+ shutil.copy2(src, output_path / bin_file)
300
+ print(f" - {bin_file}")
301
+ found = True
302
+ break
303
+
304
+ if not found:
305
+ # Check if it's in the state dict instead
306
+ if any(bin_file.replace('.bin', '') in key for key in state_dict.keys()):
307
+ print(f" ℹ️ {bin_file} weights are in model state dict")
308
+ else:
309
+ print(f" ⚠️ {bin_file} not found (may not be needed)")
310
+
311
+ # Copy config files
312
+ print(" Copying additional config files...")
313
+ for json_file in ["generation_config.json", "preprocessor_config.json"]:
314
+ src = base_path / json_file
315
+ if src.exists():
316
+ shutil.copy2(src, output_path / json_file)
317
+ print(f" - {json_file}")
318
+
319
+ print(f"\n{'='*80}")
320
+ print("✅ MODEL SAVED SUCCESSFULLY!")
321
+ print(f"{'='*80}")
322
+ print(f"\nMerged model saved to: {output_path}")
323
+ print(f"Total parameters: {len(state_dict):,}")
324
+ print(f"Model size: {total_size / 1024**3:.2f}GB")
325
+
326
+ if deepspeed_state_dict is not None:
327
+ print("\n⚠️ Note: This model includes weights from DeepSpeed checkpoint")
328
+
329
+ print("\nYou can now load it with:")
330
+ print(f" from transformers import AutoModelForCausalLM")
331
+ print(f" model = AutoModelForCausalLM.from_pretrained('{output_path}', trust_remote_code=True)")
332
+ print(f"\nOr with your custom class:")
333
+ print(f" model = blip3oFastForInferenceLM.from_pretrained('{output_path}')")
334
+ print(f"\n{'='*80}\n")
335
+
336
+ def _load_deepspeed_checkpoint(self, model, checkpoint_dir):
337
+ """Load DeepSpeed checkpoint with full model states"""
338
+ print(f"Loading DeepSpeed checkpoint from: {checkpoint_dir}")
339
+
340
+ # Path to the consolidated model states
341
+ global_step_dir = os.path.join(checkpoint_dir, "checkpoint-23620/global_step23620")
342
+ model_state_path = os.path.join(global_step_dir, "mp_rank_00_model_states.pt")
343
+
344
+ if not os.path.exists(model_state_path):
345
+ print(f"⚠️ Warning: Model states not found at {model_state_path}")
346
+ print(" Using zero_to_fp32.py to consolidate checkpoint...")
347
+
348
+ # Try to use zero_to_fp32.py to consolidate
349
+ import subprocess
350
+ zero_script = os.path.join(checkpoint_dir, "zero_to_fp32.py")
351
+ if os.path.exists(zero_script):
352
+ output_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
353
+ subprocess.run([
354
+ "python", zero_script,
355
+ checkpoint_dir,
356
+ output_path
357
+ ])
358
+ model_state_path = output_path
359
+ else:
360
+ print(" zero_to_fp32.py not found, skipping full checkpoint loading")
361
+ return model
362
+
363
+ # Load the checkpoint
364
+ print(f"Loading model states from: {model_state_path}")
365
+ checkpoint = torch.load(model_state_path, map_location="cpu")
366
+
367
+ # Extract the actual state dict (DeepSpeed wraps it)
368
+ if "module" in checkpoint:
369
+ state_dict = checkpoint["module"]
370
+ elif isinstance(checkpoint, dict) and "state_dict" in checkpoint:
371
+ state_dict = checkpoint["state_dict"]
372
+ else:
373
+ state_dict = checkpoint
374
+
375
+ # Load non-LoRA weights (DiT, projectors, vision tower, etc.)
376
+ # We'll load these into the base model before applying LoRA
377
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
378
+
379
+ print(f"✓ Loaded checkpoint successfully")
380
+ if missing_keys:
381
+ print(f" Missing keys (expected for LoRA): {len(missing_keys)}")
382
+ # Show first few missing keys
383
+ for key in missing_keys[:5]:
384
+ print(f" - {key}")
385
+ if len(missing_keys) > 5:
386
+ print(f" ... and {len(missing_keys) - 5} more")
387
+
388
+ if unexpected_keys:
389
+ print(f" Unexpected keys: {len(unexpected_keys)}")
390
+ for key in unexpected_keys[:5]:
391
+ print(f" - {key}")
392
+
393
+ return model
394
+
395
+ def _load_models(self):
396
+ """Load model with LoRA adapters and full checkpoint weights"""
397
+ print("=" * 80)
398
+ if self.config.use_lora_checkpoint:
399
+ print(f"Loading base model from: {self.config.base_model_path}")
400
+ print(f"Loading LoRA checkpoint from: {self.config.lora_checkpoint_path}")
401
+ else:
402
+ print(f"Loading model without LoRA from: {self.config.base_model_path}")
403
+ print("=" * 80)
404
+
405
+ # Step 1: Load base model architecture
406
+ print("\n[1/4] Loading base model architecture...")
407
+ base_model = blip3oFastForInferenceLM.from_pretrained(
408
+ self.config.base_model_path,
409
+ torch_dtype=self.config.dtype,
410
+ device_map="cpu", # Load to CPU first for checkpoint loading
411
+ )
412
+ print("✓ Base model loaded")
413
+
414
+ if self.config.use_lora_checkpoint:
415
+ # Step 2: Load full checkpoint weights (DiT, projectors, etc.)
416
+ print("\n[2/4] Loading full checkpoint weights (DiT, projectors, etc.)...")
417
+ base_model = self._load_deepspeed_checkpoint(
418
+ base_model,
419
+ self.config.lora_checkpoint_path
420
+ )
421
+
422
+ # Step 3: Apply LoRA adapters on top
423
+ print("\n[3/4] Applying LoRA adapters...")
424
+ self.model = PeftModel.from_pretrained(
425
+ base_model,
426
+ self.config.lora_checkpoint_path,
427
+ torch_dtype=self.config.dtype,
428
+ )
429
+ print("✓ LoRA adapters applied successfully!")
430
+
431
+ # Print parameter info
432
+ lora_params = sum(p.numel() for n, p in self.model.named_parameters() if "lora" in n.lower())
433
+ total_params = sum(p.numel() for p in self.model.parameters())
434
+ print(f" LoRA parameters: {lora_params:,} ({100 * lora_params / total_params:.2f}%)")
435
+ else:
436
+ self.model = base_model
437
+
438
+ # Step 4: Move to device and set to eval mode
439
+ print("\n[4/4] Moving model to device and setting eval mode...")
440
+ self.model = self.model.to(self.device)
441
+ self.model.eval()
442
+ print(f"✓ Model on {self.device}")
443
+
444
+ # Load tokenizer from checkpoint (has all the special tokens)
445
+ tokenizer_path = self.config.lora_checkpoint_path if self.config.use_lora_checkpoint else self.config.base_model_path
446
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
447
+ print(f"✓ Tokenizer loaded from: {tokenizer_path}")
448
+ print("=" * 80)
449
+ print("\n✅ Model loading complete! Ready for inference.\n")
450
+
451
+ def generate_image(self, prompt, steps=30) -> Image.Image:
452
+ """Generate image from text prompt"""
453
+ batch_messages = []
454
+
455
+ messages = [
456
+ {"role": "system", "content": "You are a helpful assistant."},
457
+ {"role": "user", "content": f"Please generate image based on the following caption: {prompt}"}
458
+ ]
459
+
460
+ input_text = self.tokenizer.apply_chat_template(
461
+ messages,
462
+ tokenize=False,
463
+ add_generation_prompt=True
464
+ )
465
+ #input_text += f"<im_start>"
466
+ batch_messages.append(input_text)
467
+
468
+ inputs = self.tokenizer(
469
+ batch_messages,
470
+ return_tensors="pt",
471
+ padding=True,
472
+ truncation=True,
473
+ padding_side="left"
474
+ )
475
+
476
+ with torch.no_grad():
477
+ output_image = self.model.generate_image(
478
+ inputs.input_ids.to(self.device),
479
+ inputs.attention_mask.to(self.device),
480
+ #steps=steps
481
+ )
482
+
483
+ return output_image[0]
484
+
485
+
486
+ def consolidate_checkpoint_first(checkpoint_dir):
487
+ """
488
+ Consolidate DeepSpeed checkpoint before loading.
489
+ Run this once if you get errors loading the checkpoint.
490
+ """
491
+ import subprocess
492
+
493
+ print("=" * 80)
494
+ print("Consolidating DeepSpeed checkpoint...")
495
+ print("=" * 80)
496
+
497
+ zero_script = os.path.join(checkpoint_dir, "zero_to_fp32.py")
498
+ output_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
499
+
500
+ if not os.path.exists(zero_script):
501
+ print(f"❌ zero_to_fp32.py not found at {zero_script}")
502
+ return False
503
+
504
+ print(f"Input: {checkpoint_dir}")
505
+ print(f"Output: {output_path}")
506
+
507
+ result = subprocess.run(
508
+ ["python", zero_script, checkpoint_dir, output_path],
509
+ capture_output=True,
510
+ text=True
511
+ )
512
+
513
+ if result.returncode == 0:
514
+ print(f"✓ Checkpoint consolidated successfully to {output_path}")
515
+ return True
516
+ else:
517
+ print(f"❌ Error consolidating checkpoint:")
518
+ print(result.stderr)
519
+ return False
520
+
521
+
522
+ def main():
523
+ """Generate images with different inference steps"""
524
+ # Parse command line arguments
525
+ parser = argparse.ArgumentParser(description="Merge BLIP3o LoRA model with base model")
526
+ parser.add_argument(
527
+ "--checkpoint_dir",
528
+ type=str,
529
+ required=True,
530
+ help="Path to the checkpoint directory (e.g., blip3o_fast_vlm_unified_v6_60k_...)"
531
+ )
532
+ parser.add_argument(
533
+ "--output_dir",
534
+ type=str,
535
+ default=None,
536
+ help="Output directory for merged model (default: {checkpoint_dir}/final_merged_model_{step})"
537
+ )
538
+ parser.add_argument(
539
+ "--skip_inference",
540
+ action="store_true",
541
+ help="Skip image generation and only save merged model"
542
+ )
543
+
544
+ args = parser.parse_args()
545
+
546
+ checkpoint_dir = args.checkpoint_dir
547
+
548
+ # Find the latest checkpoint
549
+ print(f"\n{'='*80}")
550
+ print(f"Searching for latest checkpoint in: {checkpoint_dir}")
551
+ print(f"{'='*80}\n")
552
+
553
+ latest_checkpoint = find_latest_checkpoint(checkpoint_dir)
554
+
555
+ if latest_checkpoint is None:
556
+ print("❌ Error: Could not find any valid checkpoints")
557
+ return
558
+
559
+ # Extract step number from checkpoint path
560
+ step_match = re.search(r"global_step(\d+)", latest_checkpoint)
561
+ if step_match:
562
+ step_num = step_match.group(1)
563
+ else:
564
+ step_num = "unknown"
565
+
566
+ # Set output directory
567
+ if args.output_dir is None:
568
+ output_dir = f"{checkpoint_dir}/final_merged_model_{step_num}"
569
+ else:
570
+ output_dir = args.output_dir
571
+
572
+ print(f"Output directory: {output_dir}\n")
573
+
574
+ # Update config with checkpoint directory
575
+ config = T2IConfig()
576
+ config.base_model_path = checkpoint_dir
577
+ config.lora_checkpoint_path = checkpoint_dir
578
+
579
+ # Initialize inference
580
+ inference = TextToImageInference(config)
581
+
582
+ # Save merged model
583
+ inference.save_merged_model(output_dir, deepspeed_checkpoint_path=latest_checkpoint)
584
+
585
+ if args.skip_inference:
586
+ print("\n✅ Merged model saved. Skipping inference as requested.")
587
+ return
588
+
589
+ # Generate test images
590
+ prompts = [
591
+ 'A surreal scene on a lunar-like surface, where a brown horse is standing on the back of an astronaut...',
592
+ "a photo of four cute cats",
593
+ "a photo of five cute dogs",
594
+ "a photo of a horse",
595
+ "a photo of a tiger",
596
+ "a photo of a wolf",
597
+ "a beautiful mountain landscape"
598
+ ]
599
+ inference_steps = [20]
600
+ image_output_dir = f"Fast-SANA-LoRA-Full-{step_num}"
601
+ os.makedirs(image_output_dir, exist_ok=True)
602
+
603
+ all_images = []
604
+
605
+ # Generate images
606
+ print("\n" + "=" * 80)
607
+ print("Starting image generation...")
608
+ print("=" * 80)
609
+
610
+ for idx, prompt in enumerate(prompts):
611
+ print(f"\n[Prompt {idx+1}/{len(prompts)}] {prompt[:60]}...")
612
+ row_images = []
613
+ for inf in inference_steps:
614
+ print(f" Generating with {inf} steps...", end=" ")
615
+ image_sana = inference.generate_image(prompt, steps=inf)
616
+ save_path = os.path.join(image_output_dir, f"prompt{idx:02d}_steps{inf}.png")
617
+ image_sana.save(save_path)
618
+ print(f"✓ Saved")
619
+ row_images.append(image_sana)
620
+ all_images.append(row_images)
621
+
622
+ # Create grid visualization
623
+ print("\n" + "=" * 80)
624
+ print("Creating grid visualization...")
625
+ print("=" * 80)
626
+
627
+ import matplotlib.pyplot as plt
628
+
629
+ fig, axes = plt.subplots(len(prompts), len(inference_steps), figsize=(15, 10))
630
+ for i, row_images in enumerate(all_images):
631
+ for j, img in enumerate(row_images):
632
+ if len(inference_steps) == 1:
633
+ ax = axes[i]
634
+ else:
635
+ ax = axes[i, j]
636
+ ax.imshow(img)
637
+ ax.axis("off")
638
+ if i == 0:
639
+ ax.set_title(f"{inference_steps[j]} steps", fontsize=10)
640
+
641
+ plt.tight_layout()
642
+ grid_path = os.path.join(image_output_dir, "grid_results.png")
643
+ plt.savefig(grid_path, dpi=150, bbox_inches='tight')
644
+ print(f"✓ Grid saved: {grid_path}")
645
+ plt.close()
646
+
647
+ print("\n✅ All done! Check the '{}' folder for results.".format(image_output_dir))
648
+
649
+
650
+ def compare_base_vs_lora():
651
+ """Compare base model vs LoRA-trained model outputs"""
652
+ import matplotlib.pyplot as plt
653
+
654
+ test_prompts = [
655
+ "a photo of a cute cat",
656
+ "a beautiful mountain landscape",
657
+ "a tiger in the forest"
658
+ ]
659
+
660
+ num_inference_steps = 20
661
+
662
+ for model_type in ["base", "lora"]:
663
+ config = T2IConfig()
664
+ config.use_lora_checkpoint = (model_type == "lora")
665
+
666
+ output_dir = f"comparison_{model_type}"
667
+ os.makedirs(output_dir, exist_ok=True)
668
+
669
+ print(f"\n{'='*80}")
670
+ print(f"Generating with {model_type.upper()} model")
671
+ print(f"{'='*80}")
672
+
673
+ inference = TextToImageInference(config)
674
+
675
+ for idx, prompt in enumerate(test_prompts):
676
+ print(f"\n[{idx+1}/{len(test_prompts)}] {prompt}")
677
+ image = inference.generate_image(prompt, num_inference_steps=num_inference_steps)
678
+ save_path = os.path.join(output_dir, f"{model_type}_prompt{idx:02d}.png")
679
+ image.save(save_path)
680
+ print(f"✓ Saved: {save_path}")
681
+
682
+ # Clean up to free memory
683
+ del inference
684
+ torch.cuda.empty_cache()
685
+
686
+ print("\n✅ Comparison complete!")
687
+ print("Check 'comparison_base' and 'comparison_lora' folders")
688
+
689
+
690
+ if __name__ == "__main__":
691
+ main()