File size: 2,326 Bytes
9d2d5e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45131ff
9d2d5e5
 
45131ff
9d2d5e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from safetensors.torch import save_file, load_file
import torch

def merge_model_components(
    unet_path,
    vae_path,
    text_encoder_path,
    output_path
):
    """
    Merge UNet, VAE, and text encoder into a single safetensors file.
    
    Args:
        unet_path: Path to the main model/unet safetensors file
        vae_path: Path to the VAE safetensors file
        text_encoder_path: Path to the text encoder/CLIP safetensors file
        output_path: Path where the merged file will be saved
    """
    
    print("Loading UNet/Model weights...")
    unet_state = load_file(unet_path)
    
    print("Loading VAE weights...")
    vae_state = load_file(vae_path)
    
    print("Loading Text Encoder weights...")
    text_encoder_state = load_file(text_encoder_path)
    
    # Merge all state dictionaries
    print("Merging state dictionaries...")
    merged_state = {}
    
    # Add all UNet weights
    merged_state.update(unet_state)
    
    # Add VAE weights with proper prefixes if needed
    for key, value in vae_state.items():
        # If keys don't already have 'vae.' prefix, add it
        if not key.startswith('vae.'):
            merged_state[f'vae.{key}'] = value
        else:
            merged_state[key] = value
    
    # Add text encoder weights with proper prefixes
    for key, value in text_encoder_state.items():
        # If keys don't already have 'text_encoder.' prefix, add it
        if not key.startswith('text_encoder.'):
            merged_state[f'text_encoder.{key}'] = value
        else:
            merged_state[key] = value
    
    print(f"Total parameters in merged model: {len(merged_state)}")
    print(f"Saving merged model to {output_path}...")
    
    # Save the merged state dictionary
    save_file(merged_state, output_path)
    
    print("✅ Merge complete!")
    print(f"File saved to: {output_path}")
    
    # Print file size
    import os
    size_gb = os.path.getsize(output_path) / (1024**3)
    print(f"File size: {size_gb:.2f} GB")


# Example usage
if __name__ == "__main__":
    merge_model_components(
        unet_path="flux1-depth-dev.safetensors",
        vae_path="vae/diffusion_pytorch_model.safetensors", 
        text_encoder_path="text_encoder/model.safetensors",
        output_path="flux1-depth-dev_merged_model.safetensors"
    )