ru_leaderboard / src /evaluate /generate_answers.py
Titova Ksenia
divide results
8d664ba
import argparse
import asyncio
import json
import uuid
import os
from pathlib import PurePath
from tqdm import tqdm
from src.evaluate.util import APIModelBase, download_dataset_hf
SYSTEM_PROMPT = "Ты — ИИ-помощник. Тебе дано задание: необходимо сгенерировать подробный и развернутый ответ."
TEMPERATURE = 0.0
TOP_P = 0.1
FREQUENCY_PENALTY = 1.2
def load_data(dataset_name, from_local=False, split=None):
if not from_local:
if dataset_name is not None:
data = download_dataset_hf(dataset_name, split=split)
else:
with open(PurePath(dataset_name), "r") as f:
data = [json.loads(line) for line in f]
return data
def write_response_jsonl(response_text, counter, question, model_name, output_filename):
message_id = str(uuid.uuid4())
cur_dict = {
"question_id": question["question_id"][counter],
"cluster": question["cluster"][counter],
"turns": question["turns"][counter],
}
cur_dict["replies"] = {
"message_id": message_id,
"text": response_text,
"model_name": model_name,
}
with open(output_filename, "a") as f:
json.dump(cur_dict, f, ensure_ascii=False)
f.write("\n")
def main(
model: APIModelBase,
output_filename: str,
model_name: str,
dataset_name: str = None,
from_local: bool = False,
dataset_split: str = None,
chunk_size: int = 1,
start_from_chunk=0,
sync: bool = False,
) -> None:
data = load_data(
dataset_name=dataset_name,
from_local=from_local,
split=dataset_split,
)
chunks = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)]
steps = 0
bar = tqdm(total=len(data) - (start_from_chunk * chunk_size))
for chunk in chunks[start_from_chunk:]:
texts = [dialog["content"] for dialog in chunk["turns"]]
if chunk_size > 1 and not sync:
result = asyncio.run(model.respond_async(texts))
else:
result = model.respond(texts)
steps += len(result)
bar.update(len(result))
for num, response_text in enumerate(result):
write_response_jsonl(
response_text=response_text,
counter=num,
question=chunk,
model_name=model_name,
output_filename=output_filename,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--hostname",
type=str,
default="http://0.0.0.0:6969/api",
help="Хост, где запущена модель",
)
parser.add_argument(
"--output-filename",
type=str,
default=None,
help="Файл в таком же формате, что и входящий, но с добавленными ответам модели",
)
parser.add_argument(
"--model-name",
type=str,
default="gpt-3.5-turbo",
help="Название модели для записи в файл",
)
parser.add_argument(
"--model-openai",
type=str,
default=None,
help="Название модели для доступа по OpenAI API",
)
parser.add_argument(
"--dataset-name",
type=str,
default="Vikhrmodels/arena_hard_ru",
help="Название датасета для файла input_filename из clearml.`",
)
parser.add_argument(
"--dataset-split",
type=str,
default="train",
help="Сплит датасета",
)
parser.add_argument("--model-version", type=str, default=None, help="Версия модели")
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--top-p", type=float, default=TOP_P)
parser.add_argument(
"--frequency-penalty",
type=float,
default=FREQUENCY_PENALTY,
help=(
"Если сервер запущен на flask, то этот параметр на самом деле repetition_penalty. "
"Если сервер запущен на vllm, то это OPENAI frequency_penalty"
),
)
parser.add_argument(
"--overwrite-competitor-args",
default=False,
action=argparse.BooleanOptionalAction,
help="Передавать ли в модели конкурентов заданные параметры генерации вместо дефолтных",
)
parser.add_argument(
"--from-local",
action=argparse.BooleanOptionalAction,
default=False,
help=(
"True, если input-filename лежит локально, в противном "
"случае файл скачается из clearml, но необходимо указать dataset-name."
),
)
parser.add_argument(
"--chunk-size", type=int, default=1, help="Размер батча для генерации"
)
parser.add_argument(
"--sync",
default=False,
action=argparse.BooleanOptionalAction,
)
parser.add_argument("--save-steps", type=int, default=16)
parser.add_argument(
"--max-gen-length",
type=int,
default=4096,
help="Максимальная длина генерируемого текста",
)
parser.add_argument(
"--system-prompt", type=str, default=SYSTEM_PROMPT, help="Системный промпт"
)
parser.add_argument(
"--start-from-chunk",
type=int,
default=0,
help="С какого батча начинать генерацию (если предыдущая попытка упала)",
)
args = parser.parse_args()
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY is not set")
if not args.model_openai:
args.model_openai = args.model_name
if "yandex" in args.model_name.lower():
from src.evaluate.util import YandexGPTModel
args.system_prompt = None
model = YandexGPTModel(args)
elif "gigachat" in args.model_name.lower():
from src.evaluate.util import GigachatModel
args.system_prompt = None
if args.model_version:
if args.model_version not in args.model_openai:
args.model_openai = f"{args.model_openai}:{args.model_version}"
if args.model_version not in args.model_name:
args.model_name = f"{args.model_name}:{args.model_version}"
model = GigachatModel(args)
else:
from src.evaluate.util import OpenaiModel
model = OpenaiModel(args)
if not args.output_filename:
os.makedirs("./data/generations/", exist_ok=True)
dataset_name = args.dataset_name.split("/")[-1]
args.output_filename = f"./data/generations/{args.model_name}_{dataset_name}_responses.jsonl"
main(
model=model,
output_filename=args.output_filename,
model_name=args.model_name,
from_local=args.from_local,
dataset_name=args.dataset_name,
dataset_split=args.dataset_split,
chunk_size=args.chunk_size,
start_from_chunk=args.start_from_chunk,
sync=args.sync
)