playground / hackaprompt /completers.py
JerMa88's picture
Rewrote the program as I do not need other 2 completors anyway... Trash OpenAI Migrate function...
2b50232 verified
import logging
import os
from abc import ABC, abstractmethod
from gradio_client import Client
from openai import OpenAI
import tiktoken
from transformers import T5Tokenizer
# Authentication tokens from environment variables
HUB_TOKEN = os.getenv("HUB_TOKEN")
FLAN_HF_SPACE = "jerpint/i-like-flan"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextCompletion(ABC):
@abstractmethod
def get_token_count(self, prompt: str) -> int:
...
@abstractmethod
def complete(self, prompt: str):
...
def get_score_multiplier(self) -> float:
return 1.0
class ChatGPTCompletor(TextCompletion):
def __init__(self, openai_api_key, model):
self.openai_api_key = openai_api_key
self.model = model
# Initialize the client inside the class using the provided key
self.client = OpenAI(api_key=self.openai_api_key)
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model(self.model)
return len(encoding.encode(prompt))
def complete(self, prompt: str):
messages = [
{"role": "user", "content": prompt},
]
# Call the API using the instance's client
response = self.client.chat.completions.create(
messages=messages,
model=self.model,
temperature=0
)
# Get the response text
response_text = response.choices[0].message.content
return response_text
def get_score_multiplier(self) -> float:
return 2.0
# Kept for compatibility if other parts of the app reference the dictionary
completers = {
"gpt-3.5-turbo": ChatGPTCompletor,
}
def get_completer(model: str, openai_api_key: str = ""):
"""
Factory function to get the appropriate completer.
Modified to prioritize gpt-3.5-turbo as requested.
"""
logger.info(f"Loading completer for {model=}")
# Check if a key was provided; if not, try to pull from environment
api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
if model == "gpt-3.5-turbo":
return ChatGPTCompletor(openai_api_key=api_key, model=model)
else:
# Fallback for the specific competition setup
raise NotImplementedError(f"Model {model} is not supported. Please use gpt-3.5-turbo.")