Spaces:
Sleeping
Sleeping
| """ | |
| AIモデルアダプターのファクトリー関数 | |
| 環境変数に基づいて適切なモデルを自動選択 | |
| """ | |
| import os | |
| from typing import Optional | |
| from .base import BaseAI | |
| from .transformers_ai import TransformersAI | |
| from .openai_ai import OpenAIAI | |
| from .anthropic_ai import AnthropicAI | |
| from .google_ai import GoogleAI | |
| def get_ai_model(model_type: Optional[str] = None, **kwargs) -> BaseAI: | |
| """ | |
| 環境変数または引数に基づいて適切なAIモデルを取得 | |
| Args: | |
| model_type: モデルタイプ("transformers", "openai", "anthropic", "google") | |
| Noneの場合は環境変数MODEL_TYPEから取得 | |
| **kwargs: 各モデル固有の引数 | |
| - transformers: model_path | |
| - openai: model_name, api_key | |
| - anthropic: model_name, api_key | |
| - google: model_name, api_key | |
| Returns: | |
| BaseAI: 選択されたモデルのインスタンス | |
| Examples: | |
| # 環境変数から自動選択 | |
| ai = get_ai_model() | |
| # 明示的に指定 | |
| ai = get_ai_model("transformers", model_path="Qwen/Qwen2.5-3B-Instruct") | |
| ai = get_ai_model("openai", model_name="gpt-4", api_key="sk-...") | |
| """ | |
| # モデルタイプを決定 | |
| if model_type is None: | |
| model_type = os.getenv("MODEL_TYPE", "transformers") | |
| model_type = model_type.lower() | |
| # モデルタイプに応じて適切なクラスを返す | |
| if model_type == "transformers": | |
| model_path = kwargs.get("model_path") or os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct") | |
| return TransformersAI.get_model(model_path=model_path) | |
| elif model_type == "openai": | |
| model_name = kwargs.get("model_name") or os.getenv("OPENAI_MODEL", "gpt-4") | |
| api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY") | |
| return OpenAIAI.get_model(model_name=model_name, api_key=api_key) | |
| elif model_type == "anthropic": | |
| model_name = kwargs.get("model_name") or os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022") | |
| api_key = kwargs.get("api_key") or os.getenv("ANTHROPIC_API_KEY") | |
| return AnthropicAI.get_model(model_name=model_name, api_key=api_key) | |
| elif model_type == "google": | |
| model_name = kwargs.get("model_name") or os.getenv("GOOGLE_MODEL", "gemini-pro") | |
| api_key = kwargs.get("api_key") or os.getenv("GOOGLE_API_KEY") | |
| return GoogleAI.get_model(model_name=model_name, api_key=api_key) | |
| else: | |
| raise ValueError( | |
| f"不明なモデルタイプ: {model_type}. " | |
| f"サポートされているタイプ: transformers, openai, anthropic, google" | |
| ) | |
| # 後方互換性のため、BaseAIもエクスポート | |
| __all__ = [ | |
| "BaseAI", | |
| "TransformersAI", | |
| "OpenAIAI", | |
| "AnthropicAI", | |
| "GoogleAI", | |
| "get_ai_model", | |
| ] | |