Spaces:
Sleeping
Sleeping
| """ | |
| OpenAIAI - OpenAI API(ChatGPT)用アダプター | |
| """ | |
| from typing import List, Tuple, Optional, Dict, Any | |
| import os | |
| import math | |
| from .base import BaseAI | |
| class OpenAIAI(BaseAI): | |
| """ | |
| OpenAI API(ChatGPT)用アダプター | |
| 特徴: | |
| - API経由でモデルにアクセス | |
| - logprobsパラメータでトークン確率を取得可能(GPT-4以降) | |
| - user/assistantを明確に分離する形式を推奨(messages配列形式) | |
| """ | |
| _instances = {} # モデルごとのインスタンスをキャッシュ | |
| def __new__(cls, model_name: str = None, api_key: str = None): | |
| """シングルトンパターンでクライアントを常駐""" | |
| model = model_name or os.getenv("OPENAI_MODEL", "gpt-4") | |
| key = api_key or os.getenv("OPENAI_API_KEY") | |
| cache_key = f"{model}:{key}" | |
| if cache_key not in cls._instances: | |
| cls._instances[cache_key] = super().__new__(cls) | |
| cls._instances[cache_key]._initialized = False | |
| return cls._instances[cache_key] | |
| def __init__(self, model_name: str = None, api_key: str = None): | |
| """ | |
| 初期化 | |
| Args: | |
| model_name: モデル名(例: "gpt-4", "gpt-3.5-turbo") | |
| api_key: OpenAI APIキー | |
| """ | |
| if hasattr(self, '_initialized') and self._initialized: | |
| return | |
| self.model_name = model_name or os.getenv("OPENAI_MODEL", "gpt-4") | |
| self.api_key = api_key or os.getenv("OPENAI_API_KEY") | |
| self._initialized = True | |
| if not self.api_key: | |
| raise ValueError("OPENAI_API_KEYが設定されていません") | |
| # OpenAIクライアントを初期化 | |
| try: | |
| import openai | |
| self.client = openai.OpenAI(api_key=self.api_key) | |
| print(f"[OpenAIAI] 初期化完了: モデル={self.model_name}") | |
| except ImportError: | |
| raise ImportError("openaiパッケージがインストールされていません。pip install openai を実行してください") | |
| except Exception as e: | |
| raise ValueError(f"OpenAIクライアントの初期化に失敗しました: {e}") | |
| def get_model(cls, model_name: str = None, api_key: str = None) -> 'OpenAIAI': | |
| """モデルインスタンスを取得(常駐キャッシュから)""" | |
| return cls(model_name, api_key) | |
| def clear_cache(cls): | |
| """キャッシュをクリア(開発・テスト用)""" | |
| cls._instances.clear() | |
| def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]: | |
| """ | |
| 文章とkを引数に、{token, 確率}のリストを返す | |
| Args: | |
| text: 入力文章(プロンプト) | |
| k: 取得するトークン数 | |
| Returns: | |
| List[Tuple[str, float]]: (トークン, 確率)のリスト | |
| """ | |
| try: | |
| # OpenAI APIでは、messages形式でリクエストする必要がある | |
| # textが既にmessages形式かどうかを判定 | |
| if isinstance(text, str): | |
| # 文字列の場合は、userメッセージとして扱う | |
| messages = [{"role": "user", "content": text}] | |
| else: | |
| messages = text | |
| # API呼び出し(logprobs=Trueでトークン確率を取得) | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| logprobs=True, | |
| top_logprobs=k, | |
| max_tokens=1, # 次のトークン1つだけを取得 | |
| ) | |
| # logprobsから確率を計算 | |
| items: List[Tuple[str, float]] = [] | |
| if response.choices and response.choices[0].logprobs: | |
| logprobs = response.choices[0].logprobs.content[0] if response.choices[0].logprobs.content else None | |
| if logprobs: | |
| # top_logprobsから確率を取得 | |
| for token_info in logprobs.top_logprobs: | |
| token = self._clean_text(token_info.token) | |
| if not token: | |
| continue | |
| # logprobを確率に変換 | |
| prob = math.exp(token_info.logprob) | |
| items.append((token, float(prob))) | |
| # 確率を正規化 | |
| if items: | |
| total_prob = sum(prob for _, prob in items) | |
| if total_prob > 0: | |
| normalized_items: List[Tuple[str, float]] = [] | |
| for token, prob in items: | |
| normalized_prob = prob / total_prob | |
| normalized_items.append((token, normalized_prob)) | |
| return normalized_items | |
| return items | |
| except Exception as e: | |
| print(f"[OpenAIAI] トークン確率取得エラー: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [] | |
| def build_chat_prompt( | |
| self, | |
| user_content: str, | |
| system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください", | |
| assistant_content: Optional[str] = None | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| チャットプロンプトを構築(OpenAI messages形式) | |
| 注意: OpenAIでは、user/assistantを明確に分離するmessages配列形式を推奨します。 | |
| このメソッドは文字列ではなく、messages配列を返します。 | |
| Args: | |
| user_content: ユーザーのメッセージ | |
| system_content: システムプロンプト | |
| assistant_content: アシスタントの既存応答(会話履歴用、オプション) | |
| Returns: | |
| List[Dict[str, str]]: OpenAI messages形式の配列 | |
| """ | |
| messages = [] | |
| # Systemメッセージ(最初に1回だけ) | |
| if system_content: | |
| messages.append({ | |
| "role": "system", | |
| "content": system_content | |
| }) | |
| # 会話履歴がある場合(assistant_contentが指定されている場合) | |
| if assistant_content: | |
| # 前回のuserメッセージとassistant応答を追加 | |
| # 注意: この実装では、assistant_contentのみを追加 | |
| # 実際の会話履歴管理は呼び出し側で行う必要があります | |
| messages.append({ | |
| "role": "assistant", | |
| "content": assistant_content | |
| }) | |
| # 現在のUserメッセージ | |
| messages.append({ | |
| "role": "user", | |
| "content": user_content | |
| }) | |
| return messages | |