Spaces:
Running
Running
| import logging | |
| import os | |
| from abc import ABC, abstractmethod | |
| from gradio_client import Client | |
| from openai import OpenAI | |
| import tiktoken | |
| from transformers import T5Tokenizer | |
| # Authentication tokens from environment variables | |
| HUB_TOKEN = os.getenv("HUB_TOKEN") | |
| FLAN_HF_SPACE = "jerpint/i-like-flan" | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class TextCompletion(ABC): | |
| def get_token_count(self, prompt: str) -> int: | |
| ... | |
| def complete(self, prompt: str): | |
| ... | |
| def get_score_multiplier(self) -> float: | |
| return 1.0 | |
| class ChatGPTCompletor(TextCompletion): | |
| def __init__(self, openai_api_key, model): | |
| self.openai_api_key = openai_api_key | |
| self.model = model | |
| # Initialize the client inside the class using the provided key | |
| self.client = OpenAI(api_key=self.openai_api_key) | |
| def get_token_count(self, prompt: str) -> int: | |
| encoding = tiktoken.encoding_for_model(self.model) | |
| return len(encoding.encode(prompt)) | |
| def complete(self, prompt: str): | |
| messages = [ | |
| {"role": "user", "content": prompt}, | |
| ] | |
| # Call the API using the instance's client | |
| response = self.client.chat.completions.create( | |
| messages=messages, | |
| model=self.model, | |
| temperature=0 | |
| ) | |
| # Get the response text | |
| response_text = response.choices[0].message.content | |
| return response_text | |
| def get_score_multiplier(self) -> float: | |
| return 2.0 | |
| # Kept for compatibility if other parts of the app reference the dictionary | |
| completers = { | |
| "gpt-3.5-turbo": ChatGPTCompletor, | |
| } | |
| def get_completer(model: str, openai_api_key: str = ""): | |
| """ | |
| Factory function to get the appropriate completer. | |
| Modified to prioritize gpt-3.5-turbo as requested. | |
| """ | |
| logger.info(f"Loading completer for {model=}") | |
| # Check if a key was provided; if not, try to pull from environment | |
| api_key = openai_api_key or os.getenv("OPENAI_API_KEY") | |
| if model == "gpt-3.5-turbo": | |
| return ChatGPTCompletor(openai_api_key=api_key, model=model) | |
| else: | |
| # Fallback for the specific competition setup | |
| raise NotImplementedError(f"Model {model} is not supported. Please use gpt-3.5-turbo.") |