Spaces:
Sleeping
Sleeping
| import asyncio | |
| import os | |
| import requests | |
| from abc import ABC, abstractmethod | |
| import openai | |
| import aiohttp | |
| from datasets import load_dataset | |
| def download_dataset_hf(name: str, split: str = None): | |
| dataset = load_dataset(name, split=split) | |
| return dataset | |
| class APIModelBase(ABC): | |
| def __init__(self, args): | |
| self.args = args | |
| self.generation_params = self.get_generation_params() | |
| def get_generation_params(self): | |
| return { | |
| "temperature": self.args.temperature, | |
| "top_p": self.args.top_p, | |
| "frequency_penalty": self.args.frequency_penalty, | |
| } | |
| def get_payload(self, messages): | |
| pass | |
| def create_prompt(self, text): | |
| pass | |
| def respond(self, texts): | |
| pass | |
| def respond_async(self, texts): | |
| pass | |
| class OpenaiModel(APIModelBase): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.system_prompt = self.args.system_prompt | |
| def get_payload(self, messages): | |
| return { | |
| "model": self.args.model_openai, | |
| "messages": messages, | |
| "n": 1, | |
| **self.generation_params, | |
| "max_tokens": self.args.max_gen_length, | |
| "timeout": 10000, | |
| "seed": 0 | |
| } | |
| def create_prompt(self, text): | |
| if self.system_prompt is not None: | |
| messages = [{"role": "system", "content": self.system_prompt}] | |
| else: | |
| messages = [] | |
| messages.append({"role": "user", "content": text}) | |
| return messages | |
| def get_generation_params(self): | |
| return { | |
| "temperature": self.args.temperature, | |
| "top_p": self.args.top_p, | |
| "frequency_penalty": self.args.frequency_penalty, | |
| } | |
| def respond(self, texts): | |
| answers = [] | |
| for text in texts: | |
| answer = self.generate_answers(text) | |
| answers.append(answer) | |
| return answers | |
| async def respond_async(self, texts): | |
| tasks = [] | |
| for text in texts: | |
| task = asyncio.create_task(self.generate_answers_async(text)) | |
| tasks.append(task) | |
| result = await asyncio.gather(*tasks) | |
| # print(result) | |
| assert len(result) == len(texts) | |
| return result | |
| def generate_answers(self, text): | |
| with openai.OpenAI( | |
| api_key=os.getenv("OPENAI_API_KEY"), base_url=self.args.hostname | |
| ) as client: | |
| messages = self.create_prompt(text) | |
| payload = self.get_payload(messages) | |
| completion = client.chat.completions.create(**payload) | |
| return completion.choices[0].message.content | |
| async def generate_answers_async(self, text): | |
| async with openai.AsyncOpenAI( | |
| api_key=os.getenv("OPENAI_API_KEY"), base_url=self.args.hostname | |
| ) as client: | |
| messages = self.create_prompt(text) | |
| payload = self.get_payload(messages) | |
| completion = await client.chat.completions.create(**payload) | |
| return completion.choices[0].message.content | |
| class GigachatModel(APIModelBase): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| from gigachat import GigaChat | |
| self.client = GigaChat( | |
| base_url=self.args.hostname, | |
| credentials=os.getenv("GIGACHAT_CREDENTIALS"), | |
| scope="GIGACHAT_API_PERS", | |
| verify_ssl_certs=False, | |
| timeout=100, | |
| ) | |
| self.system_prompt = None | |
| def get_generation_params(self): | |
| if self.args.overwrite_competitor_args: | |
| return { | |
| "temperature": self.args.temperature, | |
| "top_p": self.args.top_p, | |
| "frequency_penalty": self.args.frequency_penalty, | |
| } | |
| return {} | |
| def create_prompt(self, text): | |
| from gigachat.models import Messages, MessagesRole | |
| if self.system_prompt is not None: | |
| messages = [Messages(role=MessagesRole.SYSTEM, content=self.system_prompt)] | |
| else: | |
| messages = [] | |
| messages.append(Messages(role=MessagesRole.USER, content=text)) | |
| return messages | |
| def get_payload(self, messages): | |
| return { | |
| "model": self.args.model_openai, | |
| "messages": messages, | |
| "max_tokens": self.args.max_gen_length, | |
| **self.generation_params, | |
| } | |
| def respond(self, texts): | |
| result = [] | |
| for text in texts: | |
| with self.get_client() as client: | |
| messages = self.create_prompt(text) | |
| payload = self.get_payload(messages) | |
| answer = client.chat(payload).choices[0].message.content | |
| result.append(answer) | |
| # print(result) | |
| assert len(result) == len(texts) | |
| return result | |
| async def respond_async(self, texts): | |
| tasks = [] | |
| for text in texts: | |
| task = asyncio.create_task(self.generate_answers(text)) | |
| tasks.append(task) | |
| result = await asyncio.gather(*tasks) | |
| # print(result) | |
| assert len(result) == len(texts) | |
| return result | |
| async def generate_answers(self, text): | |
| async with self.get_client() as client: | |
| messages = self.create_prompt(text) | |
| payload = self.get_payload(messages) | |
| completion = await client.achat(payload) | |
| return completion.choices[0].message.content | |
| def get_client(self): | |
| from gigachat import GigaChat | |
| return GigaChat( | |
| base_url=self.args.hostname, | |
| credentials=os.getenv("GIGACHAT_CREDENTIALS"), | |
| scope="GIGACHAT_API_PERS", | |
| verify_ssl_certs=False, | |
| timeout=100, | |
| ) | |
| class YandexGPTModel(APIModelBase): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.folder_id = os.getenv("YANDEX_FOLDER_ID") | |
| self.token = os.getenv("YANDEX_KEY") | |
| self.model_version = self.args.model_version or "rc" | |
| self.system_prompt = None | |
| def get_generation_params(self): | |
| if self.args.overwrite_competitor_args: | |
| return { | |
| "temperature": self.args.temperature, | |
| "top_p": self.args.top_p, | |
| "frequency_penalty": self.args.frequency_penalty, | |
| } | |
| return {} | |
| def get_payload(self, messages): | |
| model_uri = ( | |
| f"gpt://{self.folder_id}/{self.args.model_openai}/{self.model_version}" | |
| ) | |
| return { | |
| "modelUri": model_uri, | |
| "completionOptions": { | |
| "stream": False, | |
| "max_tokens": self.args.max_gen_length, | |
| **self.generation_params, | |
| }, | |
| "messages": messages, | |
| } | |
| def create_prompt(self, text): | |
| if self.system_prompt is not None: | |
| messages = [{"role": "system", "text": self.system_prompt}] | |
| else: | |
| messages = [] | |
| messages.append({"role": "user", "text": text}) | |
| return messages | |
| def respond(self, texts): | |
| result = [] | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.token}", | |
| "x-folder-id": self.folder_id, | |
| } | |
| for text in texts: | |
| messages = self.create_prompt(text) | |
| payload = self.get_payload(messages) | |
| response = requests.post(self.args.hostname, headers=headers, json=payload) | |
| response.raise_for_status() | |
| answer = response.json() | |
| if "error" in answer: | |
| raise Exception(f"Operation failed: {answer['error']['message']}") | |
| alternatives = answer["result"]["alternatives"] | |
| if alternatives: | |
| result.append(alternatives[0]["message"]["text"]) | |
| return result | |
| async def respond_async(self, texts): | |
| tasks = [] | |
| for text in texts: | |
| task = asyncio.create_task(self.generate_answers(text)) | |
| tasks.append(task) | |
| result = await asyncio.gather(*tasks) | |
| # print(result) | |
| assert len(result) == len(texts) | |
| return result | |
| async def generate_answers(self, text): | |
| operation_url = "https://operation.api.cloud.yandex.net/operations/" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.token}", | |
| "x-folder-id": self.folder_id, | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| messages = self.create_prompt(text) | |
| payload = self.get_payload(messages) | |
| async with session.post( | |
| self.args.hostname, headers=headers, json=payload | |
| ) as response: | |
| response.raise_for_status() | |
| operation = await response.json() | |
| operation_id = operation["id"] | |
| while True: | |
| async with session.get( | |
| f"{operation_url}{operation_id}", headers=headers | |
| ) as response: | |
| response.raise_for_status() | |
| answer = await response.json() | |
| if answer["done"]: | |
| break | |
| await asyncio.sleep(1) # Wait for 1 second before polling again | |
| if "error" in answer: | |
| raise Exception(f"Operation failed: {answer['error']['message']}") | |
| alternatives = answer["response"]["alternatives"] | |
| if alternatives: | |
| return [alternatives[0]["message"]["text"]] | |
| else: | |
| return None | |