File size: 928 Bytes
89280a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from .base import LayerDecomposer
from .qwen import QwenDecomposer

# Registry mapping model type string (or class) to Decomposer
DECOMPOSER_REGISTRY = {
    # Keys should match what we expect in model config or name logic
    "qwen2": QwenDecomposer,
    "qwen3": QwenDecomposer,
    "qwen": QwenDecomposer, # Generic fallback
    "llama": QwenDecomposer, # Llama usually identical structure (PreNorm, RMS, MLP)
}

def get_decomposer(model_name_or_obj) -> LayerDecomposer:
    """
    Factory to return appropriate decomposer.
    """
    # Simple logic based on string for now
    name = str(model_name_or_obj).lower()

    if "qwen" in name:
        return QwenDecomposer()
    if "llama" in name:
        return QwenDecomposer() # Re-use for now as structure is same

    # Default fallback (hope compatibility)
    print(f"Warning: No specific decomposer for {name}. Using Qwen/Llama default.")
    return QwenDecomposer()