File size: 418 Bytes
e14f899
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import torch
from peft import PeftModel
from diffusers_lite.wan.modules.model import WanModel, WanAttentionBlock


def get_no_split_modules(transformer):
    while isinstance(transformer, PeftModel):
        transformer = transformer.base_model.model
    if isinstance(transformer, WanModel):
        return (WanAttentionBlock, )
    else:
        raise ValueError(f"Unsupported transformer type: {type(transformer)}")