LLMView_multi_model / package /ai /__init__.py
WatNeru's picture
first commit
0447f30
"""
AIモデルアダプターのファクトリー関数
環境変数に基づいて適切なモデルを自動選択
"""
import os
from typing import Optional
from .base import BaseAI
from .transformers_ai import TransformersAI
from .openai_ai import OpenAIAI
from .anthropic_ai import AnthropicAI
from .google_ai import GoogleAI
def get_ai_model(model_type: Optional[str] = None, **kwargs) -> BaseAI:
"""
環境変数または引数に基づいて適切なAIモデルを取得
Args:
model_type: モデルタイプ("transformers", "openai", "anthropic", "google")
Noneの場合は環境変数MODEL_TYPEから取得
**kwargs: 各モデル固有の引数
- transformers: model_path
- openai: model_name, api_key
- anthropic: model_name, api_key
- google: model_name, api_key
Returns:
BaseAI: 選択されたモデルのインスタンス
Examples:
# 環境変数から自動選択
ai = get_ai_model()
# 明示的に指定
ai = get_ai_model("transformers", model_path="Qwen/Qwen2.5-3B-Instruct")
ai = get_ai_model("openai", model_name="gpt-4", api_key="sk-...")
"""
# モデルタイプを決定
if model_type is None:
model_type = os.getenv("MODEL_TYPE", "transformers")
model_type = model_type.lower()
# モデルタイプに応じて適切なクラスを返す
if model_type == "transformers":
model_path = kwargs.get("model_path") or os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
return TransformersAI.get_model(model_path=model_path)
elif model_type == "openai":
model_name = kwargs.get("model_name") or os.getenv("OPENAI_MODEL", "gpt-4")
api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
return OpenAIAI.get_model(model_name=model_name, api_key=api_key)
elif model_type == "anthropic":
model_name = kwargs.get("model_name") or os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022")
api_key = kwargs.get("api_key") or os.getenv("ANTHROPIC_API_KEY")
return AnthropicAI.get_model(model_name=model_name, api_key=api_key)
elif model_type == "google":
model_name = kwargs.get("model_name") or os.getenv("GOOGLE_MODEL", "gemini-pro")
api_key = kwargs.get("api_key") or os.getenv("GOOGLE_API_KEY")
return GoogleAI.get_model(model_name=model_name, api_key=api_key)
else:
raise ValueError(
f"不明なモデルタイプ: {model_type}. "
f"サポートされているタイプ: transformers, openai, anthropic, google"
)
# 後方互換性のため、BaseAIもエクスポート
__all__ = [
"BaseAI",
"TransformersAI",
"OpenAIAI",
"AnthropicAI",
"GoogleAI",
"get_ai_model",
]