File size: 6,155 Bytes
0447f30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
GoogleAI - Google API(Gemini)用アダプター
"""
from typing import List, Tuple, Optional
import os
import math
from .base import BaseAI


class GoogleAI(BaseAI):
    """
    Google API(Gemini)用アダプター
    
    特徴:
    - API経由でモデルにアクセス
    - logprobsパラメータでトークン確率を取得可能
    - user/assistantを分離しない方が良い場合もある(テキスト形式)
    - systemとuserを結合したテキスト形式を推奨
    """
    
    _instances = {}  # モデルごとのインスタンスをキャッシュ
    
    def __new__(cls, model_name: str = None, api_key: str = None):
        """シングルトンパターンでクライアントを常駐"""
        model = model_name or os.getenv("GOOGLE_MODEL", "gemini-pro")
        key = api_key or os.getenv("GOOGLE_API_KEY")
        
        cache_key = f"{model}:{key}"
        if cache_key not in cls._instances:
            cls._instances[cache_key] = super().__new__(cls)
            cls._instances[cache_key]._initialized = False
        
        return cls._instances[cache_key]
    
    def __init__(self, model_name: str = None, api_key: str = None):
        """
        初期化
        
        Args:
            model_name: モデル名(例: "gemini-pro")
            api_key: Google APIキー
        """
        if hasattr(self, '_initialized') and self._initialized:
            return
        
        self.model_name = model_name or os.getenv("GOOGLE_MODEL", "gemini-pro")
        self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
        self._initialized = True
        
        if not self.api_key:
            raise ValueError("GOOGLE_API_KEYが設定されていません")
        
        # Google Generative AIクライアントを初期化
        try:
            import google.generativeai as genai
            genai.configure(api_key=self.api_key)
            self.model = genai.GenerativeModel(self.model_name)
            print(f"[GoogleAI] 初期化完了: モデル={self.model_name}")
        except ImportError:
            raise ImportError("google-generativeaiパッケージがインストールされていません。pip install google-generativeai を実行してください")
        except Exception as e:
            raise ValueError(f"Google Generative AIクライアントの初期化に失敗しました: {e}")
    
    @classmethod
    def get_model(cls, model_name: str = None, api_key: str = None) -> 'GoogleAI':
        """モデルインスタンスを取得(常駐キャッシュから)"""
        return cls(model_name, api_key)
    
    @classmethod
    def clear_cache(cls):
        """キャッシュをクリア(開発・テスト用)"""
        cls._instances.clear()
    
    def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
        """
        文章とkを引数に、{token, 確率}のリストを返す
        
        Args:
            text: 入力文章(プロンプト)
            k: 取得するトークン数
            
        Returns:
            List[Tuple[str, float]]: (トークン, 確率)のリスト
        """
        try:
            # Gemini APIでトークン確率を取得
            # 注意: Gemini APIのlogprobs取得方法は他のAPIと異なる可能性があります
            response = self.model.generate_content(
                text,
                generation_config={
                    "max_output_tokens": 1,  # 次のトークン1つだけを取得
                    "temperature": 0.0,  # 確定的な結果を得るため
                }
            )
            
            # 注意: Gemini APIのlogprobs取得方法は公式ドキュメントを確認してください
            # ここでは仮の実装です
            items: List[Tuple[str, float]] = []
            
            # 実際の実装では、responseからlogprobsを取得する必要があります
            # 現在のGemini APIでは、logprobsの直接取得が難しい可能性があります
            # 代替案: 複数回のサンプリングで確率を推定
            
            print("[GoogleAI] 警告: Gemini APIのlogprobs取得は実装が不完全です")
            return items
            
        except Exception as e:
            print(f"[GoogleAI] トークン確率取得エラー: {e}")
            import traceback
            traceback.print_exc()
            return []
    
    def build_chat_prompt(
        self, 
        user_content: str, 
        system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください",
        assistant_content: Optional[str] = None
    ) -> str:
        """
        チャットプロンプトを構築(Gemini形式)
        
        注意: Geminiでは、user/assistantを分離しない方が良い場合もあります。
        systemとuserを結合したテキスト形式を推奨します。
        
        Args:
            user_content: ユーザーのメッセージ
            system_content: システムプロンプト
            assistant_content: アシスタントの既存応答(会話履歴用、オプション)
            
        Returns:
            str: Gemini形式のプロンプト(テキスト)
        """
        prompt_parts = []
        
        # Systemメッセージ(最初に1回だけ)
        if system_content:
            prompt_parts.append(f"システム: {system_content}")
            prompt_parts.append("")
        
        # 会話履歴がある場合(assistant_contentが指定されている場合)
        if assistant_content:
            prompt_parts.append(f"ユーザー: {user_content}")
            prompt_parts.append(f"アシスタント: {assistant_content}")
            prompt_parts.append("")
        
        # 現在のUserメッセージ
        prompt_parts.append(f"ユーザー: {user_content}")
        prompt_parts.append("アシスタント:")
        
        prompt_text = "\n".join(prompt_parts)
        return prompt_text