LLMView / package /ai.py
WatNeru's picture
細かいミス修正
c7f69f6
from typing import List, Tuple, Any, Optional
import os
import subprocess
from config import Config
class AI:
"""AIクラス - モデルをロードして文章とkを引数にトークンと確率のリストを返す(常駐版)"""
_instances = {} # モデルパスごとのインスタンスをキャッシュ(常駐)
def __new__(cls, model_path: str = None):
"""シングルトンパターンでモデルを常駐"""
path = model_path or Config.get_default_model_path()
if path not in cls._instances:
cls._instances[path] = super().__new__(cls)
cls._instances[path]._initialized = False
return cls._instances[path]
def __init__(self, model_path: str = None):
"""
モデルをロードして初期化(一度だけ実行、常駐)
Args:
model_path: モデルファイルのパス(Noneの場合はデフォルトパスを使用)
"""
if hasattr(self, '_initialized') and self._initialized:
return
self.model_path = model_path or Config.get_default_model_path()
self.model = self._load_model(self.model_path)
self._initialized = True
if self.model is None:
raise ValueError(f"モデルのロードに失敗しました: {self.model_path}")
@classmethod
def get_model(cls, model_path: str = None) -> 'AI':
"""モデルインスタンスを取得(常駐キャッシュから)"""
return cls(model_path)
@classmethod
def clear_cache(cls):
"""キャッシュをクリア(開発・テスト用)"""
cls._instances.clear()
def _load_model(self, model_path: str) -> Optional[Any]:
"""モデルをロード(Transformers使用、Hubから直接読み込み)"""
try:
if not model_path:
return None
# モデルパスがリポジトリID("user/repo"形式)か、ローカルパスかを判定
is_repo_id = "/" in model_path and not os.path.exists(model_path)
# リポジトリIDの場合は os.path.exists() チェックをスキップ
if not is_repo_id and not os.path.exists(model_path):
print(f"[AI] モデルパスが存在しません: {model_path}")
return None
# transformersを使用してモデルをロード
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# GPUが利用可能かチェック
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
print("[AI] GPU検出: CUDAを使用します")
else:
print("[AI] GPU未検出: CPUモードで実行します")
print(f"[AI] モデルをロード中: {model_path}")
print(f"[AI] デバイス: {device}")
# モデルパスがリポジトリID("user/repo"形式)か、ローカルパスかを判定
hf_token = os.getenv("HF_TOKEN")
is_repo_id = "/" in model_path and not os.path.exists(model_path)
if is_repo_id:
print(f"[AI] Hugging Face Hub から直接読み込み: {model_path}")
else:
print(f"[AI] ローカルパスから読み込み: {model_path}")
# トークナイザーとモデルをロード(Hubから直接読み込む)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
token=hf_token,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
token=hf_token,
)
if device == "cpu":
model = model.to(device)
# モデルとトークナイザーをタプルで返す
print(f"[AI] モデルロード成功 ({device}モード)")
return (model, tokenizer)
except Exception as e:
import traceback
print(f"[AI] transformersでのロードに失敗: {e}")
traceback.print_exc()
return None
except Exception as e:
import traceback
print(f"[AI] モデルロードエラー: {e}")
traceback.print_exc()
return None
def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
"""
文章とkを引数に、{token, 確率}のリストを返す
Args:
text: 入力文章
k: 取得するトークン数
Returns:
List[Tuple[str, float]]: (トークン, 確率)のリスト
"""
if self.model is None:
return []
try:
# transformers モデルの場合
if isinstance(self.model, tuple) and len(self.model) == 2:
model, tokenizer = self.model
import torch
# テキストをトークン化
inputs = tokenizer(text, return_tensors="pt")
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# モデルで推論(勾配計算なし)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, -1, :] # 最後のトークンのlogits
# logitsを確率に変換(softmax)
probs = torch.softmax(logits, dim=-1)
# 上位k個のトークンを取得
top_probs, top_indices = torch.topk(probs, k)
# トークンIDを文字列に変換
items: List[Tuple[str, float]] = []
# Llama 3.2の特殊トークンを定義
LLAMA_SPECIAL_TOKENS = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|eot_id|>",
"<|start_header_id|>",
"<|end_header_id|>",
]
def _clean_text(text: str) -> str:
"""制御文字・不可視文字・置換文字を厳密に取り除く(正規タグは保持)"""
if not text:
return ""
# 制御文字(0x00-0x1F、0x7F-0x9F)を除去
# ただし、改行・タブ・復帰は許可
cleaned = []
for ch in text:
code = ord(ch)
# 許可する制御文字: 改行(0x0A), タブ(0x09), 復帰(0x0D)
if code in [0x09, 0x0A, 0x0D]:
cleaned.append(ch)
# 通常の印刷可能文字(0x20-0x7E、およびその他のUnicode印刷可能文字)
elif ch.isprintable():
# 置換文字(U+FFFD)を除去
if ch != "\uFFFD":
cleaned.append(ch)
# その他の制御文字や不可視文字は除去
result = "".join(cleaned)
# ゼロ幅文字を除去
result = result.replace("\u200B", "") # Zero-width space
result = result.replace("\u200C", "") # Zero-width non-joiner
result = result.replace("\u200D", "") # Zero-width joiner
result = result.replace("\uFEFF", "") # Zero-width no-break space
# その他の不可視文字(結合文字など)を除去
result = result.replace("\u200E", "") # Left-to-right mark
result = result.replace("\u200F", "") # Right-to-left mark
result = result.replace("\u202A", "") # Left-to-right embedding
result = result.replace("\u202B", "") # Right-to-left embedding
result = result.replace("\u202C", "") # Pop directional formatting
result = result.replace("\u202D", "") # Left-to-right override
result = result.replace("\u202E", "") # Right-to-left override
return result.strip()
for idx, prob in zip(top_indices, top_probs):
token_id = idx.item()
# 正規タグ(<|eot_id|>など)を保持するため、skip_special_tokens=False
token = tokenizer.decode([token_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
token = _clean_text(token)
# 空文字列のトークンは除外
if not token:
continue
prob_value = prob.item()
items.append((token, float(prob_value)))
# 確率を正規化
if items:
total_prob = sum(prob for _, prob in items)
if total_prob > 0:
normalized_items: List[Tuple[str, float]] = []
for token, prob in items:
normalized_prob = prob / total_prob
normalized_items.append((token, normalized_prob))
return normalized_items
return items
else:
print("モデルがサポートされていません")
return []
except Exception as e:
print(f"トークン確率取得エラー: {e}")
import traceback
traceback.print_exc()
return []
def _softmax_from_logprobs(self, logprobs: List[float]) -> List[float]:
"""logprobsをsoftmaxで確率に変換"""
if not logprobs:
return []
# 数値安定性のため最大値を引く
max_logprob = max(logprobs)
exp_logprobs = [exp(logprob - max_logprob) for logprob in logprobs]
sum_exp = sum(exp_logprobs)
if sum_exp == 0:
return [0.0] * len(logprobs)
return [exp_logprob / sum_exp for exp_logprob in exp_logprobs]
def exp(x: float) -> float:
"""指数関数の近似実装(math.expの代替)"""
import math
return math.exp(x)