File size: 8,266 Bytes
f99ed48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# model_loader.py
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os

# --- グローバル変数 (アプリケーション起動時にロードされる) ---
model = None
tokenizer = None
MODEL_ID = os.environ.get(
    "MODEL_ID", "Qwen/Qwen3-30B-A3B"
)  # 環境変数からモデルIDを取得、なければデフォルト
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

LOAD_IN_4BIT = os.environ.get("LOAD_IN_4BIT", "false").lower() == "true"
LOAD_IN_8BIT = os.environ.get("LOAD_IN_8BIT", "false").lower() == "true"

# 4bitと8bitが同時にTrueになるのを防ぐ (どちらか一方、またはどちらもFalse)
if LOAD_IN_4BIT and LOAD_IN_8BIT:
    print(
        "Warning: Both LOAD_IN_4BIT and LOAD_IN_8BIT are set to true. Prioritizing 4-bit."
    )
    LOAD_IN_8BIT = False
elif not LOAD_IN_4BIT and not LOAD_IN_8BIT:
    print(
        "Info: No explicit quantization (4-bit/8-bit) requested via environment variables. Loading in default precision (e.g., bfloat16 on GPU)."
    )


def load_model():
    """
    アプリケーション起動時にモデルとトークナイザーをロードする。
    """
    global model, tokenizer
    if model is None or tokenizer is None:
        quantization_info = "No Quantization"
        if LOAD_IN_4BIT:
            quantization_info = "4-bit Quantization"
        elif LOAD_IN_8BIT:
            quantization_info = "8-bit Quantization"

        print(
            f"Loading model: {MODEL_ID} on device: {DEVICE} with {quantization_info}..."
        )
        try:
            tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
            model_kwargs = {
                "trust_remote_code": True
            }  # 基本的にTrueにしておくことが多い
            quantization_config = None
            if DEVICE == "cuda":
                model_kwargs["device_map"] = "auto"
                if LOAD_IN_4BIT:
                    quantization_config = BitsAndBytesConfig(load_in_4bit=True)
                    model_kwargs["torch_dtype"] = "auto"  # 4bitと併用する計算時の型
                    # bnb_4bit_compute_dtype など、より詳細なbitsandbytes設定も環境変数で制御可能
                elif LOAD_IN_8BIT:
                    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
                    # 8bitの場合、torch_dtypeは自動で設定されることが多いが、明示も可
                else:  # 量子化なしGPU
                    model_kwargs["torch_dtype"] = torch.bfloat16

                # model = AutoModelForCausalLM.from_pretrained(
                #     MODEL_ID,
                #     torch_dtype=torch.bfloat16,  # または torch.float16
                #     load_in_4bit=True,  # 4ビット量子化でロード (bitsandbytesが必要)
                #     # load_in_8bit=True, # 8ビット量子化の場合
                #     device_map="auto",  # 自動でGPUに割り当て
                #     trust_remote_code=True,  # モデルによっては必要
                # )
            else:  # CPUの場合 (量子化はGPU推奨だが、一応対応)
                # CPUでのbitsandbytes量子化は限定的、または非推奨
                if LOAD_IN_4BIT or LOAD_IN_8BIT:
                    print(
                        "Warning: bitsandbytes quantization (4-bit/8-bit) is primarily for GPU. Attempting on CPU may be slow or unstable."
                    )
                # model_kwargs["device_map"] = {"": "cpu"} # 明示的にCPUを指定
                pass  # .to(DEVICE) で対応

            model = AutoModelForCausalLM.from_pretrained(
                MODEL_ID, **model_kwargs, quantization_config=quantization_config
            )

            if DEVICE == "cpu" and not (
                LOAD_IN_4BIT or LOAD_IN_8BIT
            ):  # CPUで量子化なしの場合
                model = model.to(DEVICE)

            model.eval()  # 評価モード
            print(f"Model {MODEL_ID} loaded successfully.")
        except Exception as e:
            print(f"Error loading model {MODEL_ID}: {e}")
            # エラー発生時は model と tokenizer が None のままになる
            # アプリケーションのヘルスチェックなどでこれを確認できるようにするのも良い
            raise RuntimeError(f"Failed to load model: {e}")


def generate_text(
    prompt: str,
    max_new_tokens: int = 100,
    temperature: float = 0.3,
    top_p: float = 0.9,
    repetition_penalty: float = 1.0,
) -> str:
    """
    ロードされたモデルを使ってテキストを生成する。
    """
    if model is None or tokenizer is None:
        raise RuntimeError("Model not loaded. Cannot generate text.")

    try:
        # プロンプトの形式はモデルによって調整が必要
        # 例: Instructモデルの場合、特定のテンプレートがあることが多い
        # ここでは単純にユーザープロンプトのみを使用
        # inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

        # より一般的なチャット形式のプロンプト適用 (モデルに合わせて調整)
        # StableLM Instruct Gamma のプロンプト形式例 (あくまで一例)
        # 参考: https://huggingface.co/stabilityai/japanese-stablelm-instruct-gamma-7b
        messages = [{"role": "user", "content": prompt}]
        # モデルによっては tokenizer.apply_chat_template が使える
        try:
            # 多くのモデルではtokenizer.apply_chat_templateが使える
            prompt_formatted = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                # Thinking Modeの切り替えここでできる
                # enable_thinking=False,
            )
        except Exception:
            # 古いモデルや特殊なモデルでapply_chat_templateがない場合の手動フォーマット例
            # これはモデルのドキュメントを確認して適切な形式にする
            print(
                f"Warning: tokenizer.apply_chat_template failed for {MODEL_ID}. Using raw prompt or basic formatting."
            )
            if (
                "stablelm-instruct" in MODEL_ID.lower() or "elyza" in MODEL_ID.lower()
            ):  # ELYZAやStableLMの例
                prompt_formatted = f"ユーザー: {prompt}\nシステム: "
            elif (
                "qwen" in MODEL_ID.lower() and "chat" in MODEL_ID.lower()
            ):  # Qwen-Chatの例
                prompt_formatted = (
                    f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
                )
            else:  # デフォルトはそのまま
                prompt_formatted = prompt

        inputs = tokenizer(
            prompt_formatted, return_tensors="pt", add_special_tokens=False
        ).to(DEVICE)  # add_special_tokensはテンプレートによる

        # テキスト生成
        # pad_token_id はeos_token_idと同じに設定することが多い (警告抑制)
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id

        generation_kwargs = {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "repetition_penalty": repetition_penalty,
            "do_sample": True
            if temperature > 0
            else False,  # temperatureが0超ならサンプリング
            "pad_token_id": tokenizer.pad_token_id,
        }

        outputs = model.generate(**inputs, **generation_kwargs)

        # 生成されたテキストのみをデコード (入力プロンプト部分を除く)
        # inputs.input_ids.shape[1] は入力トークンの長さ
        output_text = tokenizer.decode(
            outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
        )
        return output_text.strip()

    except Exception as e:
        print(f"Error during text generation: {e}")
        # traceback.print_exc() # 詳細なエラー表示
        raise RuntimeError(f"Text generation failed: {e}")