llamate / model_loader.py
m1b2lover's picture
Upload 8 files
f99ed48 verified
# model_loader.py
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
# --- グローバル変数 (アプリケーション起動時にロードされる) ---
model = None
tokenizer = None
MODEL_ID = os.environ.get(
"MODEL_ID", "Qwen/Qwen3-30B-A3B"
) # 環境変数からモデルIDを取得、なければデフォルト
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LOAD_IN_4BIT = os.environ.get("LOAD_IN_4BIT", "false").lower() == "true"
LOAD_IN_8BIT = os.environ.get("LOAD_IN_8BIT", "false").lower() == "true"
# 4bitと8bitが同時にTrueになるのを防ぐ (どちらか一方、またはどちらもFalse)
if LOAD_IN_4BIT and LOAD_IN_8BIT:
print(
"Warning: Both LOAD_IN_4BIT and LOAD_IN_8BIT are set to true. Prioritizing 4-bit."
)
LOAD_IN_8BIT = False
elif not LOAD_IN_4BIT and not LOAD_IN_8BIT:
print(
"Info: No explicit quantization (4-bit/8-bit) requested via environment variables. Loading in default precision (e.g., bfloat16 on GPU)."
)
def load_model():
"""
アプリケーション起動時にモデルとトークナイザーをロードする。
"""
global model, tokenizer
if model is None or tokenizer is None:
quantization_info = "No Quantization"
if LOAD_IN_4BIT:
quantization_info = "4-bit Quantization"
elif LOAD_IN_8BIT:
quantization_info = "8-bit Quantization"
print(
f"Loading model: {MODEL_ID} on device: {DEVICE} with {quantization_info}..."
)
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model_kwargs = {
"trust_remote_code": True
} # 基本的にTrueにしておくことが多い
quantization_config = None
if DEVICE == "cuda":
model_kwargs["device_map"] = "auto"
if LOAD_IN_4BIT:
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_kwargs["torch_dtype"] = "auto" # 4bitと併用する計算時の型
# bnb_4bit_compute_dtype など、より詳細なbitsandbytes設定も環境変数で制御可能
elif LOAD_IN_8BIT:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# 8bitの場合、torch_dtypeは自動で設定されることが多いが、明示も可
else: # 量子化なしGPU
model_kwargs["torch_dtype"] = torch.bfloat16
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_ID,
# torch_dtype=torch.bfloat16, # または torch.float16
# load_in_4bit=True, # 4ビット量子化でロード (bitsandbytesが必要)
# # load_in_8bit=True, # 8ビット量子化の場合
# device_map="auto", # 自動でGPUに割り当て
# trust_remote_code=True, # モデルによっては必要
# )
else: # CPUの場合 (量子化はGPU推奨だが、一応対応)
# CPUでのbitsandbytes量子化は限定的、または非推奨
if LOAD_IN_4BIT or LOAD_IN_8BIT:
print(
"Warning: bitsandbytes quantization (4-bit/8-bit) is primarily for GPU. Attempting on CPU may be slow or unstable."
)
# model_kwargs["device_map"] = {"": "cpu"} # 明示的にCPUを指定
pass # .to(DEVICE) で対応
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, **model_kwargs, quantization_config=quantization_config
)
if DEVICE == "cpu" and not (
LOAD_IN_4BIT or LOAD_IN_8BIT
): # CPUで量子化なしの場合
model = model.to(DEVICE)
model.eval() # 評価モード
print(f"Model {MODEL_ID} loaded successfully.")
except Exception as e:
print(f"Error loading model {MODEL_ID}: {e}")
# エラー発生時は model と tokenizer が None のままになる
# アプリケーションのヘルスチェックなどでこれを確認できるようにするのも良い
raise RuntimeError(f"Failed to load model: {e}")
def generate_text(
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.3,
top_p: float = 0.9,
repetition_penalty: float = 1.0,
) -> str:
"""
ロードされたモデルを使ってテキストを生成する。
"""
if model is None or tokenizer is None:
raise RuntimeError("Model not loaded. Cannot generate text.")
try:
# プロンプトの形式はモデルによって調整が必要
# 例: Instructモデルの場合、特定のテンプレートがあることが多い
# ここでは単純にユーザープロンプトのみを使用
# inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
# より一般的なチャット形式のプロンプト適用 (モデルに合わせて調整)
# StableLM Instruct Gamma のプロンプト形式例 (あくまで一例)
# 参考: https://huggingface.co/stabilityai/japanese-stablelm-instruct-gamma-7b
messages = [{"role": "user", "content": prompt}]
# モデルによっては tokenizer.apply_chat_template が使える
try:
# 多くのモデルではtokenizer.apply_chat_templateが使える
prompt_formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
# Thinking Modeの切り替えここでできる
# enable_thinking=False,
)
except Exception:
# 古いモデルや特殊なモデルでapply_chat_templateがない場合の手動フォーマット例
# これはモデルのドキュメントを確認して適切な形式にする
print(
f"Warning: tokenizer.apply_chat_template failed for {MODEL_ID}. Using raw prompt or basic formatting."
)
if (
"stablelm-instruct" in MODEL_ID.lower() or "elyza" in MODEL_ID.lower()
): # ELYZAやStableLMの例
prompt_formatted = f"ユーザー: {prompt}\nシステム: "
elif (
"qwen" in MODEL_ID.lower() and "chat" in MODEL_ID.lower()
): # Qwen-Chatの例
prompt_formatted = (
f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
)
else: # デフォルトはそのまま
prompt_formatted = prompt
inputs = tokenizer(
prompt_formatted, return_tensors="pt", add_special_tokens=False
).to(DEVICE) # add_special_tokensはテンプレートによる
# テキスト生成
# pad_token_id はeos_token_idと同じに設定することが多い (警告抑制)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"do_sample": True
if temperature > 0
else False, # temperatureが0超ならサンプリング
"pad_token_id": tokenizer.pad_token_id,
}
outputs = model.generate(**inputs, **generation_kwargs)
# 生成されたテキストのみをデコード (入力プロンプト部分を除く)
# inputs.input_ids.shape[1] は入力トークンの長さ
output_text = tokenizer.decode(
outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
)
return output_text.strip()
except Exception as e:
print(f"Error during text generation: {e}")
# traceback.print_exc() # 詳細なエラー表示
raise RuntimeError(f"Text generation failed: {e}")