# In the model's config (example: ERNIE 4.5-style decoder blocks)
    base_model_tp_plan = {
        "layers.*.self_attn.q_proj": "colwise",
        "layers.*.self_attn.k_proj": "colwise",
        "layers.*.self_attn.v_proj": "colwise",
        "layers.*.self_attn.o_proj": "rowwise",
        "layers.*.mlp.gate_proj": "colwise",
        "layers.*.mlp.up_proj":   "colwise",
        "layers.*.mlp.down_proj": "rowwise",
    }
    
    # Runtime
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    model_id = "your/model-or-local-checkpoint"
    model = AutoModelForCausalLM.from_pretrained( # <-- will automatically map to the plan defined above
        model_id, 
        dtype=torch.bfloat16,
    )  
    tok = AutoTokenizer.from_pretrained(model_id)
    inputs = tok("Hello", return_tensors="pt").to(model.device)
    out = model(**inputs)