Titova Ksenia
add sample data
21865d4
import asyncio
import os
import requests
from abc import ABC, abstractmethod
import openai
import aiohttp
from datasets import load_dataset
def download_dataset_hf(name: str, split: str = None):
dataset = load_dataset(name, split=split)
return dataset
class APIModelBase(ABC):
def __init__(self, args):
self.args = args
self.generation_params = self.get_generation_params()
def get_generation_params(self):
return {
"temperature": self.args.temperature,
"top_p": self.args.top_p,
"frequency_penalty": self.args.frequency_penalty,
}
@abstractmethod
def get_payload(self, messages):
pass
@abstractmethod
def create_prompt(self, text):
pass
@abstractmethod
def respond(self, texts):
pass
@abstractmethod
def respond_async(self, texts):
pass
class OpenaiModel(APIModelBase):
def __init__(self, args):
super().__init__(args)
self.system_prompt = self.args.system_prompt
def get_payload(self, messages):
return {
"model": self.args.model_openai,
"messages": messages,
"n": 1,
**self.generation_params,
"max_tokens": self.args.max_gen_length,
"timeout": 10000,
"seed": 0
}
def create_prompt(self, text):
if self.system_prompt is not None:
messages = [{"role": "system", "content": self.system_prompt}]
else:
messages = []
messages.append({"role": "user", "content": text})
return messages
def get_generation_params(self):
return {
"temperature": self.args.temperature,
"top_p": self.args.top_p,
"frequency_penalty": self.args.frequency_penalty,
}
def respond(self, texts):
answers = []
for text in texts:
answer = self.generate_answers(text)
answers.append(answer)
return answers
async def respond_async(self, texts):
tasks = []
for text in texts:
task = asyncio.create_task(self.generate_answers_async(text))
tasks.append(task)
result = await asyncio.gather(*tasks)
# print(result)
assert len(result) == len(texts)
return result
def generate_answers(self, text):
with openai.OpenAI(
api_key=os.getenv("OPENAI_API_KEY"), base_url=self.args.hostname
) as client:
messages = self.create_prompt(text)
payload = self.get_payload(messages)
completion = client.chat.completions.create(**payload)
return completion.choices[0].message.content
async def generate_answers_async(self, text):
async with openai.AsyncOpenAI(
api_key=os.getenv("OPENAI_API_KEY"), base_url=self.args.hostname
) as client:
messages = self.create_prompt(text)
payload = self.get_payload(messages)
completion = await client.chat.completions.create(**payload)
return completion.choices[0].message.content
class GigachatModel(APIModelBase):
def __init__(self, args):
super().__init__(args)
from gigachat import GigaChat
self.client = GigaChat(
base_url=self.args.hostname,
credentials=os.getenv("GIGACHAT_CREDENTIALS"),
scope="GIGACHAT_API_PERS",
verify_ssl_certs=False,
timeout=100,
)
self.system_prompt = None
def get_generation_params(self):
if self.args.overwrite_competitor_args:
return {
"temperature": self.args.temperature,
"top_p": self.args.top_p,
"frequency_penalty": self.args.frequency_penalty,
}
return {}
def create_prompt(self, text):
from gigachat.models import Messages, MessagesRole
if self.system_prompt is not None:
messages = [Messages(role=MessagesRole.SYSTEM, content=self.system_prompt)]
else:
messages = []
messages.append(Messages(role=MessagesRole.USER, content=text))
return messages
def get_payload(self, messages):
return {
"model": self.args.model_openai,
"messages": messages,
"max_tokens": self.args.max_gen_length,
**self.generation_params,
}
def respond(self, texts):
result = []
for text in texts:
with self.get_client() as client:
messages = self.create_prompt(text)
payload = self.get_payload(messages)
answer = client.chat(payload).choices[0].message.content
result.append(answer)
# print(result)
assert len(result) == len(texts)
return result
async def respond_async(self, texts):
tasks = []
for text in texts:
task = asyncio.create_task(self.generate_answers(text))
tasks.append(task)
result = await asyncio.gather(*tasks)
# print(result)
assert len(result) == len(texts)
return result
async def generate_answers(self, text):
async with self.get_client() as client:
messages = self.create_prompt(text)
payload = self.get_payload(messages)
completion = await client.achat(payload)
return completion.choices[0].message.content
def get_client(self):
from gigachat import GigaChat
return GigaChat(
base_url=self.args.hostname,
credentials=os.getenv("GIGACHAT_CREDENTIALS"),
scope="GIGACHAT_API_PERS",
verify_ssl_certs=False,
timeout=100,
)
class YandexGPTModel(APIModelBase):
def __init__(self, args):
super().__init__(args)
self.folder_id = os.getenv("YANDEX_FOLDER_ID")
self.token = os.getenv("YANDEX_KEY")
self.model_version = self.args.model_version or "rc"
self.system_prompt = None
def get_generation_params(self):
if self.args.overwrite_competitor_args:
return {
"temperature": self.args.temperature,
"top_p": self.args.top_p,
"frequency_penalty": self.args.frequency_penalty,
}
return {}
def get_payload(self, messages):
model_uri = (
f"gpt://{self.folder_id}/{self.args.model_openai}/{self.model_version}"
)
return {
"modelUri": model_uri,
"completionOptions": {
"stream": False,
"max_tokens": self.args.max_gen_length,
**self.generation_params,
},
"messages": messages,
}
def create_prompt(self, text):
if self.system_prompt is not None:
messages = [{"role": "system", "text": self.system_prompt}]
else:
messages = []
messages.append({"role": "user", "text": text})
return messages
def respond(self, texts):
result = []
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.token}",
"x-folder-id": self.folder_id,
}
for text in texts:
messages = self.create_prompt(text)
payload = self.get_payload(messages)
response = requests.post(self.args.hostname, headers=headers, json=payload)
response.raise_for_status()
answer = response.json()
if "error" in answer:
raise Exception(f"Operation failed: {answer['error']['message']}")
alternatives = answer["result"]["alternatives"]
if alternatives:
result.append(alternatives[0]["message"]["text"])
return result
async def respond_async(self, texts):
tasks = []
for text in texts:
task = asyncio.create_task(self.generate_answers(text))
tasks.append(task)
result = await asyncio.gather(*tasks)
# print(result)
assert len(result) == len(texts)
return result
async def generate_answers(self, text):
operation_url = "https://operation.api.cloud.yandex.net/operations/"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.token}",
"x-folder-id": self.folder_id,
}
async with aiohttp.ClientSession() as session:
messages = self.create_prompt(text)
payload = self.get_payload(messages)
async with session.post(
self.args.hostname, headers=headers, json=payload
) as response:
response.raise_for_status()
operation = await response.json()
operation_id = operation["id"]
while True:
async with session.get(
f"{operation_url}{operation_id}", headers=headers
) as response:
response.raise_for_status()
answer = await response.json()
if answer["done"]:
break
await asyncio.sleep(1) # Wait for 1 second before polling again
if "error" in answer:
raise Exception(f"Operation failed: {answer['error']['message']}")
alternatives = answer["response"]["alternatives"]
if alternatives:
return [alternatives[0]["message"]["text"]]
else:
return None