File size: 928 Bytes
89280a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | from .base import LayerDecomposer
from .qwen import QwenDecomposer
# Registry mapping model type string (or class) to Decomposer
DECOMPOSER_REGISTRY = {
# Keys should match what we expect in model config or name logic
"qwen2": QwenDecomposer,
"qwen3": QwenDecomposer,
"qwen": QwenDecomposer, # Generic fallback
"llama": QwenDecomposer, # Llama usually identical structure (PreNorm, RMS, MLP)
}
def get_decomposer(model_name_or_obj) -> LayerDecomposer:
"""
Factory to return appropriate decomposer.
"""
# Simple logic based on string for now
name = str(model_name_or_obj).lower()
if "qwen" in name:
return QwenDecomposer()
if "llama" in name:
return QwenDecomposer() # Re-use for now as structure is same
# Default fallback (hope compatibility)
print(f"Warning: No specific decomposer for {name}. Using Qwen/Llama default.")
return QwenDecomposer()
|