LLMView_multi_model / package /ai /transformers_ai.py
WatNeru's picture
bugfix
2bdf0a5
"""
TransformersAI - Hugging Face Transformersモデル用アダプター
Llama 3.2、Qwen、Mistral、Gemma等のローカルモデルに対応
"""
from typing import List, Tuple, Any, Optional
import os
from .base import BaseAI
class TransformersAI(BaseAI):
"""
Hugging Face Transformersモデル用アダプター
特徴:
- ローカルでモデルをロード
- logitsから直接確率を取得可能
- user/assistantを明確に分離する形式を推奨(Llama 3.2形式)
"""
_instances = {} # モデルパスごとのインスタンスをキャッシュ(常駐)
def __new__(cls, model_path: str = None):
"""シングルトンパターンでモデルを常駐"""
path = model_path or os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
if path not in cls._instances:
cls._instances[path] = super().__new__(cls)
cls._instances[path]._initialized = False
return cls._instances[path]
def __init__(self, model_path: str = None):
"""
モデルをロードして初期化(一度だけ実行、常駐)
Args:
model_path: モデルリポジトリIDまたはローカルパス
"""
if hasattr(self, '_initialized') and self._initialized:
return
self.model_path = model_path or os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
self.model = self._load_model(self.model_path)
self._initialized = True
if self.model is None:
raise ValueError(f"モデルのロードに失敗しました: {self.model_path}")
@classmethod
def get_model(cls, model_path: str = None) -> 'TransformersAI':
"""モデルインスタンスを取得(常駐キャッシュから)"""
return cls(model_path)
@classmethod
def clear_cache(cls):
"""キャッシュをクリア(開発・テスト用)"""
cls._instances.clear()
def _load_model(self, model_path: str) -> Optional[Any]:
"""モデルをロード(Transformers使用、Hubから直接読み込み)"""
try:
if not model_path:
return None
# モデルパスがリポジトリID("user/repo"形式)か、ローカルパスかを判定
is_repo_id = "/" in model_path and not os.path.exists(model_path)
# リポジトリIDの場合は os.path.exists() チェックをスキップ
if not is_repo_id and not os.path.exists(model_path):
print(f"[TransformersAI] モデルパスが存在しません: {model_path}")
return None
# transformersを使用してモデルをロード
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# GPUが利用可能かチェック
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
print("[TransformersAI] GPU検出: CUDAを使用します")
else:
print("[TransformersAI] GPU未検出: CPUモードで実行します")
print(f"[TransformersAI] モデルをロード中: {model_path}")
print(f"[TransformersAI] デバイス: {device}")
hf_token = os.getenv("HF_TOKEN")
if is_repo_id:
print(f"[TransformersAI] Hugging Face Hub から直接読み込み: {model_path}")
else:
print(f"[TransformersAI] ローカルパスから読み込み: {model_path}")
# トークナイザーとモデルをロード(Hubから直接読み込む)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
token=hf_token,
)
# device_map="auto"を使用する場合はaccelerateが必要
# accelerateがインストールされていない場合は、device_mapをNoneにして手動でデバイスに移動
try:
import accelerate
use_device_map = "auto" if device == "cuda" else None
except ImportError:
print("[TransformersAI] 警告: accelerateがインストールされていません。device_mapを使用しません。")
use_device_map = None
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map=use_device_map,
token=hf_token,
)
# device_mapを使用しない場合は、手動でデバイスに移動
if use_device_map is None:
model = model.to(device)
# モデルとトークナイザーをタプルで返す
print(f"[TransformersAI] モデルロード成功 ({device}モード)")
return (model, tokenizer)
except Exception as e:
import traceback
print(f"[TransformersAI] transformersでのロードに失敗: {e}")
traceback.print_exc()
return None
except Exception as e:
import traceback
print(f"[TransformersAI] モデルロードエラー: {e}")
traceback.print_exc()
return None
def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
"""
文章とkを引数に、{token, 確率}のリストを返す
Args:
text: 入力文章
k: 取得するトークン数
Returns:
List[Tuple[str, float]]: (トークン, 確率)のリスト
"""
if self.model is None:
return []
try:
# transformers モデルの場合
if isinstance(self.model, tuple) and len(self.model) == 2:
model, tokenizer = self.model
import torch
# テキストをトークン化
inputs = tokenizer(text, return_tensors="pt")
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# モデルで推論(勾配計算なし)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, -1, :] # 最後のトークンのlogits
# logitsを確率に変換(softmax)
probs = torch.softmax(logits, dim=-1)
# 上位k個のトークンを取得
top_probs, top_indices = torch.topk(probs, k)
# トークンIDを文字列に変換
items: List[Tuple[str, float]] = []
# 特殊トークンを定義(Llama 3.2、Qwen、Mistral等で使用)
SPECIAL_TOKENS = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|eot_id|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|im_start|>",
"<|im_end|>",
]
def _clean_text_local(text: str) -> str:
"""制御文字・不可視文字・置換文字・特殊トークンを厳密に取り除く"""
if not text:
return ""
# 特殊トークンを除去
for special_token in SPECIAL_TOKENS:
text = text.replace(special_token, "")
# 基底クラスの_clean_textを使用
return self._clean_text(text)
for idx, prob in zip(top_indices, top_probs):
token_id = idx.item()
# skip_special_tokens=Trueで特殊トークンを除外
token = tokenizer.decode([token_id], skip_special_tokens=True, clean_up_tokenization_spaces=False)
token = _clean_text_local(token)
# 空文字列のトークンは除外
if not token:
continue
prob_value = prob.item()
items.append((token, float(prob_value)))
# 確率を正規化
if items:
total_prob = sum(prob for _, prob in items)
if total_prob > 0:
normalized_items: List[Tuple[str, float]] = []
for token, prob in items:
normalized_prob = prob / total_prob
normalized_items.append((token, normalized_prob))
return normalized_items
return items
else:
print("[TransformersAI] モデルがサポートされていません")
return []
except Exception as e:
print(f"[TransformersAI] トークン確率取得エラー: {e}")
import traceback
traceback.print_exc()
return []
def build_chat_prompt(
self,
user_content: str,
system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください",
assistant_content: Optional[str] = None
) -> str:
"""
チャットプロンプトを構築(Llama 3.2形式)
注意: Transformersモデル(特にLlama 3.2、Qwen等)では、
user/assistantを明確に分離する形式を推奨します。
Args:
user_content: ユーザーのメッセージ
system_content: システムプロンプト
assistant_content: アシスタントの既存応答(会話履歴用、オプション)
Returns:
str: Llama 3.2形式のプロンプト
"""
# 既に整形済みのプロンプトが渡されている場合(複数行、ヘッダーを含む)
# そのまま返す
if "<|start_header_id|>" in user_content or "<|eot_id|>" in user_content:
return user_content
# Llama 3.2形式でプロンプトを構築
prompt_parts = []
# Systemメッセージ
if system_content:
prompt_parts.append("<|start_header_id|>system<|end_header_id|>")
prompt_parts.append(system_content)
prompt_parts.append("<|eot_id|>")
# Userメッセージ
prompt_parts.append("<|start_header_id|>user<|end_header_id|>")
prompt_parts.append(user_content)
prompt_parts.append("<|eot_id|>")
# Assistantメッセージ(会話履歴がある場合)
if assistant_content:
prompt_parts.append("<|start_header_id|>assistant<|end_header_id|>")
prompt_parts.append(assistant_content)
prompt_parts.append("<|eot_id|>")
# 新しい応答を生成する場合は、assistantヘッダーだけを追加
prompt_parts.append("<|start_header_id|>assistant<|end_header_id|>")
prompt_text = "\n".join(prompt_parts)
# BOS(<|begin_of_text|>) の重複を抑止: 先頭のBOSを全て除去
# transformers側でBOSが自動付与される場合があるため
BOS = "<|begin_of_text|>"
s = prompt_text.lstrip()
while s.startswith(BOS):
s = s[len(BOS):]
prompt_text = s
return prompt_text