File size: 8,266 Bytes
f99ed48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# model_loader.py
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
# --- グローバル変数 (アプリケーション起動時にロードされる) ---
model = None
tokenizer = None
MODEL_ID = os.environ.get(
"MODEL_ID", "Qwen/Qwen3-30B-A3B"
) # 環境変数からモデルIDを取得、なければデフォルト
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LOAD_IN_4BIT = os.environ.get("LOAD_IN_4BIT", "false").lower() == "true"
LOAD_IN_8BIT = os.environ.get("LOAD_IN_8BIT", "false").lower() == "true"
# 4bitと8bitが同時にTrueになるのを防ぐ (どちらか一方、またはどちらもFalse)
if LOAD_IN_4BIT and LOAD_IN_8BIT:
print(
"Warning: Both LOAD_IN_4BIT and LOAD_IN_8BIT are set to true. Prioritizing 4-bit."
)
LOAD_IN_8BIT = False
elif not LOAD_IN_4BIT and not LOAD_IN_8BIT:
print(
"Info: No explicit quantization (4-bit/8-bit) requested via environment variables. Loading in default precision (e.g., bfloat16 on GPU)."
)
def load_model():
"""
アプリケーション起動時にモデルとトークナイザーをロードする。
"""
global model, tokenizer
if model is None or tokenizer is None:
quantization_info = "No Quantization"
if LOAD_IN_4BIT:
quantization_info = "4-bit Quantization"
elif LOAD_IN_8BIT:
quantization_info = "8-bit Quantization"
print(
f"Loading model: {MODEL_ID} on device: {DEVICE} with {quantization_info}..."
)
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model_kwargs = {
"trust_remote_code": True
} # 基本的にTrueにしておくことが多い
quantization_config = None
if DEVICE == "cuda":
model_kwargs["device_map"] = "auto"
if LOAD_IN_4BIT:
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_kwargs["torch_dtype"] = "auto" # 4bitと併用する計算時の型
# bnb_4bit_compute_dtype など、より詳細なbitsandbytes設定も環境変数で制御可能
elif LOAD_IN_8BIT:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# 8bitの場合、torch_dtypeは自動で設定されることが多いが、明示も可
else: # 量子化なしGPU
model_kwargs["torch_dtype"] = torch.bfloat16
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_ID,
# torch_dtype=torch.bfloat16, # または torch.float16
# load_in_4bit=True, # 4ビット量子化でロード (bitsandbytesが必要)
# # load_in_8bit=True, # 8ビット量子化の場合
# device_map="auto", # 自動でGPUに割り当て
# trust_remote_code=True, # モデルによっては必要
# )
else: # CPUの場合 (量子化はGPU推奨だが、一応対応)
# CPUでのbitsandbytes量子化は限定的、または非推奨
if LOAD_IN_4BIT or LOAD_IN_8BIT:
print(
"Warning: bitsandbytes quantization (4-bit/8-bit) is primarily for GPU. Attempting on CPU may be slow or unstable."
)
# model_kwargs["device_map"] = {"": "cpu"} # 明示的にCPUを指定
pass # .to(DEVICE) で対応
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, **model_kwargs, quantization_config=quantization_config
)
if DEVICE == "cpu" and not (
LOAD_IN_4BIT or LOAD_IN_8BIT
): # CPUで量子化なしの場合
model = model.to(DEVICE)
model.eval() # 評価モード
print(f"Model {MODEL_ID} loaded successfully.")
except Exception as e:
print(f"Error loading model {MODEL_ID}: {e}")
# エラー発生時は model と tokenizer が None のままになる
# アプリケーションのヘルスチェックなどでこれを確認できるようにするのも良い
raise RuntimeError(f"Failed to load model: {e}")
def generate_text(
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.3,
top_p: float = 0.9,
repetition_penalty: float = 1.0,
) -> str:
"""
ロードされたモデルを使ってテキストを生成する。
"""
if model is None or tokenizer is None:
raise RuntimeError("Model not loaded. Cannot generate text.")
try:
# プロンプトの形式はモデルによって調整が必要
# 例: Instructモデルの場合、特定のテンプレートがあることが多い
# ここでは単純にユーザープロンプトのみを使用
# inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
# より一般的なチャット形式のプロンプト適用 (モデルに合わせて調整)
# StableLM Instruct Gamma のプロンプト形式例 (あくまで一例)
# 参考: https://huggingface.co/stabilityai/japanese-stablelm-instruct-gamma-7b
messages = [{"role": "user", "content": prompt}]
# モデルによっては tokenizer.apply_chat_template が使える
try:
# 多くのモデルではtokenizer.apply_chat_templateが使える
prompt_formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
# Thinking Modeの切り替えここでできる
# enable_thinking=False,
)
except Exception:
# 古いモデルや特殊なモデルでapply_chat_templateがない場合の手動フォーマット例
# これはモデルのドキュメントを確認して適切な形式にする
print(
f"Warning: tokenizer.apply_chat_template failed for {MODEL_ID}. Using raw prompt or basic formatting."
)
if (
"stablelm-instruct" in MODEL_ID.lower() or "elyza" in MODEL_ID.lower()
): # ELYZAやStableLMの例
prompt_formatted = f"ユーザー: {prompt}\nシステム: "
elif (
"qwen" in MODEL_ID.lower() and "chat" in MODEL_ID.lower()
): # Qwen-Chatの例
prompt_formatted = (
f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
)
else: # デフォルトはそのまま
prompt_formatted = prompt
inputs = tokenizer(
prompt_formatted, return_tensors="pt", add_special_tokens=False
).to(DEVICE) # add_special_tokensはテンプレートによる
# テキスト生成
# pad_token_id はeos_token_idと同じに設定することが多い (警告抑制)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": True
if temperature > 0
else False, # temperatureが0超ならサンプリング
"pad_token_id": tokenizer.pad_token_id,
}
outputs = model.generate(**inputs, **generation_kwargs)
# 生成されたテキストのみをデコード (入力プロンプト部分を除く)
# inputs.input_ids.shape[1] は入力トークンの長さ
output_text = tokenizer.decode(
outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
)
return output_text.strip()
except Exception as e:
print(f"Error during text generation: {e}")
# traceback.print_exc() # 詳細なエラー表示
raise RuntimeError(f"Text generation failed: {e}")
|