Spaces:
Sleeping
Sleeping
| """ | |
| GoogleAI - Google API(Gemini)用アダプター | |
| """ | |
| from typing import List, Tuple, Optional | |
| import os | |
| import math | |
| from .base import BaseAI | |
| class GoogleAI(BaseAI): | |
| """ | |
| Google API(Gemini)用アダプター | |
| 特徴: | |
| - API経由でモデルにアクセス | |
| - logprobsパラメータでトークン確率を取得可能 | |
| - user/assistantを分離しない方が良い場合もある(テキスト形式) | |
| - systemとuserを結合したテキスト形式を推奨 | |
| """ | |
| _instances = {} # モデルごとのインスタンスをキャッシュ | |
| def __new__(cls, model_name: str = None, api_key: str = None): | |
| """シングルトンパターンでクライアントを常駐""" | |
| model = model_name or os.getenv("GOOGLE_MODEL", "gemini-pro") | |
| key = api_key or os.getenv("GOOGLE_API_KEY") | |
| cache_key = f"{model}:{key}" | |
| if cache_key not in cls._instances: | |
| cls._instances[cache_key] = super().__new__(cls) | |
| cls._instances[cache_key]._initialized = False | |
| return cls._instances[cache_key] | |
| def __init__(self, model_name: str = None, api_key: str = None): | |
| """ | |
| 初期化 | |
| Args: | |
| model_name: モデル名(例: "gemini-pro") | |
| api_key: Google APIキー | |
| """ | |
| if hasattr(self, '_initialized') and self._initialized: | |
| return | |
| self.model_name = model_name or os.getenv("GOOGLE_MODEL", "gemini-pro") | |
| self.api_key = api_key or os.getenv("GOOGLE_API_KEY") | |
| self._initialized = True | |
| if not self.api_key: | |
| raise ValueError("GOOGLE_API_KEYが設定されていません") | |
| # Google Generative AIクライアントを初期化 | |
| try: | |
| import google.generativeai as genai | |
| genai.configure(api_key=self.api_key) | |
| self.model = genai.GenerativeModel(self.model_name) | |
| print(f"[GoogleAI] 初期化完了: モデル={self.model_name}") | |
| except ImportError: | |
| raise ImportError("google-generativeaiパッケージがインストールされていません。pip install google-generativeai を実行してください") | |
| except Exception as e: | |
| raise ValueError(f"Google Generative AIクライアントの初期化に失敗しました: {e}") | |
| def get_model(cls, model_name: str = None, api_key: str = None) -> 'GoogleAI': | |
| """モデルインスタンスを取得(常駐キャッシュから)""" | |
| return cls(model_name, api_key) | |
| def clear_cache(cls): | |
| """キャッシュをクリア(開発・テスト用)""" | |
| cls._instances.clear() | |
| def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]: | |
| """ | |
| 文章とkを引数に、{token, 確率}のリストを返す | |
| Args: | |
| text: 入力文章(プロンプト) | |
| k: 取得するトークン数 | |
| Returns: | |
| List[Tuple[str, float]]: (トークン, 確率)のリスト | |
| """ | |
| try: | |
| # Gemini APIでトークン確率を取得 | |
| # 注意: Gemini APIのlogprobs取得方法は他のAPIと異なる可能性があります | |
| response = self.model.generate_content( | |
| text, | |
| generation_config={ | |
| "max_output_tokens": 1, # 次のトークン1つだけを取得 | |
| "temperature": 0.0, # 確定的な結果を得るため | |
| } | |
| ) | |
| # 注意: Gemini APIのlogprobs取得方法は公式ドキュメントを確認してください | |
| # ここでは仮の実装です | |
| items: List[Tuple[str, float]] = [] | |
| # 実際の実装では、responseからlogprobsを取得する必要があります | |
| # 現在のGemini APIでは、logprobsの直接取得が難しい可能性があります | |
| # 代替案: 複数回のサンプリングで確率を推定 | |
| print("[GoogleAI] 警告: Gemini APIのlogprobs取得は実装が不完全です") | |
| return items | |
| except Exception as e: | |
| print(f"[GoogleAI] トークン確率取得エラー: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [] | |
| def build_chat_prompt( | |
| self, | |
| user_content: str, | |
| system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください", | |
| assistant_content: Optional[str] = None | |
| ) -> str: | |
| """ | |
| チャットプロンプトを構築(Gemini形式) | |
| 注意: Geminiでは、user/assistantを分離しない方が良い場合もあります。 | |
| systemとuserを結合したテキスト形式を推奨します。 | |
| Args: | |
| user_content: ユーザーのメッセージ | |
| system_content: システムプロンプト | |
| assistant_content: アシスタントの既存応答(会話履歴用、オプション) | |
| Returns: | |
| str: Gemini形式のプロンプト(テキスト) | |
| """ | |
| prompt_parts = [] | |
| # Systemメッセージ(最初に1回だけ) | |
| if system_content: | |
| prompt_parts.append(f"システム: {system_content}") | |
| prompt_parts.append("") | |
| # 会話履歴がある場合(assistant_contentが指定されている場合) | |
| if assistant_content: | |
| prompt_parts.append(f"ユーザー: {user_content}") | |
| prompt_parts.append(f"アシスタント: {assistant_content}") | |
| prompt_parts.append("") | |
| # 現在のUserメッセージ | |
| prompt_parts.append(f"ユーザー: {user_content}") | |
| prompt_parts.append("アシスタント:") | |
| prompt_text = "\n".join(prompt_parts) | |
| return prompt_text | |