File size: 4,058 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Model loading and prompt construction for Qwen3-30B-A3B-Thinking.

Provides:
  - load_model_and_tokenizer(): robust loader with dtype/device auto-handling
  - build_thinking_prompt(problem, enable_thinking): chat-template wrapper
  - generate(): a simple generation helper
"""
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from configs.model import MODEL_CONFIG, GEN_CONFIG


def load_model_and_tokenizer(
    model_dir: Path = None,
    dtype: str = None,
    device_map: str = "auto",
    verbose: bool = True,
):
    """
    Load Qwen3-30B-A3B-Thinking-2507 from local dir.

    Args:
        model_dir: override MODEL_CONFIG["local_dir"]
        dtype: "bfloat16" | "float16" | "auto"; overrides MODEL_CONFIG
        device_map: HuggingFace device_map, default "auto"
    """
    mdir = Path(model_dir) if model_dir else Path(MODEL_CONFIG["local_dir"])
    dt = dtype or MODEL_CONFIG["load_dtype"]
    torch_dtype = {
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
        "auto": "auto",
    }.get(dt, torch.bfloat16)

    if verbose:
        print(f"[model_io] Loading tokenizer: {mdir}")
    tokenizer = AutoTokenizer.from_pretrained(
        mdir, trust_remote_code=MODEL_CONFIG["trust_remote_code"]
    )
    if verbose:
        print(f"[model_io] Loading model dtype={dt} device_map={device_map}")
    model = AutoModelForCausalLM.from_pretrained(
        mdir,
        torch_dtype=torch_dtype,
        device_map=device_map,
        trust_remote_code=MODEL_CONFIG["trust_remote_code"],
    )
    model.eval()

    # Validate architecture matches config
    cfg = model.config
    assert cfg.num_hidden_layers == MODEL_CONFIG["num_layers"], \
        f"num_layers mismatch: model has {cfg.num_hidden_layers}, config says {MODEL_CONFIG['num_layers']}"
    if verbose:
        ne = getattr(cfg, "num_experts", None) or getattr(cfg, "n_routed_experts", None)
        nk = getattr(cfg, "num_experts_per_tok", None) or getattr(cfg, "top_k", None)
        print(f"[model_io] layers={cfg.num_hidden_layers}, experts={ne}, top_k={nk}")

    return model, tokenizer


def build_thinking_prompt(
    tokenizer,
    problem: str,
    system_prompt: str = None,
    enable_thinking: bool = True,
) -> str:
    """
    Construct chat-template prompt. Qwen3-Thinking uses enable_thinking
    to insert the <think> channel.
    """
    sys_msg = system_prompt or MODEL_CONFIG["default_system_prompt"]
    messages = [
        {"role": "system", "content": sys_msg},
        {"role": "user", "content": f"Problem: {problem}\n\nSolve step by step."},
    ]
    # Some Qwen3 chat templates accept enable_thinking kwarg
    try:
        return tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True,
            enable_thinking=enable_thinking,
        )
    except TypeError:
        # Fallback: plain chat template (no thinking switch)
        return tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True,
        )


def generate(
    model, tokenizer, prompt: str,
    max_new_tokens: int = None,
    temperature: float = None,
    top_p: float = None,
    do_sample: bool = None,
) -> str:
    """Generate a completion. Returns only newly generated text (no prompt)."""
    max_new = max_new_tokens or GEN_CONFIG["max_new_tokens"]
    t = temperature if temperature is not None else GEN_CONFIG["temperature"]
    p = top_p if top_p is not None else GEN_CONFIG["top_p"]
    ds = do_sample if do_sample is not None else GEN_CONFIG["do_sample"]

    enc = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **enc,
            max_new_tokens=max_new,
            temperature=t,
            top_p=p,
            do_sample=ds,
            pad_token_id=tokenizer.eos_token_id,
        )
    gen_ids = out[0, enc["input_ids"].shape[1]:]
    return tokenizer.decode(gen_ids, skip_special_tokens=True)