import asyncio import os import requests from abc import ABC, abstractmethod import openai import aiohttp from datasets import load_dataset def download_dataset_hf(name: str, split: str = None): dataset = load_dataset(name, split=split) return dataset class APIModelBase(ABC): def __init__(self, args): self.args = args self.generation_params = self.get_generation_params() def get_generation_params(self): return { "temperature": self.args.temperature, "top_p": self.args.top_p, "frequency_penalty": self.args.frequency_penalty, } @abstractmethod def get_payload(self, messages): pass @abstractmethod def create_prompt(self, text): pass @abstractmethod def respond(self, texts): pass @abstractmethod def respond_async(self, texts): pass class OpenaiModel(APIModelBase): def __init__(self, args): super().__init__(args) self.system_prompt = self.args.system_prompt def get_payload(self, messages): return { "model": self.args.model_openai, "messages": messages, "n": 1, **self.generation_params, "max_tokens": self.args.max_gen_length, "timeout": 10000, "seed": 0 } def create_prompt(self, text): if self.system_prompt is not None: messages = [{"role": "system", "content": self.system_prompt}] else: messages = [] messages.append({"role": "user", "content": text}) return messages def get_generation_params(self): return { "temperature": self.args.temperature, "top_p": self.args.top_p, "frequency_penalty": self.args.frequency_penalty, } def respond(self, texts): answers = [] for text in texts: answer = self.generate_answers(text) answers.append(answer) return answers async def respond_async(self, texts): tasks = [] for text in texts: task = asyncio.create_task(self.generate_answers_async(text)) tasks.append(task) result = await asyncio.gather(*tasks) # print(result) assert len(result) == len(texts) return result def generate_answers(self, text): with openai.OpenAI( api_key=os.getenv("OPENAI_API_KEY"), base_url=self.args.hostname ) as client: messages = self.create_prompt(text) payload = self.get_payload(messages) completion = client.chat.completions.create(**payload) return completion.choices[0].message.content async def generate_answers_async(self, text): async with openai.AsyncOpenAI( api_key=os.getenv("OPENAI_API_KEY"), base_url=self.args.hostname ) as client: messages = self.create_prompt(text) payload = self.get_payload(messages) completion = await client.chat.completions.create(**payload) return completion.choices[0].message.content class GigachatModel(APIModelBase): def __init__(self, args): super().__init__(args) from gigachat import GigaChat self.client = GigaChat( base_url=self.args.hostname, credentials=os.getenv("GIGACHAT_CREDENTIALS"), scope="GIGACHAT_API_PERS", verify_ssl_certs=False, timeout=100, ) self.system_prompt = None def get_generation_params(self): if self.args.overwrite_competitor_args: return { "temperature": self.args.temperature, "top_p": self.args.top_p, "frequency_penalty": self.args.frequency_penalty, } return {} def create_prompt(self, text): from gigachat.models import Messages, MessagesRole if self.system_prompt is not None: messages = [Messages(role=MessagesRole.SYSTEM, content=self.system_prompt)] else: messages = [] messages.append(Messages(role=MessagesRole.USER, content=text)) return messages def get_payload(self, messages): return { "model": self.args.model_openai, "messages": messages, "max_tokens": self.args.max_gen_length, **self.generation_params, } def respond(self, texts): result = [] for text in texts: with self.get_client() as client: messages = self.create_prompt(text) payload = self.get_payload(messages) answer = client.chat(payload).choices[0].message.content result.append(answer) # print(result) assert len(result) == len(texts) return result async def respond_async(self, texts): tasks = [] for text in texts: task = asyncio.create_task(self.generate_answers(text)) tasks.append(task) result = await asyncio.gather(*tasks) # print(result) assert len(result) == len(texts) return result async def generate_answers(self, text): async with self.get_client() as client: messages = self.create_prompt(text) payload = self.get_payload(messages) completion = await client.achat(payload) return completion.choices[0].message.content def get_client(self): from gigachat import GigaChat return GigaChat( base_url=self.args.hostname, credentials=os.getenv("GIGACHAT_CREDENTIALS"), scope="GIGACHAT_API_PERS", verify_ssl_certs=False, timeout=100, ) class YandexGPTModel(APIModelBase): def __init__(self, args): super().__init__(args) self.folder_id = os.getenv("YANDEX_FOLDER_ID") self.token = os.getenv("YANDEX_KEY") self.model_version = self.args.model_version or "rc" self.system_prompt = None def get_generation_params(self): if self.args.overwrite_competitor_args: return { "temperature": self.args.temperature, "top_p": self.args.top_p, "frequency_penalty": self.args.frequency_penalty, } return {} def get_payload(self, messages): model_uri = ( f"gpt://{self.folder_id}/{self.args.model_openai}/{self.model_version}" ) return { "modelUri": model_uri, "completionOptions": { "stream": False, "max_tokens": self.args.max_gen_length, **self.generation_params, }, "messages": messages, } def create_prompt(self, text): if self.system_prompt is not None: messages = [{"role": "system", "text": self.system_prompt}] else: messages = [] messages.append({"role": "user", "text": text}) return messages def respond(self, texts): result = [] headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.token}", "x-folder-id": self.folder_id, } for text in texts: messages = self.create_prompt(text) payload = self.get_payload(messages) response = requests.post(self.args.hostname, headers=headers, json=payload) response.raise_for_status() answer = response.json() if "error" in answer: raise Exception(f"Operation failed: {answer['error']['message']}") alternatives = answer["result"]["alternatives"] if alternatives: result.append(alternatives[0]["message"]["text"]) return result async def respond_async(self, texts): tasks = [] for text in texts: task = asyncio.create_task(self.generate_answers(text)) tasks.append(task) result = await asyncio.gather(*tasks) # print(result) assert len(result) == len(texts) return result async def generate_answers(self, text): operation_url = "https://operation.api.cloud.yandex.net/operations/" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.token}", "x-folder-id": self.folder_id, } async with aiohttp.ClientSession() as session: messages = self.create_prompt(text) payload = self.get_payload(messages) async with session.post( self.args.hostname, headers=headers, json=payload ) as response: response.raise_for_status() operation = await response.json() operation_id = operation["id"] while True: async with session.get( f"{operation_url}{operation_id}", headers=headers ) as response: response.raise_for_status() answer = await response.json() if answer["done"]: break await asyncio.sleep(1) # Wait for 1 second before polling again if "error" in answer: raise Exception(f"Operation failed: {answer['error']['message']}") alternatives = answer["response"]["alternatives"] if alternatives: return [alternatives[0]["message"]["text"]] else: return None