File size: 2,946 Bytes
0447f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
AIモデルアダプターのファクトリー関数
環境変数に基づいて適切なモデルを自動選択
"""
import os
from typing import Optional
from .base import BaseAI
from .transformers_ai import TransformersAI
from .openai_ai import OpenAIAI
from .anthropic_ai import AnthropicAI
from .google_ai import GoogleAI


def get_ai_model(model_type: Optional[str] = None, **kwargs) -> BaseAI:
    """
    環境変数または引数に基づいて適切なAIモデルを取得
    
    Args:
        model_type: モデルタイプ("transformers", "openai", "anthropic", "google")
                    Noneの場合は環境変数MODEL_TYPEから取得
        **kwargs: 各モデル固有の引数
                  - transformers: model_path
                  - openai: model_name, api_key
                  - anthropic: model_name, api_key
                  - google: model_name, api_key
    
    Returns:
        BaseAI: 選択されたモデルのインスタンス
        
    Examples:
        # 環境変数から自動選択
        ai = get_ai_model()
        
        # 明示的に指定
        ai = get_ai_model("transformers", model_path="Qwen/Qwen2.5-3B-Instruct")
        ai = get_ai_model("openai", model_name="gpt-4", api_key="sk-...")
    """
    # モデルタイプを決定
    if model_type is None:
        model_type = os.getenv("MODEL_TYPE", "transformers")
    
    model_type = model_type.lower()
    
    # モデルタイプに応じて適切なクラスを返す
    if model_type == "transformers":
        model_path = kwargs.get("model_path") or os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
        return TransformersAI.get_model(model_path=model_path)
    
    elif model_type == "openai":
        model_name = kwargs.get("model_name") or os.getenv("OPENAI_MODEL", "gpt-4")
        api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        return OpenAIAI.get_model(model_name=model_name, api_key=api_key)
    
    elif model_type == "anthropic":
        model_name = kwargs.get("model_name") or os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022")
        api_key = kwargs.get("api_key") or os.getenv("ANTHROPIC_API_KEY")
        return AnthropicAI.get_model(model_name=model_name, api_key=api_key)
    
    elif model_type == "google":
        model_name = kwargs.get("model_name") or os.getenv("GOOGLE_MODEL", "gemini-pro")
        api_key = kwargs.get("api_key") or os.getenv("GOOGLE_API_KEY")
        return GoogleAI.get_model(model_name=model_name, api_key=api_key)
    
    else:
        raise ValueError(
            f"不明なモデルタイプ: {model_type}. "
            f"サポートされているタイプ: transformers, openai, anthropic, google"
        )


# 後方互換性のため、BaseAIもエクスポート
__all__ = [
    "BaseAI",
    "TransformersAI",
    "OpenAIAI",
    "AnthropicAI",
    "GoogleAI",
    "get_ai_model",
]