CAIA-evaluate / utils.py
Zhejian
bugfix
0778ffc
import pandas as pd
import tiktoken
from typing import List, Optional
from email._parseaddr import AddressList as _AddressList
from schemas import BenchmarkItem, EvaluateData, EvaluateItem
from datasets import DatasetDict
def truncate_text(text: str, model: str = "gpt-4.1", max_tokens: Optional[int] = None) -> str:
"""
Truncate text to specified token count using tiktoken
Args:
text: Text to be truncated
model: Model name to use, defaults to "gpt-4"
max_tokens: Maximum token count, if None then no truncation
Returns:
Truncated text
"""
if not max_tokens:
return text
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
# 如果找不到指定模型的编码器,使用cl100k_base编码器
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
if len(tokens) <= max_tokens:
return text
truncated_tokens = tokens[:max_tokens]
return encoding.decode(truncated_tokens)
def count_tokens(text: str, model: str = "gpt-4.1") -> int:
"""
Count the number of tokens in a text using tiktoken
Args:
text: Text to count tokens
model: Model name to use, defaults to "gpt-4"
Returns:
Number of tokens in the text
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
# 如果找不到指定模型的编码器,使用cl100k_base编码器
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
return len(tokens)
def parseaddr(addr):
"""
Parse addr into its constituent realname and email address parts.
Return a tuple of realname and email address, unless the parse fails, in
which case return a 2-tuple of ('', '').
"""
addrs = _AddressList(addr).addresslist
if not addrs:
return '', ''
return addrs[0]
def parse_eval_dataset(dataset:DatasetDict) -> List[BenchmarkItem]:
df = pd.DataFrame(dataset['train'])
benchmark_items:List[BenchmarkItem] = []
for index, row in df.iterrows():
benchmark_items.append(BenchmarkItem(
task_id=row['task_id'],
question=row['question'],
evaluate=EvaluateData(items=[EvaluateItem(**item) for item in row['evaluate']['items']]),
category=row['category'],
level=row['level']
))
return benchmark_items